From b45f2c4f11f573a144eac36cf39829519a287513 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 29 Aug 2024 18:39:30 +0800 Subject: [PATCH 01/33] [init] amsgrad --- paddle/phi/infermeta/multiary.cc | 8 +++++ paddle/phi/infermeta/multiary.h | 3 ++ paddle/phi/infermeta/spmd_rules/optimizer.cc | 10 +++++- paddle/phi/infermeta/spmd_rules/optimizer.h | 4 ++- paddle/phi/kernels/adam_kernel.h | 6 ++++ paddle/phi/kernels/cpu/adam_kernel.cc | 24 ++++++++++++-- paddle/phi/kernels/cpu/adamw_kernel.cc | 3 ++ paddle/phi/kernels/cpu/fused_adam_kernel.cc | 3 ++ paddle/phi/kernels/funcs/adam_functors.h | 13 ++++++-- paddle/phi/kernels/funcs/jit/kernel_base.h | 17 ++++++++-- paddle/phi/kernels/funcs/jit/refer/refer.h | 24 ++++++++++++-- paddle/phi/kernels/gpu/adam_kernel.cu | 6 ++++ paddle/phi/ops/yaml/op_compat.yaml | 4 +-- paddle/phi/ops/yaml/ops.yaml | 10 +++--- python/paddle/optimizer/adam.py | 34 +++++++++++++++++++- 15 files changed, 150 insertions(+), 19 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 9ec9538b5aabfb..281aea978191c4 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -152,6 +152,7 @@ void AdamInferMeta(const MetaTensor& param, const MetaTensor& learning_rate, const MetaTensor& moment1, const MetaTensor& moment2, + const MetaTensor& moment2_max, const MetaTensor& beta1_pow, const MetaTensor& beta2_pow, const MetaTensor& master_param, @@ -163,9 +164,11 @@ void AdamInferMeta(const MetaTensor& param, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, MetaTensor* param_out, MetaTensor* moment1_out, MetaTensor* moment2_out, + MetaTensor* moment2_max_out, MetaTensor* beta1_pow_out, MetaTensor* beta2_pow_out, MetaTensor* master_param_outs) { @@ -232,6 +235,8 @@ void AdamInferMeta(const MetaTensor& param, moment1_out->set_dtype(moment1.dtype()); moment2_out->set_dims(param_dims); moment2_out->set_dtype(moment2.dtype()); + moment2_max_out->set_dims(param_dims); + moment2_max_out->set_dtype(moment2_max.dtype()); beta1_pow_out->set_dims(beta1_pow_dims); beta1_pow_out->set_dtype(beta1_pow.dtype()); @@ -353,6 +358,7 @@ void AdamwInferMeta(const MetaTensor& param, learning_rate, moment1, moment2, + moment2, // TODO(megemini) beta1_pow, beta2_pow, master_param, @@ -364,9 +370,11 @@ void AdamwInferMeta(const MetaTensor& param, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow, + false, // TODO(megemini) param_out, moment1_out, moment2_out, + moment2_out, // TODO(megemini) beta1_pow_out, beta2_pow_out, master_param_outs); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 045dd747791665..ddd54c113cff80 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -86,6 +86,7 @@ void AdamInferMeta(const MetaTensor& param, const MetaTensor& learning_rate, const MetaTensor& moment1, const MetaTensor& moment2, + const MetaTensor& moment2_max, const MetaTensor& beta1_pow, const MetaTensor& beta2_pow, const MetaTensor& master_param, @@ -97,9 +98,11 @@ void AdamInferMeta(const MetaTensor& param, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, MetaTensor* param_out, MetaTensor* moment1_out, MetaTensor* moment2_out, + MetaTensor* moment2_max_out, MetaTensor* beta1_pow_out, MetaTensor* beta2_pow_out, MetaTensor* master_param_outs); diff --git a/paddle/phi/infermeta/spmd_rules/optimizer.cc b/paddle/phi/infermeta/spmd_rules/optimizer.cc index 72cd06fda3fbbc..7b1e711a4b97b8 100644 --- a/paddle/phi/infermeta/spmd_rules/optimizer.cc +++ b/paddle/phi/infermeta/spmd_rules/optimizer.cc @@ -30,6 +30,7 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, const DistMetaTensor& learning_rate, const DistMetaTensor& moment1, const DistMetaTensor& moment2, + const DistMetaTensor& moment2_max, const DistMetaTensor& beta1_pow, const DistMetaTensor& beta2_pow, const DistMetaTensor& master_param, @@ -40,7 +41,8 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, - bool use_global_beta_pow) { + bool use_global_beta_pow, + bool amsgrad) { // shape check PADDLE_ENFORCE( param.dims().size() == grad.dims().size() && @@ -79,6 +81,8 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, CopyTensorDistAttrForOutput(moment1.dist_attr()); TensorDistAttr moment2_dist_attr = CopyTensorDistAttrForOutput(moment2.dist_attr()); + TensorDistAttr moment2_max_dist_attr = + CopyTensorDistAttrForOutput(moment2_max.dist_attr()); TensorDistAttr beta1_pow_dist_attr = CopyTensorDistAttrForOutput(beta1_pow.dist_attr()); TensorDistAttr beta2_pow_dist_attr = @@ -115,6 +119,7 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, auto grad_spmd_dims_mapping = grad_dist_attr_spmd.dims_mapping(); auto momentum1_src_dims_mapping = moment1.dist_attr().dims_mapping(); auto momentum2_src_dims_mapping = moment2.dist_attr().dims_mapping(); + auto momentum2_max_src_dims_mapping = moment2_max.dist_attr().dims_mapping(); // Get the final dist attr for param, master_param, grad and momentum. // Whatever the input dist attrs are, the output dist attr should be same. @@ -172,12 +177,14 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, } moment1_dist_attr.set_dims_mapping(dst_dims_mapping); moment2_dist_attr.set_dims_mapping(dst_dims_mapping); + moment2_max_dist_attr.set_dims_mapping(dst_dims_mapping); return {{param_dist_attr, grad_dist_attr, lr_dist_attr, moment1_dist_attr, moment2_dist_attr, + moment2_max_dist_attr, beta1_pow_dist_attr, beta2_pow_dist_attr, master_param_dist_attr, @@ -185,6 +192,7 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, {param_dist_attr, moment1_dist_attr, moment2_dist_attr, + moment2_max_dist_attr, beta1_pow_dist_attr, beta2_pow_dist_attr, master_param_dist_attr}}; diff --git a/paddle/phi/infermeta/spmd_rules/optimizer.h b/paddle/phi/infermeta/spmd_rules/optimizer.h index c45ddcd0c97e11..bc033f5310216a 100644 --- a/paddle/phi/infermeta/spmd_rules/optimizer.h +++ b/paddle/phi/infermeta/spmd_rules/optimizer.h @@ -28,6 +28,7 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, const DistMetaTensor& learning_rate, const DistMetaTensor& moment1, const DistMetaTensor& moment2, + const DistMetaTensor& moment2_max, const DistMetaTensor& beta1_pow, const DistMetaTensor& beta2_pow, const DistMetaTensor& master_param, @@ -38,7 +39,8 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, - bool use_global_beta_pow); + bool use_global_beta_pow, + bool amsgrad); SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param, const DistMetaTensor& grad, diff --git a/paddle/phi/kernels/adam_kernel.h b/paddle/phi/kernels/adam_kernel.h index b1a7f5a686530c..a7d1033e00f854 100644 --- a/paddle/phi/kernels/adam_kernel.h +++ b/paddle/phi/kernels/adam_kernel.h @@ -26,6 +26,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -37,9 +38,11 @@ void AdamDenseKernel(const Context& dev_ctx, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs); @@ -52,6 +55,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, + const std::vector& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -60,9 +64,11 @@ void MergedAdamKernel( const Scalar& epsilon, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector param_out, std::vector moment1_out, std::vector moment2_out, + std::vector moment2_max_out, std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out); diff --git a/paddle/phi/kernels/cpu/adam_kernel.cc b/paddle/phi/kernels/cpu/adam_kernel.cc index 1a63b779b02a19..0d30dc28a8220f 100644 --- a/paddle/phi/kernels/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/cpu/adam_kernel.cc @@ -35,6 +35,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -46,9 +47,11 @@ void AdamDenseKernel(const Context& dev_ctx, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -72,6 +75,7 @@ void AdamDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -112,6 +116,7 @@ void AdamDenseKernel(const Context& dev_ctx, T* param_out_ptr = dev_ctx.template Alloc(param_out); T* mom1_out_ptr = dev_ctx.template Alloc(moment1_out); T* mom2_out_ptr = dev_ctx.template Alloc(moment2_out); + T* mom2_max_out_ptr = dev_ctx.template Alloc(moment2_max_out); T learning_rate_ = learning_rate.data()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); @@ -123,6 +128,7 @@ void AdamDenseKernel(const Context& dev_ctx, const T* param_ptr = param.data(); const T* mom1_ptr = moment1.data(); const T* mom2_ptr = moment2.data(); + const T* mom2_max_ptr = moment2_max.data(); const T* grad_ptr = grad.data(); auto adam = @@ -144,10 +150,13 @@ void AdamDenseKernel(const Context& dev_ctx, grad_ptr + offset, mom1_ptr + offset, mom2_ptr + offset, + mom2_max_ptr + offset, param_ptr + offset, mom1_out_ptr + offset, mom2_out_ptr + offset, - param_out_ptr + offset); + mom2_max_out_ptr + offset, + param_out_ptr + offset, + amsgrad); } if (numel % chunk_size != 0) { @@ -161,10 +170,13 @@ void AdamDenseKernel(const Context& dev_ctx, grad_ptr + offset, mom1_ptr + offset, mom2_ptr + offset, + mom2_max_ptr + offset, param_ptr + offset, mom1_out_ptr + offset, mom2_out_ptr + offset, - param_out_ptr + offset); + mom2_max_out_ptr + offset, + param_out_ptr + offset, + amsgrad); } } @@ -176,6 +188,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, + const std::vector& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -184,9 +197,11 @@ void MergedAdamKernel( const Scalar& epsilon, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector param_out, std::vector moment1_out, std::vector moment2_out, + std::vector moment2_max_out, std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out) { @@ -255,10 +270,13 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), + moment2_max[idx]->data(), + dev_ctx.template Alloc(moment2_max_out[idx]), learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), - dev_ctx.template Alloc(param_out[idx])); + dev_ctx.template Alloc(param_out[idx]), + amsgrad); functor(param[idx]->numel()); if (!use_global_beta_pow) { dev_ctx.template Alloc(beta1_pow_out[idx])[0] = diff --git a/paddle/phi/kernels/cpu/adamw_kernel.cc b/paddle/phi/kernels/cpu/adamw_kernel.cc index f8b8ea67e23bb6..ffef5d50c4e641 100644 --- a/paddle/phi/kernels/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/cpu/adamw_kernel.cc @@ -75,6 +75,7 @@ void AdamwDenseKernel(const Context& dev_ctx, learning_rate, moment1, moment2, + moment2, // TODO(megemini) beta1_pow, beta2_pow, master_param, @@ -86,9 +87,11 @@ void AdamwDenseKernel(const Context& dev_ctx, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow, + false, // TODO(megemini) param_out, moment1_out, moment2_out, + moment2_out, // TODO(megemini) beta1_pow_out, beta2_pow_out, master_param_outs); diff --git a/paddle/phi/kernels/cpu/fused_adam_kernel.cc b/paddle/phi/kernels/cpu/fused_adam_kernel.cc index c6434be8077d9a..df935c57da7361 100644 --- a/paddle/phi/kernels/cpu/fused_adam_kernel.cc +++ b/paddle/phi/kernels/cpu/fused_adam_kernel.cc @@ -106,6 +106,7 @@ void FusedAdamKernel( learning_rate, *moments1[idx], *moments2[idx], + *moments2[idx], // TODO(megemini) *beta1_pows[idx], *beta2_pows[idx], master_params_tmp, @@ -117,9 +118,11 @@ void FusedAdamKernel( 1000, multi_precision, use_global_beta_pow, + false, // TODO(megemini) params_out[idx], moments1_out[idx], moments2_out[idx], + moments2_out[idx], // TODO(megemini) beta1_pows_out[idx], beta2_pows_out[idx], master_params_out.empty() ? nullptr : master_params_out[idx]); diff --git a/paddle/phi/kernels/funcs/adam_functors.h b/paddle/phi/kernels/funcs/adam_functors.h index 936b1d518fa95f..5ad0a64fcd766d 100644 --- a/paddle/phi/kernels/funcs/adam_functors.h +++ b/paddle/phi/kernels/funcs/adam_functors.h @@ -244,10 +244,13 @@ class AdamFunctor { T* moment1_out_; const T* moment2_; T* moment2_out_; + const T* moment2_max_; + T* moment2_max_out_; const T* lr_; const T* grad_; const T* param_; T* param_out_; + bool amsgrad_; public: AdamFunctor(T beta1, @@ -259,10 +262,13 @@ class AdamFunctor { T* mom1_out, const T* mom2, T* mom2_out, + const T* mom2_max, + T* mom2_max_out, const T* lr, const T* grad, const T* param, - T* param_out) + T* param_out, + bool amsgrad) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -272,10 +278,13 @@ class AdamFunctor { moment1_out_(mom1_out), moment2_(mom2), moment2_out_(mom2_out), + moment2_max_(mom2_max), + moment2_max_out_(mom2_max_out), lr_(lr), grad_(grad), param_(param), - param_out_(param_out) {} + param_out_(param_out), + amsgrad_(amsgrad) {} void operator()(size_t numel) const { Eigen::Map> g{ diff --git a/paddle/phi/kernels/funcs/jit/kernel_base.h b/paddle/phi/kernels/funcs/jit/kernel_base.h index e08f7821793c02..b5467e611f2494 100644 --- a/paddle/phi/kernels/funcs/jit/kernel_base.h +++ b/paddle/phi/kernels/funcs/jit/kernel_base.h @@ -275,8 +275,21 @@ struct AdamTuple { static constexpr KernelType kernel_type = kAdam; typedef T data_type; typedef adam_attr_t attr_type; - typedef void (*func_type)( - T, T, T, T, int64_t, const T*, const T*, const T*, const T*, T*, T*, T*); + typedef void (*func_type)(T, + T, + T, + T, + int64_t, + const T*, + const T*, + const T*, + const T*, + const T*, + T*, + T*, + T*, + T*, + bool); }; template diff --git a/paddle/phi/kernels/funcs/jit/refer/refer.h b/paddle/phi/kernels/funcs/jit/refer/refer.h index 926e07751232ea..348f7a30078b97 100644 --- a/paddle/phi/kernels/funcs/jit/refer/refer.h +++ b/paddle/phi/kernels/funcs/jit/refer/refer.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -523,16 +524,35 @@ void Adam(T beta1, const T* grad_ptr, const T* mom1_ptr, const T* mom2_ptr, + const T* mom2_max_ptr, const T* param_ptr, T* mom1_out_ptr, T* mom2_out_ptr, - T* param_out_ptr) { + T* mom2_max_out_ptr, + T* param_out_ptr, + bool amsgrad) { for (int i = 0; i < numel; ++i) { mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i]; mom2_out_ptr[i] = beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i]; + + // T mom2 = mom2_out_ptr[i]; + T mom2 = std::max(mom2_out_ptr[i], mom2_max_out_ptr[i]); + + // if (tmp > mom2_out_ptr[i]) { + // std::cout << tmp << " | " << mom2_ptr[i] << " | " << mom2_out_ptr[i] << + // " | " << mom2 << std::endl; std::cout << "---------- old " << + // std::endl; + + // } + // else { + // std::cout << "---------- new " << std::endl; + + // } + mom2_max_out_ptr[i] = mom2; + param_out_ptr[i] = - param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps)); + param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2) + eps)); } } diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 56be43fecb0d17..9f406d4615b8e5 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -134,6 +134,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -145,9 +146,11 @@ void AdamDenseKernel(const Context& dev_ctx, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -318,6 +321,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, + const std::vector& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -326,9 +330,11 @@ void MergedAdamKernel( const Scalar& epsilon, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector param_out, std::vector moment1_out, std::vector moment2_out, + std::vector moment2_max_out, std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out) { diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index e634caf1ad4268..c25873b63f1d4e 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -58,9 +58,9 @@ - op : adam_ (adam) inputs : - {param: Param, grad: Grad, learning_rate: LearningRate, moment1: Moment1, moment2: Moment2, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, master_param: MasterParam, skip_update: SkipUpdate} + {param: Param, grad: Grad, learning_rate: LearningRate, moment1: Moment1, moment2: Moment2, moment2_max: Moment2Max, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, master_param: MasterParam, skip_update: SkipUpdate} outputs : - {param_out: ParamOut, moment1_out: Moment1Out, moment2_out: Moment2Out, beta1_pow_out: Beta1PowOut, beta2_pow_out: Beta2PowOut, master_param_out: MasterParamOut} + {param_out: ParamOut, moment1_out: Moment1Out, moment2_out: Moment2Out, moment2_max_out: Moment2MaxOut, beta1_pow_out: Beta1PowOut, beta2_pow_out: Beta2PowOut, master_param_out: MasterParamOut} scalar : beta1 : data_type : float diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 49e93ed77d3c9b..ed5d776e3edf59 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -88,17 +88,17 @@ traits : pir::SideEffectTrait - op : adam_ - args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1 = 0.9f, Scalar beta2 = 0.999f, Scalar epsilon = 1.0e-8f, bool lazy_mode = false, int64_t min_row_size_to_use_multithread = 1000, bool multi_precision = false, bool use_global_beta_pow = false) - output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_out) + args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor moment2_max, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1 = 0.9f, Scalar beta2 = 0.999f, Scalar epsilon = 1.0e-8f, bool lazy_mode = false, int64_t min_row_size_to_use_multithread = 1000, bool multi_precision = false, bool use_global_beta_pow = false, bool amsgrad = false) + output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(moment2_max_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_out) infer_meta : func : AdamInferMeta spmd_rule : AdamInferSpmdDynamic kernel : - func : adam {dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense}, - adam_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense} + func : adam {dense, dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense, dense}, + adam_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense, dense} data_type : param optional : master_param, skip_update, master_param_out - inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) + inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (moment2_max -> moment2_max_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) traits : pir::SideEffectTrait - op : adamax_ diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 01db526dba9deb..87e28db080821f 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -191,6 +191,7 @@ class Adam(Optimizer): type: str _moment1_acc_str = "moment1" _moment2_acc_str = "moment2" + _moment2_acc_max_str = "moment2_max" _beta1_pow_acc_str = "beta1_pow_acc" _beta2_pow_acc_str = "beta2_pow_acc" @@ -208,6 +209,7 @@ def __init__( lazy_mode: bool = False, multi_precision: bool = False, use_multi_tensor: bool = False, + amsgrad: bool = False, name: str | None = None, ) -> None: assert learning_rate is not None @@ -255,11 +257,14 @@ def __init__( self._param_dict = self._create_multi_tensor_dict() self._moment1_dict = self._create_multi_tensor_dict() self._moment2_dict = self._create_multi_tensor_dict() + self._moment2_max_dict = self._create_multi_tensor_dict() self._beta1_pow_acc_dict = self._create_multi_tensor_dict() self._beta2_pow_acc_dict = self._create_multi_tensor_dict() self._master_weight_dict = self._create_multi_tensor_dict() self._master_weight_dict['FP32_LODTensor'] = None + self._amsgrad = amsgrad + def _add_moments_pows(self, p): acc_dtype = p.dtype if self._is_dtype_fp16_or_bf16(acc_dtype): @@ -269,6 +274,7 @@ def _add_moments_pows(self, p): acc_dtype = core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) + self._add_accumulator(self._moment2_acc_max_str, p, dtype=acc_dtype) self._add_accumulator( name=self._beta1_pow_acc_str, param=p, @@ -332,6 +338,9 @@ def _append_optimize_op(self, block, param_and_grad): moment2 = self._get_accumulator_master( self._moment2_acc_str, param_and_grad[0] ) + moment2_max = self._get_accumulator_master( + self._moment2_acc_max_str, param_and_grad[0] + ) beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] ) @@ -364,12 +373,13 @@ def _append_optimize_op(self, block, param_and_grad): self._get_auxiliary_var('found_inf') if in_pir_mode() else None ) - _, _, _, _, _, _ = _C_ops.adam_( + _ = _C_ops.adam_( param_and_grad[0], param_and_grad[1], lr, moment1, moment2, + moment2_max, beta1_pow_acc, beta2_pow_acc, master_weight, @@ -381,6 +391,7 @@ def _append_optimize_op(self, block, param_and_grad): 1000, find_master, False, + self._amsgrad, ) return None @@ -391,6 +402,7 @@ def _append_optimize_op(self, block, param_and_grad): "LearningRate": [lr], "Moment1": [moment1], "Moment2": [moment2], + "Moment2Max": [moment2_max], "Beta1Pow": [beta1_pow_acc], "Beta2Pow": [beta2_pow_acc], } @@ -405,6 +417,7 @@ def _append_optimize_op(self, block, param_and_grad): "ParamOut": [param_and_grad[0]], "Moment1Out": [moment1], "Moment2Out": [moment2], + "Moment2MaxOut": [moment2_max], "Beta1PowOut": [beta1_pow_acc], "Beta2PowOut": [beta2_pow_acc], } @@ -534,6 +547,9 @@ def _multi_tensor_init(self, target_block, parameters, param_group_idx): for param in parameters: moment1 = self._get_accumulator_master(self._moment1_acc_str, param) moment2 = self._get_accumulator_master(self._moment2_acc_str, param) + moment2_max = self._get_accumulator_master( + self._moment2_acc_max_str, param + ) beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param ) @@ -551,6 +567,9 @@ def _multi_tensor_init(self, target_block, parameters, param_group_idx): self._moment2_dict['FP32_LODTensor'][param_group_idx].append( moment2 ) + self._moment2_max_dict['FP32_LODTensor'][ + param_group_idx + ].append(moment2_max) self._beta1_pow_acc_dict['FP32_LODTensor'][ param_group_idx ].append(beta1_pow_acc) @@ -567,6 +586,9 @@ def _multi_tensor_init(self, target_block, parameters, param_group_idx): self._moment2_dict['FP16_LODTensor'][param_group_idx].append( moment2 ) + self._moment2_max_dict['FP16_LODTensor'][ + param_group_idx + ].append(moment2_max) self._beta1_pow_acc_dict['FP16_LODTensor'][ param_group_idx ].append(beta1_pow_acc) @@ -762,6 +784,7 @@ def _append_optimize_multi_tensor_op( lr_dict[key], self._moment1_dict[key][param_group_idx], self._moment2_dict[key][param_group_idx], + self._moment2_max_dict[key][param_group_idx], self._beta1_pow_acc_dict[key][param_group_idx], self._beta2_pow_acc_dict[key][param_group_idx], master_weight, @@ -770,6 +793,7 @@ def _append_optimize_multi_tensor_op( self._epsilon, find_master, False, + self._amsgrad, ) elif in_pir_mode(): master_weight = self._master_weight_dict[key] @@ -784,6 +808,7 @@ def _append_optimize_multi_tensor_op( lr_dict[key], self._moment1_dict[key][param_group_idx], self._moment2_dict[key][param_group_idx], + self._moment2_max_dict[key][param_group_idx], self._beta1_pow_acc_dict[key][param_group_idx], self._beta2_pow_acc_dict[key][param_group_idx], master_weight, @@ -792,6 +817,7 @@ def _append_optimize_multi_tensor_op( self._epsilon, find_master, False, + self._amsgrad, ) else: inputs = { @@ -800,6 +826,9 @@ def _append_optimize_multi_tensor_op( "LearningRate": lr_dict[key], "Moment1": self._moment1_dict[key][param_group_idx], "Moment2": self._moment2_dict[key][param_group_idx], + "Moment2Max": self._moment2_max_dict[key][ + param_group_idx + ], "Beta1Pow": self._beta1_pow_acc_dict[key][ param_group_idx ], @@ -811,6 +840,9 @@ def _append_optimize_multi_tensor_op( "ParamOut": self._param_dict[key][param_group_idx], "Moment1Out": self._moment1_dict[key][param_group_idx], "Moment2Out": self._moment2_dict[key][param_group_idx], + "Moment2MaxOut": self._moment2_max_dict[key][ + param_group_idx + ], "Beta1PowOut": self._beta1_pow_acc_dict[key][ param_group_idx ], From 640be9b49d16a42bf7b7016f0f72f49b7b6beac8 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 29 Aug 2024 23:06:46 +0800 Subject: [PATCH 02/33] [update] refer.h --- paddle/phi/kernels/funcs/jit/refer/refer.h | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/paddle/phi/kernels/funcs/jit/refer/refer.h b/paddle/phi/kernels/funcs/jit/refer/refer.h index 348f7a30078b97..a98b925dec8ea0 100644 --- a/paddle/phi/kernels/funcs/jit/refer/refer.h +++ b/paddle/phi/kernels/funcs/jit/refer/refer.h @@ -536,20 +536,13 @@ void Adam(T beta1, mom2_out_ptr[i] = beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i]; - // T mom2 = mom2_out_ptr[i]; - T mom2 = std::max(mom2_out_ptr[i], mom2_max_out_ptr[i]); - - // if (tmp > mom2_out_ptr[i]) { - // std::cout << tmp << " | " << mom2_ptr[i] << " | " << mom2_out_ptr[i] << - // " | " << mom2 << std::endl; std::cout << "---------- old " << - // std::endl; - - // } - // else { - // std::cout << "---------- new " << std::endl; - - // } - mom2_max_out_ptr[i] = mom2; + T mom2; + if (amsgrad) { + mom2 = std::max(mom2_out_ptr[i], mom2_max_out_ptr[i]); + mom2_max_out_ptr[i] = mom2; + } else { + mom2 = mom2_out_ptr[i]; + } param_out_ptr[i] = param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2) + eps)); From caf919ac2d8c266a0bce0e8b0b8ac197b02a1e0f Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 4 Sep 2024 17:41:29 +0800 Subject: [PATCH 03/33] [Add] amsgrad gpu --- paddle/phi/kernels/funcs/adam_functors.h | 40 +++++++++-- paddle/phi/kernels/gpu/adam_kernel.cu | 89 +++++++++++++++++++----- 2 files changed, 107 insertions(+), 22 deletions(-) diff --git a/paddle/phi/kernels/funcs/adam_functors.h b/paddle/phi/kernels/funcs/adam_functors.h index 5ad0a64fcd766d..598c3bb2cd5773 100644 --- a/paddle/phi/kernels/funcs/adam_functors.h +++ b/paddle/phi/kernels/funcs/adam_functors.h @@ -174,10 +174,13 @@ class AdamFunctor { T* moment1_out_; const T* moment2_; T* moment2_out_; + const T* moment2_max_; + T* moment2_max_out_; const T* lr_; const T* grad_; const T* param_; T* param_out_; + bool amsgrad_; public: AdamFunctor(T beta1, @@ -189,10 +192,13 @@ class AdamFunctor { T* mom1_out, const T* mom2, T* mom2_out, + const T* mom2_max, + T* mom2_max_out, const T* lr, const T* grad, const T* param, - T* param_out) + T* param_out, + bool amsgrad) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -202,16 +208,20 @@ class AdamFunctor { moment1_out_(mom1_out), moment2_(mom2), moment2_out_(mom2_out), + moment2_max_(mom2_max), + moment2_max_out_(mom2_max_out), lr_(lr), grad_(grad), param_(param), - param_out_(param_out) {} + param_out_(param_out), + amsgrad_(amsgrad) {} inline HOSTDEVICE void operator()(size_t i) const { // Merge all memory access together. T g = grad_[i]; T mom1 = moment1_[i]; T mom2 = moment2_[i]; + T mom2_max = moment2_max_[i]; T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; @@ -222,7 +232,16 @@ class AdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); + + T mom2_max_; + if (amsgrad_) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out_[i] = mom2_max_; + } else { + mom2_max_ = mom2; + } + + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); // Write back to global memory moment1_out_[i] = mom1; @@ -293,6 +312,8 @@ class AdamFunctor { moment1_, static_cast(numel)}; Eigen::Map> mom2{ moment2_, static_cast(numel)}; + Eigen::Map> mom2_max{ + moment2_max_, static_cast(numel)}; Eigen::Map> param{ param_, static_cast(numel)}; @@ -302,6 +323,8 @@ class AdamFunctor { moment1_out_, static_cast(numel)}; Eigen::Map> moment2_out{ moment2_out_, static_cast(numel)}; + Eigen::Map> moment2_max_out{ + moment2_max_out_, static_cast(numel)}; T lr = *lr_; T beta1_pow = *beta1_pow_; @@ -312,8 +335,15 @@ class AdamFunctor { moment1_out = beta1_ * mom1 + (1 - beta1_) * g; moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g; - param_out = param - lr * (moment1_out / (moment2_out.sqrt() + - epsilon_ * sqrt(1 - beta2_pow))); + + if (amsgrad_) { + moment2_max_out = moment2_out.cwiseMax(mom2_max); + param_out = param - lr * (moment1_out / (moment2_max_out.sqrt() + + epsilon_ * sqrt(1 - beta2_pow))); + } else { + param_out = param - lr * (moment1_out / (moment2_out.sqrt() + + epsilon_ * sqrt(1 - beta2_pow))); + } } }; diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 9f406d4615b8e5..2a7613b628eb94 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -40,13 +40,16 @@ __global__ void AdamKernelREG(MT beta1, MT* moment1_out, const MT* moment2, MT* moment2_out, + const MT* moment2_max, + MT* moment2_max_out, const MT* lr_, const TG* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, - int64_t ndim) { + int64_t ndim, + bool amsgrad) { MT lr = *lr_; MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; @@ -58,10 +61,21 @@ __global__ void AdamKernelREG(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + MT mom2_max = static_cast(moment2_max[id]); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + MT mom2_max_; + if (amsgrad) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out[id] = mom2_max_; + } else { + mom2_max_ = mom2; + } + + MT denom = + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; @@ -83,13 +97,16 @@ __global__ void AdamKernelMEM(MT beta1, MT* moment1_out, const MT* moment2, MT* moment2_out, + const MT* moment2_max, + MT* moment2_max_out, const MT* lr_, const TG* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, - int64_t ndim) { + int64_t ndim, + bool amsgrad) { MT lr = *lr_; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; @@ -101,10 +118,21 @@ __global__ void AdamKernelMEM(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + MT mom2_max = static_cast(moment2_max[id]); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + MT mom2_max_; + if (amsgrad) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out[id] = mom2_max_; + } else { + mom2_max_ = mom2; + } + + MT denom = + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; @@ -177,6 +205,7 @@ void AdamDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -228,13 +257,16 @@ void AdamDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad.data(), param.data(), dev_ctx.template Alloc(param_out), master_in_data, master_out_data, - param.numel()); + param.numel(), + amsgrad); } else { AdamKernelREG<<>>( beta1_, @@ -246,13 +278,16 @@ void AdamDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad.data(), param.data(), dev_ctx.template Alloc(param_out), master_in_data, master_out_data, - param.numel()); + param.numel(), + amsgrad); } if (!use_global_beta_pow) { // Cpu update @@ -274,13 +309,16 @@ void AdamDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad.data(), param.data(), dev_ctx.template Alloc(param_out), master_in_data, master_out_data, - param.numel()); + param.numel(), + amsgrad); } else { AdamKernelMEM<<>>( beta1_, @@ -292,13 +330,16 @@ void AdamDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad.data(), param.data(), dev_ctx.template Alloc(param_out), master_in_data, master_out_data, - param.numel()); + param.numel(), + amsgrad); } if (!use_global_beta_pow) { // Update with gpu @@ -373,13 +414,16 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), + moment2_max[idx]->data(), + dev_ctx.template Alloc(moment2_max_out[idx]), learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), dev_ctx.template Alloc(param_out[idx]), master_in_data, master_out_data, - param[idx]->numel()); + param[idx]->numel(), + amsgrad); } else { AdamKernelREG<<>>( beta1_, @@ -391,13 +435,16 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), + moment2_max[idx]->data(), + dev_ctx.template Alloc(moment2_max_out[idx]), learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), dev_ctx.template Alloc(param_out[idx]), master_in_data, master_out_data, - param[idx]->numel()); + param[idx]->numel(), + amsgrad); } if (!use_global_beta_pow) { // Cpu update @@ -419,13 +466,16 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), + moment2_max[idx]->data(), + dev_ctx.template Alloc(moment2_max_out[idx]), learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), dev_ctx.template Alloc(param_out[idx]), master_in_data, master_out_data, - param[idx]->numel()); + param[idx]->numel(), + amsgrad); } else { AdamKernelMEM<<>>( beta1_, @@ -437,13 +487,16 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), + moment2_max[idx]->data(), + dev_ctx.template Alloc(moment2_max_out[idx]), learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), dev_ctx.template Alloc(param_out[idx]), master_in_data, master_out_data, - param[idx]->numel()); + param[idx]->numel(), + amsgrad); } if (!use_global_beta_pow) { // Update with gpu @@ -470,9 +523,9 @@ PD_REGISTER_KERNEL(adam, phi::dtype::float16, phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); if (kernel_key.dtype() == phi::DataType::FLOAT16 || kernel_key.dtype() == phi::DataType::BFLOAT16) { @@ -481,9 +534,10 @@ PD_REGISTER_KERNEL(adam, kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } PD_REGISTER_KERNEL(merged_adam, @@ -495,8 +549,8 @@ PD_REGISTER_KERNEL(merged_adam, phi::dtype::float16, phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); if (kernel_key.dtype() == phi::DataType::FLOAT16 || kernel_key.dtype() == phi::DataType::BFLOAT16) { @@ -505,7 +559,8 @@ PD_REGISTER_KERNEL(merged_adam, kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } From aa289ad823bf608bf09dd7f035d3b5ca7fa48a90 Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 4 Sep 2024 23:08:27 +0800 Subject: [PATCH 04/33] [Add] amsgrad for adamw and fused --- paddle/fluid/operators/fused/fused_adam_op.cc | 9 ++ .../fluid/operators/ops_signature/adam_sig.cc | 3 + .../operators/ops_signature/fused_adam_sig.cc | 5 +- paddle/fluid/pybind/eager_generator.cc | 12 +++ paddle/phi/infermeta/multiary.cc | 17 +++- paddle/phi/infermeta/multiary.h | 9 ++ paddle/phi/infermeta/spmd_rules/optimizer.cc | 8 +- paddle/phi/infermeta/spmd_rules/optimizer.h | 4 +- paddle/phi/kernels/adamw_kernel.h | 3 + paddle/phi/kernels/cpu/adamw_kernel.cc | 21 +++-- paddle/phi/kernels/cpu/fused_adam_kernel.cc | 22 ++++- paddle/phi/kernels/funcs/adam_functors.h | 87 ++++++++++++++++--- paddle/phi/kernels/funcs/jit/kernel_base.h | 5 +- paddle/phi/kernels/funcs/jit/refer/refer.h | 17 +++- paddle/phi/kernels/fused_adam_kernel.h | 3 + paddle/phi/kernels/gpu/adamw_kernel.cu | 58 +++++++++++-- paddle/phi/kernels/gpu/fused_adam_kernel.cu | 7 ++ .../phi/kernels/selected_rows/adam_kernel.h | 3 + .../phi/kernels/selected_rows/adamw_kernel.h | 3 + .../kernels/selected_rows/cpu/adam_kernel.cc | 9 +- .../kernels/selected_rows/cpu/adamw_kernel.cc | 9 ++ .../kernels/selected_rows/gpu/adam_kernel.cu | 39 +++++++-- .../kernels/selected_rows/gpu/adamw_kernel.cu | 37 ++++++-- .../ops/yaml/inconsistent/dygraph_ops.yaml | 6 +- .../phi/ops/yaml/inconsistent/static_ops.yaml | 6 +- paddle/phi/ops/yaml/op_compat.yaml | 8 +- paddle/phi/ops/yaml/ops.yaml | 12 +-- python/paddle/optimizer/adam.py | 10 ++- python/paddle/optimizer/adamw.py | 16 +++- 29 files changed, 375 insertions(+), 73 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_adam_op.cc b/paddle/fluid/operators/fused/fused_adam_op.cc index d786dbd7c2728f..3649410a6459fd 100644 --- a/paddle/fluid/operators/fused/fused_adam_op.cc +++ b/paddle/fluid/operators/fused/fused_adam_op.cc @@ -57,6 +57,8 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("LearningRate", "(Tensor, default Tensor) Learning rate"); AddInput("Moments1", "(Tensor) Input first moments").AsDuplicable(); AddInput("Moments2", "(Tensor) Input second moments").AsDuplicable(); + AddInput("Moments2Max", "(Tensor) Input second moments max for amsgrad") + .AsDuplicable(); AddInput("Beta1Pows", "(Tensor, default Tensor) Input beta1 power accumulator") .AsDuplicable(); @@ -72,6 +74,9 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("ParamsOut", "(Tensor) Output parameters").AsDuplicable(); AddOutput("Moments1Out", "(Tensor) Output first moments").AsDuplicable(); AddOutput("Moments2Out", "(Tensor) Output second moments").AsDuplicable(); + AddOutput("Moments2MaxOut", + "(Tensor) Output second moments max for amsgrad") + .AsDuplicable(); AddOutput("Beta1PowsOut", "(Tensor) Output beta1 power accumulator") .AsDuplicable(); AddOutput("Beta2PowsOut", "(Tensor) Output beta2 power accumulator") @@ -122,6 +127,10 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker { "Whether to use global beta_pow for whole model instead of " "creating beta_pow for each parameter.") .SetDefault(false); + AddAttr("amsgrad", + "(bool, default false) " + "Whether to use the AMSGrad of this algorithm.") + .SetDefault(false); AddComment(R"DOC( Adam Optimizer. diff --git a/paddle/fluid/operators/ops_signature/adam_sig.cc b/paddle/fluid/operators/ops_signature/adam_sig.cc index f3e7eeb6b67629..7815a2a3166efd 100644 --- a/paddle/fluid/operators/ops_signature/adam_sig.cc +++ b/paddle/fluid/operators/ops_signature/adam_sig.cc @@ -24,6 +24,7 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) { "LearningRate", "Moment1", "Moment2", + "Moment2Max", "Beta1Pow", "Beta2Pow", "MasterParam", @@ -31,6 +32,7 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) { paddle::small_vector out_names = {"ParamOut", "Moment1Out", "Moment2Out", + "Moment2MaxOut", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}; @@ -46,6 +48,7 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) { attr_names.emplace_back("min_row_size_to_use_multithread"); attr_names.emplace_back("multi_precision"); attr_names.emplace_back("use_global_beta_pow"); + attr_names.emplace_back("amsgrad"); if (ctx.IsSelectedRowsInput("Grad")) { return KernelSignature("adam_dense_param_sparse_grad", diff --git a/paddle/fluid/operators/ops_signature/fused_adam_sig.cc b/paddle/fluid/operators/ops_signature/fused_adam_sig.cc index dc787529a02a2f..f619beee9f718b 100644 --- a/paddle/fluid/operators/ops_signature/fused_adam_sig.cc +++ b/paddle/fluid/operators/ops_signature/fused_adam_sig.cc @@ -25,6 +25,7 @@ KernelSignature FusedAdamOpArgumentMapping( "LearningRate", "Moments1", "Moments2", + "Moments2Max", "Beta1Pows", "Beta2Pows", "MasterParams", @@ -32,6 +33,7 @@ KernelSignature FusedAdamOpArgumentMapping( paddle::small_vector out_names = {"ParamsOut", "Moments1Out", "Moments2Out", + "Moments2MaxOut", "Beta1PowsOut", "Beta2PowsOut", "MasterParamsOut"}; @@ -42,7 +44,8 @@ KernelSignature FusedAdamOpArgumentMapping( "weight_decay", "use_adamw", "multi_precision", - "use_global_beta_pow"}; + "use_global_beta_pow", + "amsgrad"}; return KernelSignature("fused_adam", std::move(in_names), diff --git a/paddle/fluid/pybind/eager_generator.cc b/paddle/fluid/pybind/eager_generator.cc index af35e04585c868..d806dac650bd48 100644 --- a/paddle/fluid/pybind/eager_generator.cc +++ b/paddle/fluid/pybind/eager_generator.cc @@ -3344,6 +3344,7 @@ std::map> op_passing_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", + "Moment2MaxOut", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, @@ -3351,6 +3352,7 @@ std::map> op_passing_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", + "Moment2MaxOut", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, @@ -3358,6 +3360,7 @@ std::map> op_passing_outs_map = { {"ParamsOut", "Moments1Out", "Moments2Out", + "Moments2MaxOut", "Beta1PowsOut", "Beta2PowsOut", "MasterParamsOut"}}, @@ -3365,6 +3368,7 @@ std::map> op_passing_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", + "Moment2MaxOut", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, @@ -3553,6 +3557,7 @@ std::map> op_ins_map = { "LearningRate", "Moment1", "Moment2", + "Moment2Max", "Beta1Pow", "Beta2Pow", "MasterParam"}}, @@ -3562,6 +3567,7 @@ std::map> op_ins_map = { "LearningRate", "Moment1", "Moment2", + "Moment2Max", "Beta1Pow", "Beta2Pow", "MasterParam"}}, @@ -3571,6 +3577,7 @@ std::map> op_ins_map = { "LearningRate", "Moments1", "Moments2", + "Moments2Max", "Beta1Pows", "Beta2Pows", "MasterParams", @@ -3581,6 +3588,7 @@ std::map> op_ins_map = { "LearningRate", "Moment1", "Moment2", + "Moment2Max", "Beta1Pow", "Beta2Pow", "MasterParam"}}, @@ -3732,6 +3740,7 @@ std::map> op_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", + "Moment2MaxOut", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, @@ -3739,6 +3748,7 @@ std::map> op_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", + "Moment2MaxOut", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, @@ -3746,6 +3756,7 @@ std::map> op_outs_map = { {"ParamsOut", "Moments1Out", "Moments2Out", + "Moments2MaxOut", "Beta1PowsOut", "Beta2PowsOut", "MasterParamsOut"}}, @@ -3753,6 +3764,7 @@ std::map> op_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", + "Moment2MaxOut", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 7945d6bffc9c71..c89e19ee54f145 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -333,6 +333,7 @@ void AdamwInferMeta(const MetaTensor& param, const MetaTensor& learning_rate, const MetaTensor& moment1, const MetaTensor& moment2, + const MetaTensor& moment2_max, const MetaTensor& beta1_pow, const MetaTensor& beta2_pow, const MetaTensor& master_param, @@ -347,9 +348,11 @@ void AdamwInferMeta(const MetaTensor& param, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, MetaTensor* param_out, MetaTensor* moment1_out, MetaTensor* moment2_out, + MetaTensor* moment2_max_out, MetaTensor* beta1_pow_out, MetaTensor* beta2_pow_out, MetaTensor* master_param_outs) { @@ -358,7 +361,7 @@ void AdamwInferMeta(const MetaTensor& param, learning_rate, moment1, moment2, - moment2, // TODO(megemini) + moment2_max, beta1_pow, beta2_pow, master_param, @@ -370,11 +373,11 @@ void AdamwInferMeta(const MetaTensor& param, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow, - false, // TODO(megemini) + amsgrad, param_out, moment1_out, moment2_out, - moment2_out, // TODO(megemini) + moment2_max_out, beta1_pow_out, beta2_pow_out, master_param_outs); @@ -3862,6 +3865,7 @@ void MergedAdamInferMeta( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, + const std::vector& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -3870,9 +3874,11 @@ void MergedAdamInferMeta( const Scalar& epsilon, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector param_out, std::vector moment1_out, std::vector moment2_out, + std::vector moment2_max_out, std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out) {} @@ -5790,6 +5796,7 @@ void FusedAdamInferMeta( const MetaTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, + const std::vector& moments2_max, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, @@ -5802,9 +5809,11 @@ void FusedAdamInferMeta( bool use_adamw, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector params_out, std::vector moments1_out, std::vector moments2_out, + std::vector moments2_max_out, std::vector beta1_pows_out, std::vector beta2_pows_out, std::vector master_params_out) { @@ -5816,6 +5825,8 @@ void FusedAdamInferMeta( moments1_out[i]->set_dtype(moments1[i]->dtype()); moments2_out[i]->set_dims(moments2[i]->dims()); moments2_out[i]->set_dtype(moments2[i]->dtype()); + moments2_max_out[i]->set_dims(moments2_max[i]->dims()); + moments2_max_out[i]->set_dtype(moments2_max[i]->dtype()); beta1_pows_out[i]->set_dims(beta1_pows[i]->dims()); beta1_pows_out[i]->set_dtype(beta1_pows[i]->dtype()); beta2_pows_out[i]->set_dims(beta2_pows[i]->dims()); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 360be7938206c0..7c6f2afc69f5c5 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -112,6 +112,7 @@ void AdamwInferMeta(const MetaTensor& param, const MetaTensor& learning_rate, const MetaTensor& moment1, const MetaTensor& moment2, + const MetaTensor& moment2_max, const MetaTensor& beta1_pow, const MetaTensor& beta2_pow, const MetaTensor& master_param, @@ -126,9 +127,11 @@ void AdamwInferMeta(const MetaTensor& param, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, MetaTensor* param_out, MetaTensor* moment1_out, MetaTensor* moment2_out, + MetaTensor* moment2_max_out, MetaTensor* beta1_pow_out, MetaTensor* beta2_pow_out, MetaTensor* master_param_outs); @@ -713,6 +716,7 @@ void MergedAdamInferMeta( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, + const std::vector& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -721,9 +725,11 @@ void MergedAdamInferMeta( const Scalar& epsilon, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector param_out, std::vector moment1_out, std::vector moment2_out, + std::vector moment2_max_out, std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out); @@ -1119,6 +1125,7 @@ void FusedAdamInferMeta( const MetaTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, + const std::vector& moments2_max, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, @@ -1131,9 +1138,11 @@ void FusedAdamInferMeta( bool use_adamw, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector params_out, std::vector moments1_out, std::vector moments2_out, + std::vector moments2_max_out, std::vector beta1_pows_out, std::vector beta2_pows_out, std::vector master_params_out); diff --git a/paddle/phi/infermeta/spmd_rules/optimizer.cc b/paddle/phi/infermeta/spmd_rules/optimizer.cc index 7b1e711a4b97b8..60382b1ffefad0 100644 --- a/paddle/phi/infermeta/spmd_rules/optimizer.cc +++ b/paddle/phi/infermeta/spmd_rules/optimizer.cc @@ -203,6 +203,7 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param, const DistMetaTensor& learning_rate, const DistMetaTensor& moment1, const DistMetaTensor& moment2, + const DistMetaTensor& moment2_max, const DistMetaTensor& beta1_pow, const DistMetaTensor& beta2_pow, const DistMetaTensor& master_param, @@ -216,12 +217,14 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, - bool use_global_beta_pow) { + bool use_global_beta_pow, + bool amsgrad) { return AdamInferSpmdDynamic(param, grad, learning_rate, moment1, moment2, + moment2_max, beta1_pow, beta2_pow, master_param, @@ -232,7 +235,8 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param, lazy_mode, min_row_size_to_use_multithread, multi_precision, - use_global_beta_pow); + use_global_beta_pow, + amsgrad); } SpmdInfo SgdInferSpmd(const DistMetaTensor& param, diff --git a/paddle/phi/infermeta/spmd_rules/optimizer.h b/paddle/phi/infermeta/spmd_rules/optimizer.h index bc033f5310216a..3fd825a2e14965 100644 --- a/paddle/phi/infermeta/spmd_rules/optimizer.h +++ b/paddle/phi/infermeta/spmd_rules/optimizer.h @@ -47,6 +47,7 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param, const DistMetaTensor& learning_rate, const DistMetaTensor& moment1, const DistMetaTensor& moment2, + const DistMetaTensor& moment2_max, const DistMetaTensor& beta1_pow, const DistMetaTensor& beta2_pow, const DistMetaTensor& master_param, @@ -60,7 +61,8 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, - bool use_global_beta_pow); + bool use_global_beta_pow, + bool amsgrad); SpmdInfo SgdInferSpmd(const DistMetaTensor& param, const DistMetaTensor& learning_rate, diff --git a/paddle/phi/kernels/adamw_kernel.h b/paddle/phi/kernels/adamw_kernel.h index 5cbb38143ff6f7..ea34b9ca289855 100644 --- a/paddle/phi/kernels/adamw_kernel.h +++ b/paddle/phi/kernels/adamw_kernel.h @@ -26,6 +26,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -40,9 +41,11 @@ void AdamwDenseKernel(const Context& dev_ctx, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs); diff --git a/paddle/phi/kernels/cpu/adamw_kernel.cc b/paddle/phi/kernels/cpu/adamw_kernel.cc index ffef5d50c4e641..97a5d44cfab4f7 100644 --- a/paddle/phi/kernels/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/cpu/adamw_kernel.cc @@ -35,6 +35,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -49,9 +50,11 @@ void AdamwDenseKernel(const Context& dev_ctx, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -75,7 +78,7 @@ void AdamwDenseKernel(const Context& dev_ctx, learning_rate, moment1, moment2, - moment2, // TODO(megemini) + moment2_max, beta1_pow, beta2_pow, master_param, @@ -87,11 +90,11 @@ void AdamwDenseKernel(const Context& dev_ctx, min_row_size_to_use_multithread, multi_precision, use_global_beta_pow, - false, // TODO(megemini) + amsgrad, param_out, moment1_out, moment2_out, - moment2_out, // TODO(megemini) + moment2_max_out, beta1_pow_out, beta2_pow_out, master_param_outs); @@ -133,6 +136,7 @@ void AdamwDenseKernel(const Context& dev_ctx, T* param_out_ptr = dev_ctx.template Alloc(param_out); T* mom1_out_ptr = dev_ctx.template Alloc(moment1_out); T* mom2_out_ptr = dev_ctx.template Alloc(moment2_out); + T* mom2_max_out_ptr = dev_ctx.template Alloc(moment2_max_out); T old_lr = learning_rate.data()[0]; T learning_rate_ = learning_rate.data()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); @@ -143,6 +147,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const T* param_ptr = param.data(); const T* mom1_ptr = moment1.data(); const T* mom2_ptr = moment2.data(); + const T* mom2_max_ptr = moment2_max.data(); const T* grad_ptr = grad.data(); auto adamw = @@ -167,10 +172,13 @@ void AdamwDenseKernel(const Context& dev_ctx, grad_ptr + offset, mom1_ptr + offset, mom2_ptr + offset, + mom2_max_ptr + offset, param_ptr + offset, mom1_out_ptr + offset, mom2_out_ptr + offset, - param_out_ptr + offset); + mom2_max_out_ptr + offset, + param_out_ptr + offset, + amsgrad); } if (numel % chunk_size != 0) { @@ -187,10 +195,13 @@ void AdamwDenseKernel(const Context& dev_ctx, grad_ptr + offset, mom1_ptr + offset, mom2_ptr + offset, + mom2_max_ptr + offset, param_ptr + offset, mom1_out_ptr + offset, mom2_out_ptr + offset, - param_out_ptr + offset); + mom2_max_out_ptr + offset, + param_out_ptr + offset, + amsgrad); } } diff --git a/paddle/phi/kernels/cpu/fused_adam_kernel.cc b/paddle/phi/kernels/cpu/fused_adam_kernel.cc index df935c57da7361..66712300080f6d 100644 --- a/paddle/phi/kernels/cpu/fused_adam_kernel.cc +++ b/paddle/phi/kernels/cpu/fused_adam_kernel.cc @@ -36,6 +36,7 @@ void FusedAdamKernel( const DenseTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, + const std::vector& moments2_max, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, @@ -48,9 +49,11 @@ void FusedAdamKernel( bool use_adamw, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector params_out, std::vector moments1_out, std::vector moments2_out, + std::vector moments2_max_out, std::vector beta1_pows_out, std::vector beta2_pows_out, std::vector master_params_out) { @@ -79,6 +82,15 @@ void FusedAdamKernel( "is %d, the size of Input(params) is %d.", moments2.size(), params_num)); + PADDLE_ENFORCE_EQ( + params_num, + moments2_max.size(), + errors::InvalidArgument( + "The size of Input(moments2 max) must be equal to " + "Input(params), but got the size of Input(moments2 max) " + "is %d, the size of Input(params) is %d.", + moments2_max.size(), + params_num)); PADDLE_ENFORCE_EQ(params_num, beta1_pows.size(), errors::InvalidArgument( @@ -106,7 +118,7 @@ void FusedAdamKernel( learning_rate, *moments1[idx], *moments2[idx], - *moments2[idx], // TODO(megemini) + *moments2_max[idx], *beta1_pows[idx], *beta2_pows[idx], master_params_tmp, @@ -118,11 +130,11 @@ void FusedAdamKernel( 1000, multi_precision, use_global_beta_pow, - false, // TODO(megemini) + amsgrad, params_out[idx], moments1_out[idx], moments2_out[idx], - moments2_out[idx], // TODO(megemini) + moments2_max_out[idx], beta1_pows_out[idx], beta2_pows_out[idx], master_params_out.empty() ? nullptr : master_params_out[idx]); @@ -134,6 +146,7 @@ void FusedAdamKernel( learning_rate, *moments1[idx], *moments2[idx], + *moments2_max[idx], *beta1_pows[idx], *beta2_pows[idx], master_params_tmp, @@ -148,9 +161,11 @@ void FusedAdamKernel( 1000, multi_precision, use_global_beta_pow, + amsgrad, params_out[idx], moments1_out[idx], moments2_out[idx], + moments2_max_out[idx], beta1_pows_out[idx], beta2_pows_out[idx], master_params_out.empty() ? nullptr : master_params_out[idx]); @@ -167,4 +182,5 @@ PD_REGISTER_KERNEL( kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(4).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(5).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(6).SetDataType(phi::DataType::UNDEFINED); } diff --git a/paddle/phi/kernels/funcs/adam_functors.h b/paddle/phi/kernels/funcs/adam_functors.h index 598c3bb2cd5773..57bdeeee56dd45 100644 --- a/paddle/phi/kernels/funcs/adam_functors.h +++ b/paddle/phi/kernels/funcs/adam_functors.h @@ -363,6 +363,8 @@ class SparseAdamFunctor { MT* moment1_out_; const MT* moment2_; MT* moment2_out_; + const MT* moment2_max_; + MT* moment2_max_out_; const MT* lr_; const T* grad_; const T* param_; @@ -374,6 +376,7 @@ class SparseAdamFunctor { int64_t row_numel_; int64_t row_count_; bool lazy_mode_; + bool amsgrad_; public: SparseAdamFunctor(MT beta1, @@ -385,6 +388,8 @@ class SparseAdamFunctor { MT* mom1_out, const MT* mom2, MT* mom2_out, + const MT* mom2_max, + MT* mom2_max_out, const MT* lr, const T* grad, const T* param, @@ -394,7 +399,8 @@ class SparseAdamFunctor { const int64_t* rows, int64_t row_numel, int64_t row_count, - bool lazy_mode) + bool lazy_mode, + bool amsgrad) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -404,6 +410,8 @@ class SparseAdamFunctor { moment1_out_(mom1_out), moment2_(mom2), moment2_out_(mom2_out), + moment2_max_(mom2_max), + moment2_max_out_(mom2_max_out), lr_(lr), grad_(grad), param_(param), @@ -413,12 +421,14 @@ class SparseAdamFunctor { rows_(rows), row_numel_(row_numel), row_count_(row_count), - lazy_mode_(lazy_mode) {} + lazy_mode_(lazy_mode), + amsgrad_(amsgrad) {} inline HOSTDEVICE void adam_update(size_t i, MT g) const { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; + MT mom2_max = moment2_max_[i]; MT lr = *lr_; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; @@ -430,7 +440,16 @@ class SparseAdamFunctor { mom1 = beta1_ * mom1 + (static_cast(1.0) - beta1_) * g; mom2 = beta2_ * mom2 + (static_cast(1.0) - beta2_) * g * g; - p -= lr * (mom1 / (sqrt(mom2) + + + MT mom2_max_; + if (amsgrad_) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out_[i] = mom2_max_; + } else { + mom2_max_ = mom2; + } + + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); // Write back to global memory @@ -469,6 +488,8 @@ class SparseAdamFunctor { T* moment1_out_; const T* moment2_; T* moment2_out_; + const T* moment2_max_; + T* moment2_max_out_; const T* lr_; const T* grad_; const T* param_; @@ -477,6 +498,7 @@ class SparseAdamFunctor { const int64_t* rows_; int64_t row_numel_; int64_t row_count_; + bool amsgrad_; public: SparseAdamFunctor(T beta1, @@ -488,6 +510,8 @@ class SparseAdamFunctor { T* mom1_out, const T* mom2, T* mom2_out, + const T* mom2_max, + T* mom2_max_out, const T* lr, const T* grad, const T* param, @@ -495,7 +519,8 @@ class SparseAdamFunctor { const int64_t* rows, int64_t row_numel, int64_t row_count, - bool lazy_mode UNUSED) + bool lazy_mode UNUSED, + bool amsgrad) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -505,18 +530,22 @@ class SparseAdamFunctor { moment1_out_(mom1_out), moment2_(mom2), moment2_out_(mom2_out), + moment2_max_(mom2_max), + moment2_max_out_(mom2_max_out), lr_(lr), grad_(grad), param_(param), param_out_(param_out), rows_(rows), row_numel_(row_numel), - row_count_(row_count) {} + row_count_(row_count), + amsgrad_(amsgrad) {} inline HOSTDEVICE void adam_update(size_t i, T g) const { // The following code is the same as dense T mom1 = moment1_[i]; T mom2 = moment2_[i]; + T mom2_max = moment2_max_[i]; T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; @@ -527,7 +556,16 @@ class SparseAdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); + + T mom2_max_; + if (amsgrad_) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out_[i] = mom2_max_; + } else { + mom2_max_ = mom2; + } + + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); // Write back to global memory moment1_out_[i] = mom1; @@ -554,12 +592,22 @@ class SparseAdamFunctor { for (int64_t k = 0; k != row_numel_; ++k) { T mom1 = moment1_[i * row_numel_ + k]; T mom2 = moment2_[i * row_numel_ + k]; + T mom2_max = moment2_max_[i * row_numel_ + k]; + T p = param_[i * row_numel_ + k]; mom1 = beta1_ * mom1; mom2 = beta2_ * mom2; - p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + T mom2_max_; + if (amsgrad_) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out_[i * row_numel_ + k] = mom2_max_; + } else { + mom2_max_ = mom2; + } + + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_)); // Write back to global memory moment1_out_[i * row_numel_ + k] = mom1; moment2_out_[i * row_numel_ + k] = mom2; @@ -617,6 +665,8 @@ class SparseAdamWFunctor { MT* moment1_out_; const MT* moment2_; MT* moment2_out_; + const MT* moment2_max_; + MT* moment2_max_out_; const MT* lr_; const T* grad_; const T* param_; @@ -628,6 +678,7 @@ class SparseAdamWFunctor { int64_t row_numel_; int64_t row_count_; bool lazy_mode_; + bool amsgrad_; public: SparseAdamWFunctor(MT beta1, @@ -641,6 +692,8 @@ class SparseAdamWFunctor { MT* mom1_out, const MT* mom2, MT* mom2_out, + const MT* mom2_max, + MT* mom2_max_out, const MT* lr, const T* grad, const T* param, @@ -650,7 +703,8 @@ class SparseAdamWFunctor { const int64_t* rows, int64_t row_numel, int64_t row_count, - bool lazy_mode) + bool lazy_mode, + bool amsgrad) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -662,6 +716,8 @@ class SparseAdamWFunctor { moment1_out_(mom1_out), moment2_(mom2), moment2_out_(mom2_out), + moment2_max_(mom2_max), + moment2_max_out_(mom2_max_out), lr_(lr), grad_(grad), param_(param), @@ -671,12 +727,14 @@ class SparseAdamWFunctor { rows_(rows), row_numel_(row_numel), row_count_(row_count), - lazy_mode_(lazy_mode) {} + lazy_mode_(lazy_mode), + amsgrad_(amsgrad) {} inline HOSTDEVICE void adamw_update(size_t i, MT g) const { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; + MT mom2_max = moment2_max_[i]; MT lr = *lr_ * lr_ratio_; MT lr_orig = lr; MT beta1_pow = *beta1_pow_; @@ -689,8 +747,17 @@ class SparseAdamWFunctor { mom1 = beta1_ * mom1 + (static_cast(1.0) - beta1_) * g; mom2 = beta2_ * mom2 + (static_cast(1.0) - beta2_) * g * g; + + MT mom2_max_; + if (amsgrad_) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out_[i] = mom2_max_; + } else { + mom2_max_ = mom2; + } + p -= lr_orig * coeff_ * p; - p -= lr * (mom1 / (sqrt(mom2) + + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); // Write back to global memory diff --git a/paddle/phi/kernels/funcs/jit/kernel_base.h b/paddle/phi/kernels/funcs/jit/kernel_base.h index b5467e611f2494..e0c35a51644eb3 100644 --- a/paddle/phi/kernels/funcs/jit/kernel_base.h +++ b/paddle/phi/kernels/funcs/jit/kernel_base.h @@ -309,9 +309,12 @@ struct AdamWTuple { const T*, const T*, const T*, + const T*, + T*, T*, T*, - T*); + T*, + bool); }; typedef struct matmul_attr_s { diff --git a/paddle/phi/kernels/funcs/jit/refer/refer.h b/paddle/phi/kernels/funcs/jit/refer/refer.h index a98b925dec8ea0..82c17350e7d438 100644 --- a/paddle/phi/kernels/funcs/jit/refer/refer.h +++ b/paddle/phi/kernels/funcs/jit/refer/refer.h @@ -561,17 +561,28 @@ void AdamW(T beta1, const T* grad_ptr, const T* mom1_ptr, const T* mom2_ptr, + const T* mom2_max_ptr, const T* param_ptr, T* mom1_out_ptr, T* mom2_out_ptr, - T* param_out_ptr) { + T* mom2_max_out_ptr, + T* param_out_ptr, + bool amsgrad) { for (int i = 0; i < numel; ++i) { auto param_tmp = param_ptr[i] - old_lr * lr_ratio * coeff * param_ptr[i]; mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i]; mom2_out_ptr[i] = beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i]; - param_out_ptr[i] = - param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps)); + + T mom2; + if (amsgrad) { + mom2 = std::max(mom2_out_ptr[i], mom2_max_out_ptr[i]); + mom2_max_out_ptr[i] = mom2; + } else { + mom2 = mom2_out_ptr[i]; + } + + param_out_ptr[i] = param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2) + eps)); } } diff --git a/paddle/phi/kernels/fused_adam_kernel.h b/paddle/phi/kernels/fused_adam_kernel.h index b44c7250d148ff..16944abdb8b1a1 100644 --- a/paddle/phi/kernels/fused_adam_kernel.h +++ b/paddle/phi/kernels/fused_adam_kernel.h @@ -27,6 +27,7 @@ void FusedAdamKernel( const DenseTensor &learning_rate, const std::vector &moments1, const std::vector &moments2, + const std::vector &moments2_max, const std::vector &beta1_pows, const std::vector &beta2_pows, const paddle::optional> &master_params, @@ -39,9 +40,11 @@ void FusedAdamKernel( bool use_adamw, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector params_out, std::vector moments1_out, std::vector moments2_out, + std::vector moments2_max_out, std::vector beta1_pows_out, std::vector beta2_pows_out, std::vector master_params_out); diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 3adeb258bc624f..f7213d40378477 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -43,13 +43,16 @@ __global__ void AdamWKernelREG(MT beta1, MT* moment1_out, const MT* moment2, MT* moment2_out, + const MT* moment2_max, + MT* moment2_max_out, const MT* lr_, const TG* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, - int64_t ndim) { + int64_t ndim, + bool amsgrad) { MT lr = *lr_ * lr_ratio; MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; @@ -61,13 +64,23 @@ __global__ void AdamWKernelREG(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + MT mom2_max = static_cast(moment2_max[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + MT mom2_max_; + if (amsgrad) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out[id] = mom2_max_; + } else { + mom2_max_ = mom2; + } + + MT denom = + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); @@ -92,13 +105,16 @@ __global__ void AdamWKernelMEM(MT beta1, MT* moment1_out, const MT* moment2, MT* moment2_out, + const MT* moment2_max, + MT* moment2_max_out, const MT* lr_, const TG* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, - int64_t ndim) { + int64_t ndim, + bool amsgrad) { MT lr = *lr_ * lr_ratio; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; @@ -110,13 +126,23 @@ __global__ void AdamWKernelMEM(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + MT mom2_max = static_cast(moment2_max[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + MT mom2_max_; + if (amsgrad) { + mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out[id] = mom2_max_; + } else { + mom2_max_ = mom2; + } + + MT denom = + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); @@ -147,6 +173,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -161,9 +188,11 @@ void AdamwDenseKernel(const Context& dev_ctx, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -196,6 +225,7 @@ void AdamwDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -254,13 +284,16 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad.data(), param.data(), dev_ctx.template Alloc(param_out), master_in_data, master_out_data, - param.numel()); + param.numel(), + amsgrad); } else { AdamWKernelREG<<>>( beta1_, @@ -274,13 +307,16 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad.data(), param.data(), dev_ctx.template Alloc(param_out), master_in_data, master_out_data, - param.numel()); + param.numel(), + amsgrad); } if (!use_global_beta_pow) { // Cpu update @@ -304,13 +340,16 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad.data(), param.data(), dev_ctx.template Alloc(param_out), master_in_data, master_out_data, - param.numel()); + param.numel(), + amsgrad); } else { AdamWKernelMEM<<>>( beta1_, @@ -324,13 +363,16 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad.data(), param.data(), dev_ctx.template Alloc(param_out), master_in_data, master_out_data, - param.numel()); + param.numel(), + amsgrad); } if (!use_global_beta_pow) { // Update with gpu diff --git a/paddle/phi/kernels/gpu/fused_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu index a7b49ddea5d25c..1d9e91c4b5169c 100644 --- a/paddle/phi/kernels/gpu/fused_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -268,6 +268,7 @@ void FusedAdamKernel( const DenseTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, + const std::vector& moments2_max, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, @@ -280,9 +281,11 @@ void FusedAdamKernel( bool use_adamw, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, std::vector params_out, std::vector moments1_out, std::vector moments2_out, + std::vector moments2_max_out, std::vector beta1_pows_out, std::vector beta2_pows_out, std::vector master_params_out) { @@ -316,6 +319,7 @@ void FusedAdamKernel( CopyTensorIfDifferent(dev_ctx, params, params_out); CopyTensorIfDifferent(dev_ctx, moments1, moments1_out); CopyTensorIfDifferent(dev_ctx, moments2, moments2_out); + CopyTensorIfDifferent(dev_ctx, moments2_max, moments2_max_out); CopyTensorIfDifferent(dev_ctx, beta1_pows, beta1_pows_out, true); CopyTensorIfDifferent(dev_ctx, beta2_pows, beta2_pows_out, true); if (master_params) { @@ -351,6 +355,7 @@ void FusedAdamKernel( input_vector.push_back(params_out); input_vector.push_back(moments1_out); input_vector.push_back(moments2_out); + input_vector.push_back(moments2_max_out); if (multi_precision) { input_vector.push_back(master_params_out); } @@ -438,6 +443,8 @@ void FusedAdamKernel( int vec_size = GetVecSizeFromTensors(params_out); vec_size = GetVecSizeFromTensors(moments1_out, vec_size); vec_size = GetVecSizeFromTensors(moments2_out, vec_size); + // TODO(megemini): + vec_size = GetVecSizeFromTensors(moments2_max_out, vec_size); if (master_params) { vec_size = GetVecSizeFromTensors(master_params_out, vec_size); } diff --git a/paddle/phi/kernels/selected_rows/adam_kernel.h b/paddle/phi/kernels/selected_rows/adam_kernel.h index 79f87a8ed75c0c..2ac909903a4089 100644 --- a/paddle/phi/kernels/selected_rows/adam_kernel.h +++ b/paddle/phi/kernels/selected_rows/adam_kernel.h @@ -29,6 +29,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -40,9 +41,11 @@ void AdamDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs); diff --git a/paddle/phi/kernels/selected_rows/adamw_kernel.h b/paddle/phi/kernels/selected_rows/adamw_kernel.h index 5dda8107d52e3e..25321c87b321dd 100644 --- a/paddle/phi/kernels/selected_rows/adamw_kernel.h +++ b/paddle/phi/kernels/selected_rows/adamw_kernel.h @@ -29,6 +29,7 @@ void AdamwDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -43,9 +44,11 @@ void AdamwDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs); diff --git a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc index f6b4db05abd3bf..ab98fd298e7475 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc @@ -37,6 +37,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param UNUSED, @@ -48,9 +49,11 @@ void AdamDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision UNUSED, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs UNUSED) { @@ -74,6 +77,7 @@ void AdamDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out); @@ -147,6 +151,8 @@ void AdamDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad_data, param.data(), @@ -154,7 +160,8 @@ void AdamDenseParamSparseGradKernel( rows, row_numel, grad_merge.rows().size(), - lazy_mode); + lazy_mode, + amsgrad); // update beta1 and beta2 if (!use_global_beta_pow) { dev_ctx.template Alloc(beta1_pow_out)[0] = diff --git a/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc index b7d8b18324de22..9b7197a3e95e9d 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc @@ -34,6 +34,7 @@ void AdamwDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -48,9 +49,11 @@ void AdamwDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -74,6 +77,7 @@ void AdamwDenseParamSparseGradKernel( learning_rate, moment1, moment2, + moment2_max, beta1_pow, beta2_pow, master_param, @@ -85,9 +89,11 @@ void AdamwDenseParamSparseGradKernel( min_row_size_to_use_multithread, multi_precision, use_global_beta_pow, + amsgrad, param_out, moment1_out, moment2_out, + moment2_max_out, beta1_pow_out, beta2_pow_out, master_param_outs); @@ -111,6 +117,7 @@ void AdamwDenseParamSparseGradKernel( learning_rate, moment1, moment2, + moment2_max, beta1_pow, beta2_pow, master_param, @@ -122,9 +129,11 @@ void AdamwDenseParamSparseGradKernel( min_row_size_to_use_multithread, multi_precision, use_global_beta_pow, + amsgrad, param_out, moment1_out, moment2_out, + moment2_max_out, beta1_pow_out, beta2_pow_out, master_param_outs); diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index 084721a721ee56..ff2e6285110e82 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -49,6 +49,8 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, MT* mom1_out_, const MT* mom2_, MT* mom2_out_, + const MT* mom2_max_, + MT* mom2_max_out_, const MT* lr_, const T* grad_, const T* param_, @@ -59,7 +61,8 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, int64_t row_numel, int64_t row_count, bool lazy_mode, - int ndim) { + int ndim, + bool amsgrad) { int id = blockIdx.x * blockDim.x + threadIdx.x; MT lr = *lr_; @@ -71,6 +74,7 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, } else { MT mom1 = mom1_[id]; MT mom2 = mom2_[id]; + MT mom2_max = mom2_max_[id]; MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 ? static_cast(grad_[row_idx * row_numel + id % row_numel]) @@ -78,8 +82,16 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT denom = - (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + MT moment2_max_; + if (amsgrad) { + moment2_max_ = std::max(mom2, mom2_max); + mom2_max_out_[id] = moment2_max_; + } else { + moment2_max_ = mom2; + } + + MT denom = (sqrt(moment2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); // Write back to global memory @@ -101,6 +113,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -112,9 +125,11 @@ void AdamDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -140,6 +155,7 @@ void AdamDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -222,6 +238,8 @@ void AdamDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad_data, param.data(), @@ -232,7 +250,8 @@ void AdamDenseParamSparseGradKernel( row_numel, grad_merge.rows().size(), lazy_mode, - ndim); + ndim, + amsgrad); if (!use_global_beta_pow) { // Update with cpu dev_ctx.template HostAlloc(beta1_pow_out)[0] = @@ -251,6 +270,8 @@ void AdamDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad_data, param.data(), @@ -260,7 +281,8 @@ void AdamDenseParamSparseGradKernel( rows, row_numel, grad_merge.rows().size(), - lazy_mode); + lazy_mode, + amsgrad); // FIXME(minqiyang): remove BinarySearch in GPU later funcs::ForRange for_range(dev_ctx, param.numel()); @@ -289,9 +311,9 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, double, phi::dtype::float16) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); if (kernel_key.dtype() == phi::DataType::FLOAT16) { kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); @@ -299,7 +321,8 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index ee7eab855220aa..a6ac2ef3aa8114 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -55,6 +55,8 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, MT* mom1_out_, const MT* mom2_, MT* mom2_out_, + const MT* mom2_max_, + MT* mom2_max_out_, const MT* lr_, const T* grad_, const T* param_, @@ -65,7 +67,8 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, int64_t row_numel, int64_t row_count, bool lazy_mode, - int ndim) { + int ndim, + bool amsgrad) { int id = blockIdx.x * blockDim.x + threadIdx.x; MT lr = *lr_ * lr_ratio; @@ -77,6 +80,7 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, } else { MT mom1 = static_cast(mom1_[id]); MT mom2 = static_cast(mom2_[id]); + MT mom2_max = static_cast(mom2_max_[id]); MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 @@ -88,8 +92,16 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT mom2_max_; + if (amsgrad) { + mom2_max_ = std::max(mom2, mom2_max); + mom2_max_out_[id] = mom2_max_; + } else { + mom2_max_ = mom2; + } + MT denom = - (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); @@ -112,6 +124,7 @@ void AdamwDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -126,9 +139,11 @@ void AdamwDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -158,6 +173,7 @@ void AdamwDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -247,6 +263,8 @@ void AdamwDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad_data, param.data(), @@ -257,7 +275,8 @@ void AdamwDenseParamSparseGradKernel( row_numel, grad_merge.rows().size(), lazy_mode, - ndim); + ndim, + amsgrad); if (!use_global_beta_pow) { // Update with cpu dev_ctx.template HostAlloc(beta1_pow_out)[0] = @@ -278,6 +297,8 @@ void AdamwDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), + moment2_max.data(), + dev_ctx.template Alloc(moment2_max_out), learning_rate.data(), grad_data, param.data(), @@ -287,7 +308,8 @@ void AdamwDenseParamSparseGradKernel( rows, row_numel, grad_merge.rows().size(), - lazy_mode); + lazy_mode, + amsgrad); // FIXME(minqiyang): remove BinarySearch in GPU later funcs::ForRange for_range(dev_ctx, param.numel()); @@ -316,9 +338,9 @@ PD_REGISTER_KERNEL(adamw_dense_param_sparse_grad, double, phi::dtype::float16) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); if (kernel_key.dtype() == phi::DataType::FLOAT16) { kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); @@ -326,7 +348,8 @@ PD_REGISTER_KERNEL(adamw_dense_param_sparse_grad, kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml index 048f71a849d2d2..e8b251b0484c11 100755 --- a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml @@ -155,15 +155,15 @@ traits : paddle::dialect::ForwardOnlyTrait - op : fused_adam_ - args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow) - output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()} + args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] moments2_max, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow, bool amsgrad = false) + output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](moments2_max_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()} infer_meta : func : FusedAdamInferMeta kernel : func : fused_adam data_type : params optional : skip_update, master_params - inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) + inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (moments2_max -> moments2_max_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) - op : fused_gemm_epilogue args : (Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation) diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 5946a8cec8a796..88582b2f568f3c 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -324,15 +324,15 @@ traits : paddle::dialect::ForwardOnlyTrait - op : fused_adam_ - args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow) - output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()} + args : (Tensor[] params, Tensor[] grads, Tensor learning_rate, Tensor[] moments1, Tensor[] moments2, Tensor[] moments2_max, Tensor[] beta1_pows, Tensor[] beta2_pows, Tensor[] master_params, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, int chunk_size, float weight_decay, bool use_adamw, bool multi_precision, bool use_global_beta_pow, bool amsgrad = false) + output : Tensor[](params_out){params.size()}, Tensor[](moments1_out){params.size()}, Tensor[](moments2_out){params.size()}, Tensor[](moments2_max_out){params.size()}, Tensor[](beta1_pows_out){params.size()}, Tensor[](beta2_pows_out){params.size()}, Tensor[](master_params_out){params.size()} infer_meta : func : FusedAdamInferMeta kernel : func : fused_adam data_type : params optional : skip_update, master_params, master_params_out - inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) + inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (moments2_max -> moments2_max_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) - op : fused_gate_attention args: (Tensor query, Tensor key, Tensor query_weight, Tensor key_weight, Tensor diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index f61fd6d1a3270b..f9a14db20273d9 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -1425,10 +1425,10 @@ - op : fused_adam_(fused_adam) inputs : {params : Params, grads : Grads, learning_rate : LearningRate, moments1 : Moments1, - moments2 : Moments2, beta1_pows : Beta1Pows, beta2_pows : Beta2Pows, master_params : MasterParams, + moments2 : Moments2, moments2 : Moments2Max, beta1_pows : Beta1Pows, beta2_pows : Beta2Pows, master_params : MasterParams, skip_update : SkipUpdate} outputs : - {params_out : ParamsOut, moments1_out : Moments1Out, moments2_out : Moments2Out, + {params_out : ParamsOut, moments1_out : Moments1Out, moments2_out : Moments2Out, moments2_max_out : Moments2MaxOut, beta1_pows_out : Beta1PowsOut, beta2_pows_out : Beta2PowsOut, master_params_out : MasterParamsOut} - op : fused_attention @@ -2500,9 +2500,9 @@ - op : merged_adam_ inputs : - {param: Param, grad: Grad, learning_rate: LearningRate, moment1: Moment1, moment2: Moment2, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, master_param: MasterParam} + {param: Param, grad: Grad, learning_rate: LearningRate, moment1: Moment1, moment2: Moment2, moment2: Moment2Max, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, master_param: MasterParam} outputs : - {param_out: ParamOut, moment1_out: Moment1Out, moment2_out: Moment2Out, beta1_pow_out: Beta1PowOut, beta2_pow_out: Beta2PowOut, master_param_out: MasterParamOut} + {param_out: ParamOut, moment1_out: Moment1Out, moment2_out: Moment2Out, moment2_out: Moment2MaxOut, beta1_pow_out: Beta1PowOut, beta2_pow_out: Beta2PowOut, master_param_out: MasterParamOut} scalar : beta1 : data_type : float diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index f4af290fb2084f..116186197d8dac 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -114,8 +114,8 @@ traits : pir::SideEffectTrait - op : adamw_ - args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1 = 0.9f, Scalar beta2 = 0.999f, Scalar epsilon = 1.0e-8f, float lr_ratio = 1.0f, float coeff = 0.01f, bool with_decay = false, bool lazy_mode = false, int64_t min_row_size_to_use_multithread = 1000, bool multi_precision = false, bool use_global_beta_pow = false) - output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_out) + args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor moment2_max, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1 = 0.9f, Scalar beta2 = 0.999f, Scalar epsilon = 1.0e-8f, float lr_ratio = 1.0f, float coeff = 0.01f, bool with_decay = false, bool lazy_mode = false, int64_t min_row_size_to_use_multithread = 1000, bool multi_precision = false, bool use_global_beta_pow = false, bool amsgrad = false) + output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(moment2_max_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_out) infer_meta : func : AdamwInferMeta spmd_rule : AdamwInferSpmdDynamic @@ -123,7 +123,7 @@ func : adamw data_type : param optional : master_param, skip_update, master_param_out - inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) + inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (moment2_max -> moment2_max_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) traits : pir::SideEffectTrait - op : add_position_encoding @@ -3267,15 +3267,15 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : merged_adam_ - args : (Tensor[] param, Tensor[] grad, Tensor[] learning_rate, Tensor[] moment1, Tensor[] moment2, Tensor[] beta1_pow, Tensor[] beta2_pow, Tensor[] master_param, Scalar beta1 = 0.9f, Scalar beta2 = 0.999f, Scalar epsilon = 1.0e-8f, bool multi_precision = false, bool use_global_beta_pow = false) - output : Tensor[](param_out){param.size()}, Tensor[](moment1_out){param.size()}, Tensor[](moment2_out){param.size()}, Tensor[](beta1_pow_out){param.size()}, Tensor[](beta2_pow_out){param.size()}, Tensor[](master_param_out){param.size()} + args : (Tensor[] param, Tensor[] grad, Tensor[] learning_rate, Tensor[] moment1, Tensor[] moment2, Tensor[] moment2_max, Tensor[] beta1_pow, Tensor[] beta2_pow, Tensor[] master_param, Scalar beta1 = 0.9f, Scalar beta2 = 0.999f, Scalar epsilon = 1.0e-8f, bool multi_precision = false, bool use_global_beta_pow = false, bool amsgrad = false) + output : Tensor[](param_out){param.size()}, Tensor[](moment1_out){param.size()}, Tensor[](moment2_out){param.size()}, Tensor[](moment2_max_out){param.size()}, Tensor[](beta1_pow_out){param.size()}, Tensor[](beta2_pow_out){param.size()}, Tensor[](master_param_out){param.size()} infer_meta : func : MergedAdamInferMeta kernel : func : merged_adam data_type : param optional: master_param, master_param_out - inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) + inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (moment2_max -> moment2_max_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) traits : pir::SideEffectTrait - op : merged_momentum_ diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 87e28db080821f..6b4236f4b6cd8f 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -117,6 +117,7 @@ class Adam(Optimizer): The default value is False. multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false. + amsgrad (bool, optional): Whether to use the AMSGrad of this algorithm. Default is false. name (str|None, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -263,6 +264,7 @@ def __init__( self._master_weight_dict = self._create_multi_tensor_dict() self._master_weight_dict['FP32_LODTensor'] = None + # whether to use AMSGrad self._amsgrad = amsgrad def _add_moments_pows(self, p): @@ -373,7 +375,7 @@ def _append_optimize_op(self, block, param_and_grad): self._get_auxiliary_var('found_inf') if in_pir_mode() else None ) - _ = _C_ops.adam_( + _, _, _, _, _, _, _ = _C_ops.adam_( param_and_grad[0], param_and_grad[1], lr, @@ -425,6 +427,7 @@ def _append_optimize_op(self, block, param_and_grad): "lazy_mode": self._lazy_mode, "min_row_size_to_use_multithread": 1000, "multi_precision": find_master, + "amsgrad": self._amsgrad, } if isinstance(self._beta1, Variable): @@ -778,7 +781,7 @@ def _append_optimize_multi_tensor_op( found_inf, (core.eager.Tensor, pir.Value) ): self._set_auxiliary_var('found_inf', False) - _, _, _, _, _, _ = _C_ops.merged_adam_( + _, _, _, _, _, _, _ = _C_ops.merged_adam_( self._param_dict[key][param_group_idx], grad_dict[key], lr_dict[key], @@ -802,7 +805,7 @@ def _append_optimize_multi_tensor_op( if master_weight is not None else None ) - _, _, _, _, _, _ = _C_ops.merged_adam_( + _, _, _, _, _, _, _ = _C_ops.merged_adam_( self._param_dict[key][param_group_idx], grad_dict[key], lr_dict[key], @@ -854,6 +857,7 @@ def _append_optimize_multi_tensor_op( "epsilon": self._epsilon, "beta1": _beta1, "beta2": _beta2, + "amsgrad": self._amsgrad, } if find_master: inputs["MasterParam"] = self._master_weight_dict[key][ diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index ab24d7d9c9fed6..8fb56300b67310 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -104,6 +104,7 @@ class AdamW(Optimizer): different semantics with the original Adam algorithm and may lead to different result. The default value is False. multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. + amsgrad (bool, optional): Whether to use the AMSGrad of this algorithm. Default is false. name (str|None, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -165,6 +166,7 @@ class AdamW(Optimizer): type: str _moment1_acc_str = "moment1" _moment2_acc_str = "moment2" + _moment2_acc_max_str = "moment2_max" _beta1_pow_acc_str = "beta1_pow_acc" _beta2_pow_acc_str = "beta2_pow_acc" @@ -183,6 +185,7 @@ def __init__( grad_clip: GradientClipBase | None = None, lazy_mode: bool = False, multi_precision: bool = False, + amsgrad: bool = False, name: str | None = None, ) -> None: assert learning_rate is not None @@ -284,6 +287,8 @@ def __init__( self._lazy_mode = lazy_mode self._multi_precision = multi_precision self._master_weights = {} + # whether to use AMSGrad + self._amsgrad = amsgrad self._default_dict = { 'weight_decay': weight_decay, @@ -381,6 +386,7 @@ def _add_moments_pows(self, p): else: self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) + self._add_accumulator(self._moment2_acc_max_str, p, dtype=acc_dtype) self._add_accumulator( name=self._beta1_pow_acc_str, param=p, @@ -453,6 +459,9 @@ def _append_optimize_op(self, block, param_and_grad): moment2 = self._get_accumulator_master( self._moment2_acc_str, param_and_grad[0] ) + moment2_max = self._get_accumulator_master( + self._moment2_acc_max_str, param_and_grad[0] + ) beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] ) @@ -492,12 +501,13 @@ def _append_optimize_op(self, block, param_and_grad): self._get_auxiliary_var('found_inf') if in_pir_mode() else None ) - _, _, _, _, _, _ = _C_ops.adamw_( + _, _, _, _, _, _, _ = _C_ops.adamw_( param_and_grad[0], param_and_grad[1], lr, moment1, moment2, + moment2_max, beta1_pow_acc, beta2_pow_acc, master_weight, @@ -512,6 +522,7 @@ def _append_optimize_op(self, block, param_and_grad): 1000, find_master, False, + self._amsgrad, ) return None else: @@ -521,6 +532,7 @@ def _append_optimize_op(self, block, param_and_grad): "LearningRate": [lr], "Moment1": [moment1], "Moment2": [moment2], + "Moment2Max": [moment2_max], "Beta1Pow": [beta1_pow_acc], "Beta2Pow": [beta2_pow_acc], } @@ -535,6 +547,7 @@ def _append_optimize_op(self, block, param_and_grad): "ParamOut": [param_and_grad[0]], "Moment1Out": [moment1], "Moment2Out": [moment2], + "Moment2MaxOut": [moment2_max], "Beta1PowOut": [beta1_pow_acc], "Beta2PowOut": [beta2_pow_acc], } @@ -549,6 +562,7 @@ def _append_optimize_op(self, block, param_and_grad): if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) ), + "amsgrad": self._amsgrad, } if isinstance(self._beta1, Variable): From 106f817fe623cb8565a90909f48e6dbd09045f98 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 5 Sep 2024 12:09:39 +0800 Subject: [PATCH 05/33] [Fix] adamw gpu kernel --- paddle/phi/kernels/gpu/adamw_kernel.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index f7213d40378477..ed93e810652be2 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -398,9 +398,9 @@ PD_REGISTER_KERNEL(adamw, phi::dtype::float16, phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); if (kernel_key.dtype() == phi::DataType::FLOAT16 || kernel_key.dtype() == phi::DataType::BFLOAT16) { @@ -409,7 +409,8 @@ PD_REGISTER_KERNEL(adamw, kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } From fddb46a3d31c8c44bde3d81e678f21baacb7085f Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 5 Sep 2024 14:56:48 +0800 Subject: [PATCH 06/33] [Update] fused adam kernel for gpu --- paddle/phi/kernels/funcs/multi_tensor_apply.h | 2 +- paddle/phi/kernels/gpu/fused_adam_kernel.cu | 119 ++++++++++++++---- 2 files changed, 95 insertions(+), 26 deletions(-) diff --git a/paddle/phi/kernels/funcs/multi_tensor_apply.h b/paddle/phi/kernels/funcs/multi_tensor_apply.h index 6811793c02dcb2..6fe90864881381 100644 --- a/paddle/phi/kernels/funcs/multi_tensor_apply.h +++ b/paddle/phi/kernels/funcs/multi_tensor_apply.h @@ -76,7 +76,7 @@ void LaunchMultiTensorApplyKernel( errors::InvalidArgument( "input_vector.size() != InputNum - 1, the input vector's size is " "unequal to InputNum - 1, please cheack grads, params, momemts1, " - "moments2, and, master_params.")); + "moments2, moments2_max, and, master_params.")); size_t length = input_vector[0].size(); PADDLE_ENFORCE_GT( length, diff --git a/paddle/phi/kernels/gpu/fused_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu index 1d9e91c4b5169c..55b71a239f1098 100644 --- a/paddle/phi/kernels/gpu/fused_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +// TODO(megemini): + #include "paddle/phi/kernels/fused_adam_kernel.h" #include #include "glog/logging.h" @@ -70,6 +72,7 @@ template @@ -88,7 +91,9 @@ struct FusedAdamFunctor { MT beta2_pow = beta_pow.GetBeta2PowValue(); T* __restrict__ p_ptr; const T* __restrict__ g_ptr; - MT* __restrict__ mom1_ptr, * __restrict__ mom2_ptr; + MT* __restrict__ mom1_ptr; + MT* __restrict__ mom2_ptr; + MT* __restrict__ mom2_max_ptr; MT* __restrict__ mp_ptr; int n; @@ -102,9 +107,11 @@ struct FusedAdamFunctor { p_ptr = static_cast(t_info.tensor_addrs[0][tensor_id]) + offset; mom1_ptr = static_cast(t_info.tensor_addrs[1][tensor_id]) + offset; mom2_ptr = static_cast(t_info.tensor_addrs[2][tensor_id]) + offset; + mom2_max_ptr = + static_cast(t_info.tensor_addrs[3][tensor_id]) + offset; mp_ptr = IsMultiPrecision - ? static_cast(t_info.tensor_addrs[3][tensor_id]) + offset + ? static_cast(t_info.tensor_addrs[4][tensor_id]) + offset : nullptr; n -= offset; @@ -122,6 +129,7 @@ struct FusedAdamFunctor { phi::AlignedVector mp_vec; phi::AlignedVector mom1_vec; phi::AlignedVector mom2_vec; + phi::AlignedVector mom2_max_vec; if (idx <= n - VecSize) { if (IsMultiPrecision) { phi::Load(mp_ptr + idx, &mp_vec); @@ -131,6 +139,7 @@ struct FusedAdamFunctor { phi::Load(g_ptr + idx, &g_vec); phi::Load(mom1_ptr + idx, &mom1_vec); phi::Load(mom2_ptr + idx, &mom2_vec); + phi::Load(mom2_max_ptr + idx, &mom2_max_vec); } else { int size = n - idx; for (int j = 0; j < size; j++) { @@ -142,6 +151,7 @@ struct FusedAdamFunctor { g_vec[j] = g_ptr[idx + j]; mom1_vec[j] = static_cast(mom1_ptr[idx + j]); mom2_vec[j] = static_cast(mom2_ptr[idx + j]); + mom2_max_vec[j] = static_cast(mom2_max_ptr[idx + j]); } #pragma unroll for (int j = size; j < VecSize; j++) { @@ -150,6 +160,7 @@ struct FusedAdamFunctor { mp_vec[j] = MT(0); mom1_vec[j] = MT(0); mom2_vec[j] = MT(0); + mom2_max_vec[j] = MT(0); } } @@ -158,12 +169,14 @@ struct FusedAdamFunctor { MT p = IsMultiPrecision ? mp_vec[j] : static_cast(p_vec[j]); UpdateMoments(&mom1_vec[j], &mom2_vec[j], + &mom2_max_vec[j], static_cast(g_vec[j]), beta1, beta2); mp_vec[j] = UpdateParameter(p, mom1_vec[j], mom2_vec[j], + mom2_max_vec[j], beta1_pow, beta2_pow, lr, @@ -174,6 +187,7 @@ struct FusedAdamFunctor { if (idx <= n - VecSize) { phi::Store(mom1_vec, mom1_ptr + idx); phi::Store(mom2_vec, mom2_ptr + idx); + phi::Store(mom2_max_vec, mom2_max_ptr + idx); if (IsMultiPrecision) { phi::Store(mp_vec, mp_ptr + idx); } @@ -189,6 +203,7 @@ struct FusedAdamFunctor { p_ptr[idx + j] = static_cast(mp_vec[j]); mom1_ptr[idx + j] = mom1_vec[j]; mom2_ptr[idx + j] = mom2_vec[j]; + mom2_max_ptr[idx + j] = mom2_max_vec[j]; } } } @@ -198,21 +213,26 @@ struct FusedAdamFunctor { static __device__ __forceinline__ void UpdateMoments( MT* __restrict__ mom1_ptr, MT* __restrict__ mom2_ptr, + MT* __restrict__ mom2_max_ptr, MT g, MT beta1, MT beta2) { MT mom1 = static_cast(mom1_ptr[0]); MT mom2 = static_cast(mom2_ptr[0]); + MT mom2_max = static_cast(mom2_max_ptr[0]); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; mom1_ptr[0] = mom1; mom2_ptr[0] = mom2; + mom2_max_ptr[0] = std::max(mom2, mom2_max); } static __device__ __forceinline__ MT UpdateParameter(MT p, MT mom1, MT mom2, + MT mom2_max, MT beta1_pow, MT beta2_pow, MT lr, @@ -221,7 +241,15 @@ struct FusedAdamFunctor { if (UseAdamW) { p *= (static_cast(1.0) - lr * decay); } - MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + + MT denom; + if (AMSGrad) { + denom = + (sqrt(mom2_max) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + } else { + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + } + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); return p; } @@ -350,7 +378,7 @@ void FusedAdamKernel( MPDType beta2_tmp = beta2.to(); std::vector> input_vector; - input_vector.reserve(4); + input_vector.reserve(5); input_vector.push_back(params_out); input_vector.push_back(moments1_out); @@ -364,9 +392,9 @@ void FusedAdamKernel( VLOG(4) << "multi_precision: " << multi_precision; #define PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - __multi_precision, __is_cpu_betapow, __use_adamw, __vec_size) \ + __multi_precision, __is_cpu_betapow, __use_adamw, __amsgrad, __vec_size) \ do { \ - constexpr int kInputNum = __multi_precision ? 5 : 4; \ + constexpr int kInputNum = __multi_precision ? 6 : 5; \ constexpr int kMaxTensorSize = __multi_precision ? 48 : 60; \ constexpr int kMaxBlockSize = __multi_precision ? 320 : 320; \ constexpr int kBlockSize = 512; \ @@ -378,6 +406,7 @@ void FusedAdamKernel( __multi_precision, \ __is_cpu_betapow, \ __use_adamw, \ + __amsgrad, \ kInputNum, \ kMaxTensorSize, \ kMaxBlockSize> \ @@ -404,37 +433,77 @@ void FusedAdamKernel( if (multi_precision) { \ if (is_cpu_betapow) { \ if (use_adamw) { \ - PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - true, true, true, __vec_size); \ + if (amsgrad) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, true, true, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, true, true, false, __vec_size); \ + } \ } else { \ - PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - true, true, false, __vec_size); \ + if (amsgrad) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, true, false, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, true, false, false, __vec_size); \ + } \ } \ } else { \ if (use_adamw) { \ - PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - true, false, true, __vec_size); \ + if (amsgrad) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, false, true, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, false, true, false, __vec_size); \ + } \ } else { \ - PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - true, false, false, __vec_size); \ + if (amsgrad) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, false, false, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + true, false, false, false, __vec_size); \ + } \ } \ } \ } else { \ if (is_cpu_betapow) { \ if (use_adamw) { \ - PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - false, true, true, __vec_size); \ + if (amsgrad) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, true, true, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, true, true, false, __vec_size); \ + } \ } else { \ - PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - false, true, false, __vec_size); \ + if (amsgrad) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, true, false, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, true, false, false, __vec_size); \ + } \ } \ } else { \ if (use_adamw) { \ - PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - false, false, true, __vec_size); \ + if (amsgrad) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, false, true, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, false, true, false, __vec_size); \ + } \ } else { \ - PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ - false, false, false, __vec_size); \ + if (amsgrad) { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, false, false, true, __vec_size); \ + } else { \ + PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ + false, false, false, false, __vec_size); \ + } \ } \ } \ } \ @@ -443,7 +512,6 @@ void FusedAdamKernel( int vec_size = GetVecSizeFromTensors(params_out); vec_size = GetVecSizeFromTensors(moments1_out, vec_size); vec_size = GetVecSizeFromTensors(moments2_out, vec_size); - // TODO(megemini): vec_size = GetVecSizeFromTensors(moments2_max_out, vec_size); if (master_params) { vec_size = GetVecSizeFromTensors(master_params_out, vec_size); @@ -503,12 +571,13 @@ PD_REGISTER_KERNEL(fused_adam, float, double) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(3).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(4).SetDataType(phi::DataType::UNDEFINED); kernel->OutputAt(5).SetDataType(phi::DataType::UNDEFINED); + kernel->OutputAt(6).SetDataType(phi::DataType::UNDEFINED); } From d20644293270da976cd078b18e37dbf8e78f01bd Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 5 Sep 2024 20:04:00 +0800 Subject: [PATCH 07/33] [Update] xpu adam/adamw param list --- paddle/phi/kernels/gpu/fused_adam_kernel.cu | 2 -- .../phi/kernels/selected_rows/xpu/adam_kernel.cc | 9 ++++++--- paddle/phi/kernels/xpu/adam_kernel.cc | 16 +++++++++++----- paddle/phi/kernels/xpu/adamw_kernel.cc | 10 +++++++--- python/paddle/optimizer/adamw.py | 8 ++++++++ 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/paddle/phi/kernels/gpu/fused_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu index 55b71a239f1098..a1f8010b314fae 100644 --- a/paddle/phi/kernels/gpu/fused_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// TODO(megemini): - #include "paddle/phi/kernels/fused_adam_kernel.h" #include #include "glog/logging.h" diff --git a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc index 47cd016506c004..232a67e79ec454 100644 --- a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc @@ -34,6 +34,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max UNUSED, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -45,9 +46,11 @@ void AdamDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad UNUSED, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out UNUSED, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -347,9 +350,9 @@ PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, float, phi::dtype::float16) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc index a9c7e497567c1e..609b4133be079a 100644 --- a/paddle/phi/kernels/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -32,6 +32,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max UNUSED, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -43,9 +44,11 @@ void AdamDenseKernel(const Context& dev_ctx, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad UNUSED, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out UNUSED, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -261,6 +264,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, + const std::vector& moment2_max UNUSED, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -269,9 +273,11 @@ void MergedAdamKernel( const Scalar& epsilon, bool multi_precision, bool use_global_beta_pow, + bool amsgrad UNUSED, std::vector param_out, std::vector moment1_out, std::vector moment2_out, + std::vector moment2_max_out UNUSED, std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out) { @@ -480,18 +486,18 @@ void MergedAdamKernel( PD_REGISTER_KERNEL( adam, XPU, ALL_LAYOUT, phi::AdamDenseKernel, float, phi::dtype::float16) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } PD_REGISTER_KERNEL(merged_adam, XPU, ALL_LAYOUT, phi::MergedAdamKernel, float) { // Skip beta1_pow, beta2_pow data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index 72c1c5d578eaf4..efb0c19b11265a 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -483,6 +483,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, + const DenseTensor& moment2_max UNUSED, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -497,9 +498,11 @@ void AdamwDenseKernel(const Context& dev_ctx, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, + bool amsgrad UNUSED, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, + DenseTensor* moment2_max_out UNUSED, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -885,9 +888,9 @@ PD_REGISTER_KERNEL(adamw, phi::dtype::float16, phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); - kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); if (kernel_key.dtype() == phi::DataType::FLOAT16 || kernel_key.dtype() == phi::DataType::BFLOAT16) { @@ -896,7 +899,8 @@ PD_REGISTER_KERNEL(adamw, kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); } - kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); } diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 8fb56300b67310..c5a4f8334c1e5f 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -380,9 +380,17 @@ def _add_moments_pows(self, p): self._add_accumulator( self._moment2_acc_str, p, dtype=core.VarDesc.VarType.FP16 ) + self._add_accumulator( + self._moment2_acc_max_str, + p, + dtype=core.VarDesc.VarType.FP16, + ) else: self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) + self._add_accumulator( + self._moment2_acc_max_str, p, dtype=acc_dtype + ) else: self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) From 8cc9b5b95d8af90c02ffc6347521477092024151 Mon Sep 17 00:00:00 2001 From: megemini Date: Fri, 6 Sep 2024 15:12:32 +0800 Subject: [PATCH 08/33] [Update] tests for amsgrad --- paddle/phi/kernels/funcs/adam_functors.h | 69 +++-- paddle/phi/kernels/gpu/adam_kernel.cu | 24 +- paddle/phi/kernels/gpu/adamw_kernel.cu | 26 +- paddle/phi/kernels/gpu/fused_adam_kernel.cu | 7 +- .../kernels/selected_rows/gpu/adam_kernel.cu | 12 +- .../kernels/selected_rows/gpu/adamw_kernel.cu | 13 +- .../hybrid_parallel_sharding_state_dict.py | 2 +- .../cpp/phi/kernels/test_fused_adam_kernel.cc | 51 +++- test/legacy_test/test_adam_op.py | 147 ++++++++-- test/legacy_test/test_adamw_op.py | 258 ++++++++++++------ test/xpu/test_adam_op_xpu.py | 56 +++- test/xpu/test_adamw_op_xpu.py | 206 +++++++++----- test/xpu/test_merged_adam_op_xpu.py | 21 +- 13 files changed, 623 insertions(+), 269 deletions(-) diff --git a/paddle/phi/kernels/funcs/adam_functors.h b/paddle/phi/kernels/funcs/adam_functors.h index 57bdeeee56dd45..3143cf65def218 100644 --- a/paddle/phi/kernels/funcs/adam_functors.h +++ b/paddle/phi/kernels/funcs/adam_functors.h @@ -221,7 +221,6 @@ class AdamFunctor { T g = grad_[i]; T mom1 = moment1_[i]; T mom2 = moment2_[i]; - T mom2_max = moment2_max_[i]; T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; @@ -233,16 +232,15 @@ class AdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - T mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + T mom2_max = moment2_max_[i]; + T mom2_max_ = std::max(mom2, mom2_max); moment2_max_out_[i] = mom2_max_; + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); } else { - mom2_max_ = mom2; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); } - p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); - // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; @@ -312,8 +310,6 @@ class AdamFunctor { moment1_, static_cast(numel)}; Eigen::Map> mom2{ moment2_, static_cast(numel)}; - Eigen::Map> mom2_max{ - moment2_max_, static_cast(numel)}; Eigen::Map> param{ param_, static_cast(numel)}; @@ -323,8 +319,6 @@ class AdamFunctor { moment1_out_, static_cast(numel)}; Eigen::Map> moment2_out{ moment2_out_, static_cast(numel)}; - Eigen::Map> moment2_max_out{ - moment2_max_out_, static_cast(numel)}; T lr = *lr_; T beta1_pow = *beta1_pow_; @@ -337,6 +331,11 @@ class AdamFunctor { moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g; if (amsgrad_) { + Eigen::Map> mom2_max{ + moment2_max_, static_cast(numel)}; + Eigen::Map> moment2_max_out{ + moment2_max_out_, static_cast(numel)}; + moment2_max_out = moment2_out.cwiseMax(mom2_max); param_out = param - lr * (moment1_out / (moment2_max_out.sqrt() + epsilon_ * sqrt(1 - beta2_pow))); @@ -428,7 +427,6 @@ class SparseAdamFunctor { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; - MT mom2_max = moment2_max_[i]; MT lr = *lr_; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; @@ -441,17 +439,19 @@ class SparseAdamFunctor { mom1 = beta1_ * mom1 + (static_cast(1.0) - beta1_) * g; mom2 = beta2_ * mom2 + (static_cast(1.0) - beta2_) * g * g; - MT mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = moment2_max_[i]; + MT mom2_max_ = std::max(mom2, mom2_max); moment2_max_out_[i] = mom2_max_; + + p -= lr * (mom1 / (sqrt(mom2_max_) + + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); + } else { - mom2_max_ = mom2; + p -= lr * (mom1 / (sqrt(mom2) + + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); } - p -= lr * (mom1 / (sqrt(mom2_max_) + - epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); - // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; @@ -545,7 +545,6 @@ class SparseAdamFunctor { // The following code is the same as dense T mom1 = moment1_[i]; T mom2 = moment2_[i]; - T mom2_max = moment2_max_[i]; T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; @@ -557,16 +556,15 @@ class SparseAdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - T mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + T mom2_max = moment2_max_[i]; + T mom2_max_ = std::max(mom2, mom2_max); moment2_max_out_[i] = mom2_max_; + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); } else { - mom2_max_ = mom2; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); } - p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); - // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; @@ -592,22 +590,21 @@ class SparseAdamFunctor { for (int64_t k = 0; k != row_numel_; ++k) { T mom1 = moment1_[i * row_numel_ + k]; T mom2 = moment2_[i * row_numel_ + k]; - T mom2_max = moment2_max_[i * row_numel_ + k]; T p = param_[i * row_numel_ + k]; mom1 = beta1_ * mom1; mom2 = beta2_ * mom2; - T mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + T mom2_max = moment2_max_[i * row_numel_ + k]; + T mom2_max_ = std::max(mom2, mom2_max); moment2_max_out_[i * row_numel_ + k] = mom2_max_; + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_)); } else { - mom2_max_ = mom2; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); } - p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_)); // Write back to global memory moment1_out_[i * row_numel_ + k] = mom1; moment2_out_[i * row_numel_ + k] = mom2; @@ -734,7 +731,6 @@ class SparseAdamWFunctor { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; - MT mom2_max = moment2_max_[i]; MT lr = *lr_ * lr_ratio_; MT lr_orig = lr; MT beta1_pow = *beta1_pow_; @@ -748,18 +744,19 @@ class SparseAdamWFunctor { mom1 = beta1_ * mom1 + (static_cast(1.0) - beta1_) * g; mom2 = beta2_ * mom2 + (static_cast(1.0) - beta2_) * g * g; - MT mom2_max_; + p -= lr_orig * coeff_ * p; + if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = moment2_max_[i]; + MT mom2_max_ = std::max(mom2, mom2_max); moment2_max_out_[i] = mom2_max_; + p -= lr * (mom1 / (sqrt(mom2_max_) + + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); } else { - mom2_max_ = mom2; + p -= lr * (mom1 / (sqrt(mom2) + + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); } - p -= lr_orig * coeff_ * p; - p -= lr * (mom1 / (sqrt(mom2_max_) + - epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); - // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 2a7613b628eb94..2b68d917085ccf 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -61,21 +61,21 @@ __global__ void AdamKernelREG(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); - MT mom2_max = static_cast(moment2_max[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; + MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(moment2_max[id]); + MT mom2_max_ = std::max(mom2, mom2_max); moment2_max_out[id] = mom2_max_; + denom = + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } - MT denom = - (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; @@ -118,21 +118,21 @@ __global__ void AdamKernelMEM(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); - MT mom2_max = static_cast(moment2_max[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; + MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(moment2_max[id]); + MT mom2_max_ = std::max(mom2, mom2_max); moment2_max_out[id] = mom2_max_; + denom = + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } - MT denom = - (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index ed93e810652be2..5df2568d4f12a3 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -64,24 +64,23 @@ __global__ void AdamWKernelREG(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); - MT mom2_max = static_cast(moment2_max[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; + MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(moment2_max[id]); + MT mom2_max_ = std::max(mom2, mom2_max); moment2_max_out[id] = mom2_max_; + denom = + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } - MT denom = - (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; - p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; @@ -126,24 +125,23 @@ __global__ void AdamWKernelMEM(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); - MT mom2_max = static_cast(moment2_max[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; + MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(moment2_max[id]); + MT mom2_max_ = std::max(mom2, mom2_max); moment2_max_out[id] = mom2_max_; + denom = + (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } - MT denom = - (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; - p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; diff --git a/paddle/phi/kernels/gpu/fused_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu index a1f8010b314fae..094768c3e7caff 100644 --- a/paddle/phi/kernels/gpu/fused_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -217,14 +217,17 @@ struct FusedAdamFunctor { MT beta2) { MT mom1 = static_cast(mom1_ptr[0]); MT mom2 = static_cast(mom2_ptr[0]); - MT mom2_max = static_cast(mom2_max_ptr[0]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; mom1_ptr[0] = mom1; mom2_ptr[0] = mom2; - mom2_max_ptr[0] = std::max(mom2, mom2_max); + + if (AMSGrad) { + MT mom2_max = static_cast(mom2_max_ptr[0]); + mom2_max_ptr[0] = std::max(mom2, mom2_max); + } } static __device__ __forceinline__ MT UpdateParameter(MT p, diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index ff2e6285110e82..c7c3a4e14c1a03 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -74,7 +74,6 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, } else { MT mom1 = mom1_[id]; MT mom2 = mom2_[id]; - MT mom2_max = mom2_max_[id]; MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 ? static_cast(grad_[row_idx * row_numel + id % row_numel]) @@ -82,16 +81,17 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT moment2_max_; + MT denom; if (amsgrad) { - moment2_max_ = std::max(mom2, mom2_max); + MT mom2_max = mom2_max_[id]; + MT moment2_max_ = std::max(mom2, mom2_max); mom2_max_out_[id] = moment2_max_; + denom = (sqrt(moment2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + + epsilon; } else { - moment2_max_ = mom2; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } - MT denom = (sqrt(moment2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + - epsilon; p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); // Write back to global memory diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index a6ac2ef3aa8114..a428d98d2c01a0 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -80,7 +80,6 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, } else { MT mom1 = static_cast(mom1_[id]); MT mom2 = static_cast(mom2_[id]); - MT mom2_max = static_cast(mom2_max_[id]); MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 @@ -92,17 +91,17 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; + MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(mom2_max_[id]); + MT mom2_max_ = std::max(mom2, mom2_max); mom2_max_out_[id] = mom2_max_; + denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + + epsilon; } else { - mom2_max_ = mom2; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } - MT denom = - (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; - p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); // Write back to global memory diff --git a/test/collective/fleet/hybrid_parallel_sharding_state_dict.py b/test/collective/fleet/hybrid_parallel_sharding_state_dict.py index b9a5f55dc188ab..1a51407ddb5f39 100644 --- a/test/collective/fleet/hybrid_parallel_sharding_state_dict.py +++ b/test/collective/fleet/hybrid_parallel_sharding_state_dict.py @@ -146,7 +146,7 @@ def test_set_state_dict(self): # master_weights and accumulators state_dict["master_weights"] = {} all_param_names = [] - accumulator_names = ["moment1", "moment2"] + accumulator_names = ["moment1", "moment2", "moment2_max"] # local_params = dist_optimizer._rank2params[ dist_optimizer._sharding_rank diff --git a/test/cpp/phi/kernels/test_fused_adam_kernel.cc b/test/cpp/phi/kernels/test_fused_adam_kernel.cc index ec0926508c9e89..1b15c33481e22f 100644 --- a/test/cpp/phi/kernels/test_fused_adam_kernel.cc +++ b/test/cpp/phi/kernels/test_fused_adam_kernel.cc @@ -126,6 +126,7 @@ struct AdamInfo { std::vector master_params; std::vector moment1s; std::vector moment2s; + std::vector moment2s_max; std::vector beta1_pows; std::vector beta2_pows; DenseTensor learning_rate; @@ -136,6 +137,7 @@ struct AdamInfo { bool multi_precision; bool use_adamw; int chunk_size = 4096; + bool amsgrad; using MT = typename phi::dtype::MPTypeTrait::Type; @@ -145,14 +147,16 @@ struct AdamInfo { float beta2, float weight_decay, bool multi_precision, - bool use_adamw) + bool use_adamw, + bool amsgrad) : ctx(&ctx_ref), shapes(shapes), beta1(beta1), beta2(beta2), weight_decay(weight_decay), multi_precision(multi_precision), - use_adamw(use_adamw) { + use_adamw(use_adamw), + amsgrad(amsgrad) { std::vector> one_shapes(shapes.size(), std::vector(1, 1)); std::vector> learning_rate_shapes( @@ -163,6 +167,7 @@ struct AdamInfo { *ctx, learning_rate_shapes, 1e-3)[0]; moment1s = GenerateConstantTensorVectors(*ctx, shapes, 0); moment2s = GenerateConstantTensorVectors(*ctx, shapes, 0); + moment2s_max = GenerateConstantTensorVectors(*ctx, shapes, 0); if (multi_precision) { master_params.resize(shapes.size()); @@ -199,7 +204,8 @@ struct AdamInfo { other.beta2, other.weight_decay, other.multi_precision, - other.use_adamw); + other.use_adamw, + other.amsgrad); auto copy_tensor = [&other](const DenseTensor &x, DenseTensor *y) { Copy(*other.ctx, x, x.place(), false, y); }; @@ -215,6 +221,7 @@ struct AdamInfo { copy_tensors(other.master_params, &copied.master_params); copy_tensors(other.moment1s, &copied.moment1s); copy_tensors(other.moment2s, &copied.moment2s); + copy_tensors(other.moment2s_max, &copied.moment2s_max); copy_tensors(other.beta1_pows, &copied.beta1_pows); copy_tensors(other.beta2_pows, &copied.beta2_pows); copy_tensor(other.learning_rate, &copied.learning_rate); @@ -231,6 +238,7 @@ struct AdamInfo { auto master_param_metas = ToMetaTensorVector(master_params); auto moment1_metas = ToMetaTensorVector(moment1s); auto moment2_metas = ToMetaTensorVector(moment2s); + auto moment2_max_metas = ToMetaTensorVector(moment2s_max); auto beta1_pow_metas = ToMetaTensorVector(beta1_pows); auto beta2_pow_metas = ToMetaTensorVector(beta2_pows); @@ -239,6 +247,7 @@ struct AdamInfo { learning_rate, ToConstMetaTensorPtrVector(moment1_metas), ToConstMetaTensorPtrVector(moment2_metas), + ToConstMetaTensorPtrVector(moment2_max_metas), ToConstMetaTensorPtrVector(beta1_pow_metas), ToConstMetaTensorPtrVector(beta2_pow_metas), multi_precision @@ -254,9 +263,11 @@ struct AdamInfo { use_adamw, multi_precision, false, + amsgrad, ToMutableMetaTensorPtrVector(param_metas), ToMutableMetaTensorPtrVector(moment1_metas), ToMutableMetaTensorPtrVector(moment2_metas), + ToMutableMetaTensorPtrVector(moment2_max_metas), ToMutableMetaTensorPtrVector(beta1_pow_metas), ToMutableMetaTensorPtrVector(beta2_pow_metas), ToMutableMetaTensorPtrVector(master_param_metas)); @@ -268,6 +279,7 @@ struct AdamInfo { learning_rate, ToConstTensorPtrVector(moment1s), ToConstTensorPtrVector(moment2s), + ToConstTensorPtrVector(moment2s_max), ToConstTensorPtrVector(beta1_pows), ToConstTensorPtrVector(beta2_pows), multi_precision @@ -282,9 +294,11 @@ struct AdamInfo { use_adamw, multi_precision, false, + amsgrad, ToMutableTensorPtrVector(params), ToMutableTensorPtrVector(moment1s), ToMutableTensorPtrVector(moment2s), + ToMutableTensorPtrVector(moment2s_max), ToMutableTensorPtrVector(beta1_pows), ToMutableTensorPtrVector(beta2_pows), ToMutableTensorPtrVector(master_params)); @@ -299,6 +313,7 @@ struct AdamInfo { learning_rate, moment1s[idx], moment2s[idx], + moment2s_max[idx], beta1_pows[idx], beta2_pows[idx], multi_precision ? paddle::make_optional(master_params[idx]) @@ -314,9 +329,11 @@ struct AdamInfo { 1000, multi_precision, false, + amsgrad, ¶ms[idx], &moment1s[idx], &moment2s[idx], + &moment2s_max[idx], &beta1_pows[idx], &beta2_pows[idx], multi_precision ? &master_params[idx] : nullptr); @@ -331,6 +348,7 @@ struct AdamInfo { learning_rate, moment1s[idx], moment2s[idx], + moment2s_max[idx], beta1_pows[idx], beta2_pows[idx], multi_precision ? paddle::make_optional(master_params[idx]) @@ -343,9 +361,11 @@ struct AdamInfo { 1000, multi_precision, false, + amsgrad, ¶ms[idx], &moment1s[idx], &moment2s[idx], + &moment2s_max[idx], &beta1_pows[idx], &beta2_pows[idx], multi_precision ? &master_params[idx] : nullptr); @@ -401,6 +421,7 @@ template void TestFusedAdamBase(const std::vector> &shapes, float atol, bool use_adamw, + bool amsgrad, bool multi_precision = false, float beta1 = 0.9, float beta2 = 0.99, @@ -411,8 +432,14 @@ void TestFusedAdamBase(const std::vector> &shapes, using Context = typename std::remove_const< typename std::remove_pointer::type>::type; ctx.GetGenerator()->SetCurrentSeed(seed); - AdamInfo info1( - ctx, shapes, beta1, beta2, weight_decay, multi_precision, use_adamw); + AdamInfo info1(ctx, + shapes, + beta1, + beta2, + weight_decay, + multi_precision, + use_adamw, + amsgrad); auto info2 = AdamInfo::DeepCopy(info1); for (size_t i = 0; i < steps; ++i) { @@ -437,6 +464,7 @@ void TestFusedAdamBase(const std::vector> &shapes, PD_ADAM_TEST_COMP(master_params, MT); PD_ADAM_TEST_COMP(moment1s, MT); PD_ADAM_TEST_COMP(moment2s, MT); + PD_ADAM_TEST_COMP(moment2s_max, MT); } static auto GenerateRandomShapes(size_t n, uint64_t low, uint64_t high) { @@ -454,7 +482,9 @@ TEST(fused_adam, test_fp32_cpu) { auto shapes = GenerateRandomShapes(30, 10, 20); float atol = 0.0f; for (auto use_adamw : {false, true}) { - TestFusedAdamBase(shapes, atol, use_adamw); + for (auto amsgrad : {false, true}) { + TestFusedAdamBase(shapes, atol, use_adamw, amsgrad); + } } } @@ -463,7 +493,9 @@ TEST(fused_adam, test_fp32_gpu) { auto shapes = GenerateRandomShapes(40, 0, 2 << 18); float atol = 0.0f; for (auto use_adamw : {false, true}) { - TestFusedAdamBase(shapes, atol, use_adamw); + for (auto amsgrad : {false, true}) { + TestFusedAdamBase(shapes, atol, use_adamw, amsgrad); + } } } @@ -471,7 +503,10 @@ TEST(fused_adam, test_fp16_gpu) { auto shapes = GenerateRandomShapes(40, 0, 2 << 18); float atol = 5e-3f; for (auto use_adamw : {false, true}) { - TestFusedAdamBase(shapes, atol, use_adamw, true); + for (auto amsgrad : {false, true}) { + TestFusedAdamBase( + shapes, atol, use_adamw, amsgrad, true); + } } } #endif diff --git a/test/legacy_test/test_adam_op.py b/test/legacy_test/test_adam_op.py index 3be5104817acb1..b781dd3bf1263b 100644 --- a/test/legacy_test/test_adam_op.py +++ b/test/legacy_test/test_adam_op.py @@ -30,6 +30,7 @@ def adam_wrapper( LearningRate, moment1, moment2, + moment2_max, beta1_pow, beta2_pow, master_weight=None, @@ -38,13 +39,15 @@ def adam_wrapper( beta2=0.836, epsilon=1e-4, lazy_mode=False, + amsgrad=False, ): - _, _, _, _, _, _ = paddle._C_ops.adam_( + _, _, _, _, _, _, _ = paddle._C_ops.adam_( param, grad, LearningRate, moment1, moment2, + moment2_max, beta1_pow, beta2_pow, master_weight, @@ -56,6 +59,7 @@ def adam_wrapper( 1000, False, False, + amsgrad, ) @@ -70,6 +74,7 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + moment2_max = np.zeros((102, 105)).astype("float32") learning_rate = 0.004 beta1 = 0.78 @@ -77,24 +82,34 @@ def setUp(self): epsilon = 1e-4 beta1_pow = beta1**10 beta2_pow = beta2**10 + amsgrad = False self.inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), } - self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + self.attrs = { + 'epsilon': epsilon, + 'beta1': beta1, + 'beta2': beta2, + 'amsgrad': amsgrad, + } - param_out, moment1_out, moment2_out = adam_step(self.inputs, self.attrs) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, self.attrs + ) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2, @@ -119,6 +134,7 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, self.shape).astype("float32") # The second moment is positive moment2 = np.random.random(self.shape).astype("float32") + moment2_max = np.zeros(self.shape).astype("float32") learning_rate = 0.001 beta1 = 0.9 @@ -126,24 +142,34 @@ def setUp(self): epsilon = 1e-8 beta1_pow = beta1**10 beta2_pow = beta2**10 + amsgrad = False self.inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), } - attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + attributes = { + 'epsilon': epsilon, + 'beta1': beta1, + 'beta2': beta2, + 'amsgrad': amsgrad, + } - param_out, moment1_out, moment2_out = adam_step(self.inputs, attributes) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, attributes + ) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2, @@ -171,6 +197,7 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + moment2_max = np.zeros((102, 105)).astype("float32") learning_rate = 0.001 self.beta1 = 0.9 @@ -178,12 +205,14 @@ def setUp(self): epsilon = 1e-8 self.beta1_pow = self.beta1**10 self.beta2_pow = self.beta2**10 + self.amsgrad = False self.inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([self.beta1_pow]).astype("float32"), 'Beta2Pow': np.array([self.beta2_pow]).astype("float32"), @@ -193,11 +222,12 @@ def setUp(self): 'epsilon': epsilon, 'beta1': self.beta1, 'beta2': self.beta2, + 'amsgrad': self.amsgrad, } def test_check_output(self): for _ in range(self.num_steps): - param_out, moment1_out, moment2_out = adam_step( + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( self.inputs, self.attrs ) @@ -206,6 +236,7 @@ def test_check_output(self): self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': beta1_pow_out, 'Beta2PowOut': beta2_pow_out, @@ -218,6 +249,7 @@ def test_check_output(self): self.inputs['Param'] = param_out self.inputs['Moment1'] = moment1_out self.inputs['Moment2'] = moment2_out + self.inputs['Moment2Max'] = moment2_max_out # Update powers of Beta1 and Beta2 for next time step self.inputs['Beta1Pow'] = beta1_pow_out @@ -241,6 +273,7 @@ def adam_step(inputs, attributes): grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] + moment2_max = inputs['Moment2Max'] lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -256,11 +289,25 @@ def adam_step(inputs, attributes): else: beta2 = inputs['Beta2Tensor'][0] + amsgrad = attributes['amsgrad'] + moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) - param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) - return param_out, moment1_out, moment2_out + + if amsgrad: + moment2_max_out = np.maximum(moment2_out, moment2_max) + param_out = param - lr_t * ( + moment1_out / (np.sqrt(moment2_max_out) + epsilon) + ) + else: + moment2_max_out = np.zeros_like(moment2_out) + param_out = param - lr_t * ( + moment1_out / (np.sqrt(moment2_out) + epsilon) + ) + + return param_out, moment1_out, moment2_out, moment2_max_out def adamw_step(inputs, attributes): @@ -275,6 +322,7 @@ def adamw_step(inputs, attributes): grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] + moment2_max = inputs['Moment2Max'] lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -294,12 +342,25 @@ def adamw_step(inputs, attributes): else: beta2 = inputs['Beta2Tensor'][0] + amsgrad = attributes["amsgrad"] + moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) - param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) - return param_out, moment1_out, moment2_out + if amsgrad: + moment2_max_out = np.maximum(moment2_out, moment2_max) + param_out = param - lr_t * ( + moment1_out / (np.sqrt(moment2_max_out) + epsilon) + ) + else: + moment2_max_out = np.zeros_like(moment2_out) + param_out = param - lr_t * ( + moment1_out / (np.sqrt(moment2_out) + epsilon) + ) + + return param_out, moment1_out, moment2_out, moment2_max_out def adam_step_sparse( @@ -316,6 +377,7 @@ def adam_step_sparse( # grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] + moment2_max = inputs['Moment2Max'] lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -323,9 +385,11 @@ def adam_step_sparse( beta1 = attributes['beta1'] beta2 = attributes['beta2'] epsilon = attributes['epsilon'] + amsgrad = attributes['amsgrad'] moment1_out = np.zeros(shape=[height, row_numel]) moment2_out = np.zeros(shape=[height, row_numel]) + moment2_max_out = np.zeros(shape=[height, row_numel]) param_out = np.zeros(shape=[height, row_numel]) def update_row(row_id, update_value): @@ -336,9 +400,19 @@ def update_row(row_id, update_value): update_value ) lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) - param_out[row_id] = param[row_id] - lr_t * ( - moment1_out[row_id] / (np.sqrt(moment2_out[row_id]) + epsilon) - ) + + if amsgrad: + moment2_max_out[row_id] = np.maximum( + moment2_out[row_id], moment2_max[row_id] + ) + param_out[row_id] = param[row_id] - lr_t * ( + moment1_out[row_id] + / (np.sqrt(moment2_max_out[row_id]) + epsilon) + ) + else: + param_out[row_id] = param[row_id] - lr_t * ( + moment1_out[row_id] / (np.sqrt(moment2_out[row_id]) + epsilon) + ) if lazy_mode: for idx, row_id in enumerate(rows): @@ -350,7 +424,7 @@ def update_row(row_id, update_value): update_value = np_grad[rows.index(row_id)] update_row(row_id, update_value) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_max_out class TestSparseAdamOp(unittest.TestCase): @@ -370,6 +444,7 @@ def setup(self, scope, place, lazy_mode): "Param": np.full((height, row_numel), 5.0).astype("float32"), "Moment1": np.full((height, row_numel), 5.0).astype("float32"), "Moment2": np.full((height, row_numel), 5.0).astype("float32"), + "Moment2Max": np.zeros((height, row_numel)).astype("float32"), 'Beta1Pow': beta1_pow, 'Beta2Pow': beta2_pow, "LearningRate": np.full((1), 2.0).astype("float32"), @@ -380,6 +455,7 @@ def setup(self, scope, place, lazy_mode): 'beta1': beta1, 'beta2': beta2, 'min_row_size_to_use_multithread': 2, + 'amsgrad': False, } grad_selected_rows = scope.var('Grad').get_selected_rows() @@ -394,7 +470,7 @@ def setup(self, scope, place, lazy_mode): self.sparse_inputs = ["Grad"] - param_out, mom1, mom2 = adam_step_sparse( + param_out, mom1, mom2, mom2_max = adam_step_sparse( self.dense_inputs, self.attrs, height, @@ -407,6 +483,7 @@ def setup(self, scope, place, lazy_mode): "ParamOut": param_out, "Moment1Out": mom1, "Moment2Out": mom2, + "Moment2MaxOut": mom2_max, 'Beta1PowOut': beta1_pow * beta1, 'Beta2PowOut': beta2_pow * beta2, } @@ -469,6 +546,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + moment2_max = np.zeros((102, 105)).astype("float32") + beta1 = 0.85 beta2 = 0.95 @@ -482,6 +561,7 @@ def setUp(self): 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), @@ -489,13 +569,16 @@ def setUp(self): "Beta2Tensor": np.array([beta2]).astype("float32"), } - attributes = {'epsilon': epsilon} + attributes = {'epsilon': epsilon, 'amsgrad': False} - param_out, moment1_out, moment2_out = adam_step(self.inputs, attributes) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, attributes + ) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2, @@ -516,6 +599,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + moment2_max = np.zeros((102, 105)).astype("float32") + beta1 = 0.85 beta2 = 0.95 @@ -529,6 +614,7 @@ def setUp(self): 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), @@ -537,13 +623,16 @@ def setUp(self): "EpsilonTensor": np.array([epsilon]).astype("float32"), } - attributes = {'epsilon': epsilon} + attributes = {'epsilon': epsilon, 'amsgrad': False} - param_out, moment1_out, moment2_out = adam_step(self.inputs, attributes) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, attributes + ) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2, @@ -564,6 +653,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + moment2_max = np.zeros((102, 105)).astype("float32") + beta1 = 0.85 beta2 = 0.95 @@ -577,6 +668,7 @@ def setUp(self): 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), @@ -585,9 +677,11 @@ def setUp(self): "EpsilonTensor": np.array([epsilon]).astype("float32"), } - attributes = {'epsilon': epsilon} + attributes = {'epsilon': epsilon, 'amsgrad': False} - param_out, moment1_out, moment2_out = adam_step(self.inputs, attributes) + param_out, moment1_out, moment2_out, moment2_max_out = adam_step( + self.inputs, attributes + ) self.attrs = {'use_global_beta_pow': True} @@ -595,6 +689,7 @@ def setUp(self): self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([]), 'Beta2PowOut': np.array([]), @@ -615,6 +710,8 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + moment2_max = np.zeros((102, 105)).astype("float32") + beta1 = 0.85 beta2 = 0.95 @@ -628,6 +725,7 @@ def setUp(self): 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), @@ -637,14 +735,15 @@ def setUp(self): "SkipUpdate": np.array([True]).astype("bool"), } - attributes = {'epsilon': epsilon} + attributes = {'epsilon': epsilon, 'amsgrad': False} - self.attrs = {'use_global_beta_pow': True} + self.attrs = {'use_global_beta_pow': True, 'amsgrad': False} # use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty. self.outputs = { 'Moment1Out': moment1, 'Moment2Out': moment2, + 'Moment2MaxOut': moment2_max, 'ParamOut': param, 'Beta1PowOut': np.array([]), 'Beta2PowOut': np.array([]), @@ -1056,6 +1155,8 @@ def test_pir_main(self): self._check_with_place_amp(place, use_amp) +# TODO(megemini): AMSGrad + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/test/legacy_test/test_adamw_op.py b/test/legacy_test/test_adamw_op.py index d59cb53ab2a391..7d1512a6ef0e67 100644 --- a/test/legacy_test/test_adamw_op.py +++ b/test/legacy_test/test_adamw_op.py @@ -30,6 +30,7 @@ def adamw_step(inputs, attributes): grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] + moment2_max = inputs['Moment2Max'] lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -54,11 +55,20 @@ def adamw_step(inputs, attributes): else: beta2 = inputs['Beta2Tensor'][0] + amsgrad = attributes['amsgrad'] + moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) - denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon + + if amsgrad: + moment2_max_out = np.maximum(moment2_out, moment2_max) + denom = (np.sqrt(moment2_max_out) / np.sqrt(1.0 - beta2_pow)) + epsilon + else: + moment2_max_out = np.zeros_like(moment2_out) + denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon + param_out = param + ((moment1_out / denom) * (-(lr / (1.0 - beta1_pow)))) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_max_out def adamw_wrapper( @@ -67,6 +77,7 @@ def adamw_wrapper( lr, moment1, moment2, + moment2_max, beta1_pow, beta2_pow, master_weight=None, @@ -78,13 +89,15 @@ def adamw_wrapper( weight_decay=0.01, with_decay=True, lazy_mode=False, + amsgrad=False, ): - _, _, _, _, _, _ = paddle._C_ops.adamw_( + _, _, _, _, _, _, _ = paddle._C_ops.adamw_( param, grad, lr, moment1, moment2, + moment2_max, beta1_pow, beta2_pow, master_weight, @@ -99,6 +112,7 @@ def adamw_wrapper( 1000, False, False, + amsgrad, ) @@ -113,6 +127,7 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") # The second moment is positive moment2 = np.random.random((102, 105)).astype("float32") + moment2_max = np.zeros((102, 105)).astype("float32") learning_rate = 0.004 beta1 = 0.78 @@ -126,6 +141,7 @@ def setUp(self): 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), @@ -137,15 +153,17 @@ def setUp(self): 'beta2': beta2, "coeff": 0.5, "with_decay": True, + "amsgrad": False, } - param_out, moment1_out, moment2_out = adamw_step( + param_out, moment1_out, moment2_out, moment2_max_out = adamw_step( self.inputs, self.attrs ) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2, @@ -169,6 +187,7 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, (2, 2)).astype("float32") # The second moment is positive moment2 = np.random.random((2, 2)).astype("float32") + moment2_max = np.zeros((2, 2)).astype("float32") learning_rate = 0.004 beta1 = 0.78 @@ -182,6 +201,7 @@ def setUp(self): 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), @@ -194,15 +214,17 @@ def setUp(self): "lr_ratio": 0.1, "coeff": 0.5, "with_decay": True, + "amsgrad": False, } - param_out, moment1_out, moment2_out = adamw_step( + param_out, moment1_out, moment2_out, moment2_max_out = adamw_step( self.inputs, self.attrs ) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2, @@ -448,6 +470,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( main_grad = grad.astype(paddle.float32) moment1 = paddle.randn(shape).astype(paddle.float32) moment2 = paddle.randn(shape).astype(paddle.float32).abs() + moment2_max = paddle.zeros(shape).astype(paddle.float32) lr = paddle.zeros([1]).astype(paddle.float32) lr[0] = lr_rate beta1_pow_acc = paddle.ones([1]).astype(paddle.float32) @@ -460,14 +483,16 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( ref_beta2_pow_acc = beta2_pow_acc.astype(paddle.float32) ref_moment_1 = moment1.astype(paddle.float32) ref_moment_2 = moment2.astype(paddle.float32) + ref_moment_2_max = moment2.astype(paddle.float32) # reference code - _, _, _, _, _, _ = paddle._C_ops.adamw_( + _, _, _, _, _, _, _ = paddle._C_ops.adamw_( ref_param, main_grad, lr, ref_moment_1, ref_moment_2, + ref_moment_2_max, ref_beta1_pow_acc, ref_beta2_pow_acc, master_weight, @@ -482,15 +507,17 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, False, False, + False, ) if use_main_grad: - _, _, _, _, _, _ = paddle._C_ops.adamw_( + _, _, _, _, _, _, _ = paddle._C_ops.adamw_( param, main_grad, lr, moment1, moment2, + moment2_max, beta1_pow_acc, beta2_pow_acc, master_weight, @@ -505,6 +532,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, find_master, False, + False, ) np.testing.assert_allclose( param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 @@ -513,12 +541,13 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( master_weight.numpy(), ref_param.numpy(), rtol=1e-6 ) else: - _, _, _, _, _, _ = paddle._C_ops.adamw_( + _, _, _, _, _, _, _ = paddle._C_ops.adamw_( param, grad, lr, moment1, moment2, + moment2_max, beta1_pow_acc, beta2_pow_acc, master_weight, @@ -533,6 +562,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, find_master, False, + False, ) np.testing.assert_allclose( param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 @@ -754,16 +784,20 @@ def test_adamw_op_dygraph(self): fc1_w = np.array(linear1.weight) fc1_w_mon1 = np.zeros_like(fc1_w) fc1_w_mon2 = np.zeros_like(fc1_w) + fc1_w_mon2_max = np.zeros_like(fc1_w) fc1_b = np.array(linear1.bias) fc1_b_mon1 = np.zeros_like(fc1_b) fc1_b_mon2 = np.zeros_like(fc1_b) + fc1_b_mon2_max = np.zeros_like(fc1_b) fc2_w = np.array(linear2.weight) fc2_w_mon1 = np.zeros_like(fc2_w) fc2_w_mon2 = np.zeros_like(fc2_w) + fc2_w_mon2_max = np.zeros_like(fc2_w) fc2_b = np.array(linear2.bias) fc2_b_mon1 = np.zeros_like(fc2_b) fc2_b_mon2 = np.zeros_like(fc2_b) + fc2_b_mon2_max = np.zeros_like(fc2_b) simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2) learning_rate = 0.001 @@ -784,12 +818,15 @@ def test_adamw_op_dygraph(self): lr_ratio=simple_lr_fun, ) - def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): + def get_numpy_output( + param, grad, moment1, moment2, moment2_max, lr_ratio, t + ): np_inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1**t]).astype("float32"), 'Beta2Pow': np.array([beta2**t]).astype("float32"), @@ -802,11 +839,12 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): "lr_ratio": lr_ratio, "coeff": weight_decay, "with_decay": True, + "amsgrad": False, } - param_out, moment1_out, moment2_out = adamw_step( + param_out, moment1_out, moment2_out, moment2_max_out = adamw_step( np_inputs, np_attrs ) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_max_out for i in range(5): a = paddle.to_tensor( @@ -817,35 +855,39 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): out = paddle.mean(out) out.backward() - fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output( + fc1_w, fc1_w_mon1, fc1_w_mon2, fc1_w_mon2_max = get_numpy_output( fc1_w, np.array(linear1.weight.grad), fc1_w_mon1, fc1_w_mon2, + fc1_w_mon2_max, simple_lr_fun(linear1.weight), i + 1, ) - fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output( + fc1_b, fc1_b_mon1, fc1_b_mon2, fc1_b_mon2_max = get_numpy_output( fc1_b, np.array(linear1.bias.grad), fc1_b_mon1, fc1_b_mon2, + fc1_b_mon2_max, simple_lr_fun(linear1.bias), i + 1, ) - fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output( + fc2_w, fc2_w_mon1, fc2_w_mon2, fc2_w_mon2_max = get_numpy_output( fc2_w, np.array(linear2.weight.grad), fc2_w_mon1, fc2_w_mon2, + fc2_w_mon2_max, simple_lr_fun(linear2.weight), i + 1, ) - fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output( + fc2_b, fc2_b_mon1, fc2_b_mon2, fc2_b_mon2_max = get_numpy_output( fc2_b, np.array(linear2.bias.grad), fc2_b_mon1, fc2_b_mon2, + fc2_b_mon2_max, simple_lr_fun(linear2.bias), i + 1, ) @@ -910,16 +952,28 @@ def test_adamw_op(self): fc1_w_mon2 = np.zeros(linear1.weight.shape).astype( "float32" ) + fc1_w_mon2_max = np.zeros(linear1.weight.shape).astype( + "float32" + ) fc1_b_mon1 = np.zeros(linear1.bias.shape).astype("float32") fc1_b_mon2 = np.zeros(linear1.bias.shape).astype("float32") + fc1_b_mon2_max = np.zeros(linear1.bias.shape).astype( + "float32" + ) fc2_w_mon1 = np.zeros(linear2.weight.shape).astype( "float32" ) fc2_w_mon2 = np.zeros(linear2.weight.shape).astype( "float32" ) + fc2_w_mon2_max = np.zeros(linear2.weight.shape).astype( + "float32" + ) fc2_b_mon1 = np.zeros(linear2.bias.shape).astype("float32") fc2_b_mon2 = np.zeros(linear2.bias.shape).astype("float32") + fc2_b_mon2_max = np.zeros(linear2.bias.shape).astype( + "float32" + ) cost = paddle.nn.functional.square_error_cost( input=out, label=y @@ -940,12 +994,15 @@ def test_adamw_op(self): ) opt.minimize(avg_cost) - def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): + def get_numpy_output( + param, grad, moment1, moment2, moment2_max, lr_ratio, t + ): np_inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1**t]).astype("float32"), 'Beta2Pow': np.array([beta2**t]).astype("float32"), @@ -958,11 +1015,12 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): "lr_ratio": lr_ratio, "coeff": weight_decay, "with_decay": True, + "amsgrad": False, } - param_out, moment1_out, moment2_out = adamw_step( - np_inputs, np_attrs + param_out, moment1_out, moment2_out, moment2_max_out = ( + adamw_step(np_inputs, np_attrs) ) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_max_out fetch_list1 = [ "linear_0.w_0", @@ -1009,37 +1067,49 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): fc2_b = param[3] fc2_b_grad = params_and_gras[7] - fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output( - fc1_w, - fc1_w_grad, - fc1_w_mon1, - fc1_w_mon2, - simple_lr_fun(linear1.weight), - i + 1, + fc1_w, fc1_w_mon1, fc1_w_mon2, fc1_w_mon2_max = ( + get_numpy_output( + fc1_w, + fc1_w_grad, + fc1_w_mon1, + fc1_w_mon2, + fc1_w_mon2_max, + simple_lr_fun(linear1.weight), + i + 1, + ) ) - fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output( - fc1_b, - fc1_b_grad, - fc1_b_mon1, - fc1_b_mon2, - simple_lr_fun(linear1.bias), - i + 1, + fc1_b, fc1_b_mon1, fc1_b_mon2, fc1_b_mon2_max = ( + get_numpy_output( + fc1_b, + fc1_b_grad, + fc1_b_mon1, + fc1_b_mon2, + fc1_b_mon2_max, + simple_lr_fun(linear1.bias), + i + 1, + ) ) - fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output( - fc2_w, - fc2_w_grad, - fc2_w_mon1, - fc2_w_mon2, - simple_lr_fun(linear2.weight), - i + 1, + fc2_w, fc2_w_mon1, fc2_w_mon2, fc2_w_mon2_max = ( + get_numpy_output( + fc2_w, + fc2_w_grad, + fc2_w_mon1, + fc2_w_mon2, + fc2_w_mon2_max, + simple_lr_fun(linear2.weight), + i + 1, + ) ) - fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output( - fc2_b, - fc2_b_grad, - fc2_b_mon1, - fc2_b_mon2, - simple_lr_fun(linear2.bias), - i + 1, + fc2_b, fc2_b_mon1, fc2_b_mon2, fc2_b_mon2_max = ( + get_numpy_output( + fc2_b, + fc2_b_grad, + fc2_b_mon1, + fc2_b_mon2, + fc2_b_mon2_max, + simple_lr_fun(linear2.bias), + i + 1, + ) ) np.testing.assert_allclose(params_and_gras[0], fc1_w, rtol=1e-6) @@ -1101,16 +1171,28 @@ def test_adamw_op_with_pir(self): fc1_w_mon2 = np.zeros(linear1.weight.shape).astype( "float32" ) + fc1_w_mon2_max = np.zeros(linear1.weight.shape).astype( + "float32" + ) fc1_b_mon1 = np.zeros(linear1.bias.shape).astype("float32") fc1_b_mon2 = np.zeros(linear1.bias.shape).astype("float32") + fc1_b_mon2_max = np.zeros(linear1.bias.shape).astype( + "float32" + ) fc2_w_mon1 = np.zeros(linear2.weight.shape).astype( "float32" ) fc2_w_mon2 = np.zeros(linear2.weight.shape).astype( "float32" ) + fc2_w_mon2_max = np.zeros(linear2.weight.shape).astype( + "float32" + ) fc2_b_mon1 = np.zeros(linear2.bias.shape).astype("float32") fc2_b_mon2 = np.zeros(linear2.bias.shape).astype("float32") + fc2_b_mon2_max = np.zeros(linear2.bias.shape).astype( + "float32" + ) cost = paddle.nn.functional.square_error_cost( input=out, label=y @@ -1131,12 +1213,15 @@ def test_adamw_op_with_pir(self): ) _, params_grads = opt.minimize(avg_cost) - def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): + def get_numpy_output( + param, grad, moment1, moment2, moment2_max, lr_ratio, t + ): np_inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1**t]).astype("float32"), 'Beta2Pow': np.array([beta2**t]).astype("float32"), @@ -1149,11 +1234,12 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): "lr_ratio": lr_ratio, "coeff": weight_decay, "with_decay": True, + "amsgrad": False, } - param_out, moment1_out, moment2_out = adamw_step( - np_inputs, np_attrs + param_out, moment1_out, moment2_out, moment2_out_max = ( + adamw_step(np_inputs, np_attrs) ) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_out_max exe = base.Executor(place) exe.run(train_startup) @@ -1242,37 +1328,49 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): fc2_b = param[3] fc2_b_grad = params_and_gras[1] - fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output( - fc1_w, - fc1_w_grad, - fc1_w_mon1, - fc1_w_mon2, - simple_lr_fun(linear1.weight), - i + 1, + fc1_w, fc1_w_mon1, fc1_w_mon2, fc1_w_mon2_max = ( + get_numpy_output( + fc1_w, + fc1_w_grad, + fc1_w_mon1, + fc1_w_mon2, + fc1_w_mon2_max, + simple_lr_fun(linear1.weight), + i + 1, + ) ) - fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output( - fc1_b, - fc1_b_grad, - fc1_b_mon1, - fc1_b_mon2, - simple_lr_fun(linear1.bias), - i + 1, + fc1_b, fc1_b_mon1, fc1_b_mon2, fc1_b_mon2_max = ( + get_numpy_output( + fc1_b, + fc1_b_grad, + fc1_b_mon1, + fc1_b_mon2, + fc1_b_mon2_max, + simple_lr_fun(linear1.bias), + i + 1, + ) ) - fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output( - fc2_w, - fc2_w_grad, - fc2_w_mon1, - fc2_w_mon2, - simple_lr_fun(linear2.weight), - i + 1, + fc2_w, fc2_w_mon1, fc2_w_mon2, fc2_w_mon2_max = ( + get_numpy_output( + fc2_w, + fc2_w_grad, + fc2_w_mon1, + fc2_w_mon2, + fc2_w_mon2_max, + simple_lr_fun(linear2.weight), + i + 1, + ) ) - fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output( - fc2_b, - fc2_b_grad, - fc2_b_mon1, - fc2_b_mon2, - simple_lr_fun(linear2.bias), - i + 1, + fc2_b, fc2_b_mon1, fc2_b_mon2, fc2_b_mon2_max = ( + get_numpy_output( + fc2_b, + fc2_b_grad, + fc2_b_mon1, + fc2_b_mon2, + fc2_b_mon2_max, + simple_lr_fun(linear2.bias), + i + 1, + ) ) np.testing.assert_allclose(params_and_gras[6], fc1_w, rtol=1e-6) @@ -1283,5 +1381,7 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): paddle.disable_static() +# TODO(megemini): AMSGrad + if __name__ == "__main__": unittest.main() diff --git a/test/xpu/test_adam_op_xpu.py b/test/xpu/test_adam_op_xpu.py index 54f8d36a187a4a..dc47654f7dcb96 100644 --- a/test/xpu/test_adam_op_xpu.py +++ b/test/xpu/test_adam_op_xpu.py @@ -246,13 +246,14 @@ def adam_step(inputs, attributes): Simulate one step of the adam optimizer :param inputs: dict of inputs :param attributes: dict of attributes - :return tuple: tuple of output param, moment1, moment2, + :return tuple: tuple of output param, moment1, moment2, moment2_max beta1 power accumulator and beta2 power accumulator ''' param = inputs['Param'] grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] + moment2_max = inputs['Moment2Max'] lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -268,13 +269,27 @@ def adam_step(inputs, attributes): else: beta2 = inputs['Beta2Tensor'][0] + amsgrad = attributes['amsgrad'] + moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) - param_out = param - lr_t * ( - moment1_out / (np.sqrt(moment2_out) + epsilon * np.sqrt(1 - beta2_pow)) - ) - return param_out, moment1_out, moment2_out + + if amsgrad: + moment2_max_out = np.maximum(moment2_out, moment2_max) + param_out = param - lr_t * ( + moment1_out + / (np.sqrt(moment2_max_out) + epsilon * np.sqrt(1 - beta2_pow)) + ) + else: + moment2_max_out = np.zeros_like(moment2_out) + param_out = param - lr_t * ( + moment1_out + / (np.sqrt(moment2_out) + epsilon * np.sqrt(1 - beta2_pow)) + ) + + return param_out, moment1_out, moment2_out, moment2_max_out def adam_step_sparse( @@ -291,6 +306,7 @@ def adam_step_sparse( # grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] + moment2_max = inputs['Moment2Max'] lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -298,9 +314,11 @@ def adam_step_sparse( beta1 = attributes['beta1'] beta2 = attributes['beta2'] epsilon = attributes['epsilon'] + amsgrad = attributes['amsgrad'] moment1_out = np.zeros(shape=[height, row_numel]) moment2_out = np.zeros(shape=[height, row_numel]) + moment2_max_out = np.zeros(shape=[height, row_numel]) param_out = np.zeros(shape=[height, row_numel]) def update_row(row_id, update_value): @@ -311,9 +329,19 @@ def update_row(row_id, update_value): update_value ) lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) - param_out[row_id] = param[row_id] - lr_t * ( - moment1_out[row_id] / (np.sqrt(moment2_out[row_id]) + epsilon) - ) + + if amsgrad: + moment2_max_out[row_id] = np.maximum( + moment2_out[row_id], moment2_max[row_id] + ) + param_out[row_id] = param[row_id] - lr_t * ( + moment1_out[row_id] + / (np.sqrt(moment2_max_out[row_id]) + epsilon) + ) + else: + param_out[row_id] = param[row_id] - lr_t * ( + moment1_out[row_id] / (np.sqrt(moment2_out[row_id]) + epsilon) + ) if lazy_mode: for idx, row_id in enumerate(rows): @@ -325,7 +353,7 @@ def update_row(row_id, update_value): update_value = np_grad[rows.index(row_id)] update_row(row_id, update_value) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_max_out class TestSparseAdamOp(unittest.TestCase): @@ -345,6 +373,7 @@ def setup(self, scope, place, lazy_mode): "Param": np.full((height, row_numel), 5.0).astype("float32"), "Moment1": np.full((height, row_numel), 5.0).astype("float32"), "Moment2": np.full((height, row_numel), 5.0).astype("float32"), + "Moment2Max": np.zeros((height, row_numel)).astype("float32"), 'Beta1Pow': beta1_pow, 'Beta2Pow': beta2_pow, "LearningRate": np.full((1), 2.0).astype("float32"), @@ -355,6 +384,7 @@ def setup(self, scope, place, lazy_mode): 'beta1': beta1, 'beta2': beta2, 'min_row_size_to_use_multithread': 2, + 'amsgrad': False, # Currently, xpu NOT support amsgrad. } grad_selected_rows = scope.var('Grad').get_selected_rows() @@ -369,7 +399,7 @@ def setup(self, scope, place, lazy_mode): self.sparse_inputs = ["Grad"] - param_out, mom1, mom2 = adam_step_sparse( + param_out, mom1, mom2, mom2_max = adam_step_sparse( self.dense_inputs, self.attrs, height, @@ -382,6 +412,7 @@ def setup(self, scope, place, lazy_mode): "ParamOut": param_out, "Moment1Out": mom1, "Moment2Out": mom2, + "Moment2MaxOut": mom2_max, 'Beta1PowOut': beta1_pow * beta1, 'Beta2PowOut': beta2_pow * beta2, } @@ -442,6 +473,7 @@ def setup(self, scope, place, lazy_mode): "Param": np.full((height, row_numel), 5.0).astype("float16"), "Moment1": np.full((height, row_numel), 5.0).astype("float16"), "Moment2": np.full((height, row_numel), 5.0).astype("float16"), + "Moment2Max": np.zeros((height, row_numel)).astype("float16"), 'Beta1Pow': beta1_pow, 'Beta2Pow': beta2_pow, "LearningRate": np.full((1), 2.0).astype("float16"), @@ -452,6 +484,7 @@ def setup(self, scope, place, lazy_mode): 'beta1': beta1, 'beta2': beta2, 'min_row_size_to_use_multithread': 2, + 'amsgrad': False, # Currently, xpu NOT support amsgrad. } grad_selected_rows = scope.var('Grad').get_selected_rows() @@ -466,7 +499,7 @@ def setup(self, scope, place, lazy_mode): self.sparse_inputs = ["Grad"] - param_out, mom1, mom2 = adam_step_sparse( + param_out, mom1, mom2, mom2_max = adam_step_sparse( self.dense_inputs, self.attrs, height, @@ -479,6 +512,7 @@ def setup(self, scope, place, lazy_mode): "ParamOut": param_out, "Moment1Out": mom1, "Moment2Out": mom2, + "Moment2MaxOut": mom2_max, 'Beta1PowOut': beta1_pow * beta1, 'Beta2PowOut': beta2_pow * beta2, } diff --git a/test/xpu/test_adamw_op_xpu.py b/test/xpu/test_adamw_op_xpu.py index ae7a0a5434cb80..c4723c136f3e27 100644 --- a/test/xpu/test_adamw_op_xpu.py +++ b/test/xpu/test_adamw_op_xpu.py @@ -36,6 +36,7 @@ def adamw_step(inputs, attributes): grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] + moment2_max = inputs['Moment2Max'] lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -60,11 +61,20 @@ def adamw_step(inputs, attributes): else: beta2 = inputs['Beta2Tensor'][0] + amsgrad = attributes['amsgrad'] + moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) - denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon + + if amsgrad: + moment2_max_out = np.maximum(moment2_out, moment2_max) + denom = (np.sqrt(moment2_max_out) / np.sqrt(1.0 - beta2_pow)) + epsilon + else: + moment2_max_out = np.zeros_like(moment2_out) + denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon + param_out = param + ((moment1_out / denom) * (-(lr / (1.0 - beta1_pow)))) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_max_out def simple_lr_setting(param, decay_rate, n_layers): @@ -94,6 +104,7 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, self.shape).astype("float32") # The second moment is positive moment2 = np.random.random(self.shape).astype("float32") + moment2_max = np.zeros(self.shape).astype("float32") learning_rate = 0.004 beta1 = 0.78 @@ -109,6 +120,7 @@ def setUp(self): 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), @@ -120,15 +132,17 @@ def setUp(self): 'beta2': beta2, "coeff": 0.5, "with_decay": True, + "amsgrad": False, # Currently, xpu NOT support amsgrad. } - param_out, moment1_out, moment2_out = adamw_step( + param_out, moment1_out, moment2_out, moment2_max_out = adamw_step( self.inputs, self.attrs ) self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, + 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2, @@ -357,16 +371,20 @@ def test_adamw_op_dygraph(self): fc1_w = np.array(linear1.weight) fc1_w_mon1 = np.zeros_like(fc1_w) fc1_w_mon2 = np.zeros_like(fc1_w) + fc1_w_mon2_max = np.zeros_like(fc1_w) fc1_b = np.array(linear1.bias) fc1_b_mon1 = np.zeros_like(fc1_b) fc1_b_mon2 = np.zeros_like(fc1_b) + fc1_b_mon2_max = np.zeros_like(fc1_b) fc2_w = np.array(linear2.weight) fc2_w_mon1 = np.zeros_like(fc2_w) fc2_w_mon2 = np.zeros_like(fc2_w) + fc2_w_mon2_max = np.zeros_like(fc2_w) fc2_b = np.array(linear2.bias) fc2_b_mon1 = np.zeros_like(fc2_b) fc2_b_mon2 = np.zeros_like(fc2_b) + fc2_b_mon2_max = np.zeros_like(fc2_b) simple_lr_fun = partial( simple_lr_setting, decay_rate=0.8, n_layers=2 @@ -389,12 +407,15 @@ def test_adamw_op_dygraph(self): lr_ratio=simple_lr_fun, ) - def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): + def get_numpy_output( + param, grad, moment1, moment2, moment2_max, lr_ratio, t + ): np_inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1**t]).astype("float32"), 'Beta2Pow': np.array([beta2**t]).astype("float32"), @@ -407,11 +428,12 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): "lr_ratio": lr_ratio, "coeff": weight_decay, "with_decay": True, + "amsgrad": False, # Currently, xpu NOT support amsgrad. } - param_out, moment1_out, moment2_out = adamw_step( + param_out, moment1_out, moment2_out, moment2_max = adamw_step( np_inputs, np_attrs ) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_max for i in range(5): a = paddle.to_tensor( @@ -422,37 +444,49 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): out = paddle.mean(out) out.backward() - fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output( - fc1_w, - np.array(linear1.weight.grad), - fc1_w_mon1, - fc1_w_mon2, - simple_lr_fun(linear1.weight), - i + 1, + fc1_w, fc1_w_mon1, fc1_w_mon2, fc1_w_mon2_max = ( + get_numpy_output( + fc1_w, + np.array(linear1.weight.grad), + fc1_w_mon1, + fc1_w_mon2, + fc1_w_mon2_max, + simple_lr_fun(linear1.weight), + i + 1, + ) ) - fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output( - fc1_b, - np.array(linear1.bias.grad), - fc1_b_mon1, - fc1_b_mon2, - simple_lr_fun(linear1.bias), - i + 1, + fc1_b, fc1_b_mon1, fc1_b_mon2, fc1_b_mon2_max = ( + get_numpy_output( + fc1_b, + np.array(linear1.bias.grad), + fc1_b_mon1, + fc1_b_mon2, + fc1_b_mon2_max, + simple_lr_fun(linear1.bias), + i + 1, + ) ) - fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output( - fc2_w, - np.array(linear2.weight.grad), - fc2_w_mon1, - fc2_w_mon2, - simple_lr_fun(linear2.weight), - i + 1, + fc2_w, fc2_w_mon1, fc2_w_mon2, fc2_w_mon2_max = ( + get_numpy_output( + fc2_w, + np.array(linear2.weight.grad), + fc2_w_mon1, + fc2_w_mon2, + fc2_w_mon2_max, + simple_lr_fun(linear2.weight), + i + 1, + ) ) - fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output( - fc2_b, - np.array(linear2.bias.grad), - fc2_b_mon1, - fc2_b_mon2, - simple_lr_fun(linear2.bias), - i + 1, + fc2_b, fc2_b_mon1, fc2_b_mon2, fc2_b_mon2_max = ( + get_numpy_output( + fc2_b, + np.array(linear2.bias.grad), + fc2_b_mon1, + fc2_b_mon2, + fc2_b_mon2_max, + simple_lr_fun(linear2.bias), + i + 1, + ) ) opt.step() @@ -523,16 +557,28 @@ def test_adamw_op(self): fc1_w_mon2 = np.zeros(linear1.weight.shape).astype( "float32" ) + fc1_w_mon2_max = np.zeros(linear1.weight.shape).astype( + "float32" + ) fc1_b_mon1 = np.zeros(linear1.bias.shape).astype("float32") fc1_b_mon2 = np.zeros(linear1.bias.shape).astype("float32") + fc1_b_mon2_max = np.zeros(linear1.bias.shape).astype( + "float32" + ) fc2_w_mon1 = np.zeros(linear2.weight.shape).astype( "float32" ) fc2_w_mon2 = np.zeros(linear2.weight.shape).astype( "float32" ) + fc2_w_mon2_max = np.zeros(linear2.weight.shape).astype( + "float32" + ) fc2_b_mon1 = np.zeros(linear2.bias.shape).astype("float32") fc2_b_mon2 = np.zeros(linear2.bias.shape).astype("float32") + fc2_b_mon2_max = np.zeros(linear2.bias.shape).astype( + "float32" + ) cost = paddle.nn.functional.square_error_cost( input=out, label=y @@ -553,12 +599,15 @@ def test_adamw_op(self): ) _, params_grads = opt.minimize(avg_cost) - def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): + def get_numpy_output( + param, grad, moment1, moment2, moment2_max, lr_ratio, t + ): np_inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1**t]).astype("float32"), 'Beta2Pow': np.array([beta2**t]).astype("float32"), @@ -571,11 +620,12 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): "lr_ratio": lr_ratio, "coeff": weight_decay, "with_decay": True, + "amsgrad": False, # Currently, xpu NOT support amsgrad. } - param_out, moment1_out, moment2_out = adamw_step( + param_out, moment1_out, moment2_out, moment2_max = adamw_step( np_inputs, np_attrs ) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_max if paddle.framework.in_pir_mode(): fetch_list1 = [ @@ -640,37 +690,49 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): fc2_b = param[3] fc2_b_grad = params_and_gras[7] - fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output( - fc1_w, - fc1_w_grad, - fc1_w_mon1, - fc1_w_mon2, - simple_lr_fun(linear1.weight), - i + 1, + fc1_w, fc1_w_mon1, fc1_w_mon2, fc1_w_mon2_max = ( + get_numpy_output( + fc1_w, + fc1_w_grad, + fc1_w_mon1, + fc1_w_mon2, + fc1_w_mon2_max, + simple_lr_fun(linear1.weight), + i + 1, + ) ) - fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output( - fc1_b, - fc1_b_grad, - fc1_b_mon1, - fc1_b_mon2, - simple_lr_fun(linear1.bias), - i + 1, + fc1_b, fc1_b_mon1, fc1_b_mon2, fc1_b_mon2_max = ( + get_numpy_output( + fc1_b, + fc1_b_grad, + fc1_b_mon1, + fc1_b_mon2, + fc1_b_mon2_max, + simple_lr_fun(linear1.bias), + i + 1, + ) ) - fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output( - fc2_w, - fc2_w_grad, - fc2_w_mon1, - fc2_w_mon2, - simple_lr_fun(linear2.weight), - i + 1, + fc2_w, fc2_w_mon1, fc2_w_mon2, fc2_w_mon2_max = ( + get_numpy_output( + fc2_w, + fc2_w_grad, + fc2_w_mon1, + fc2_w_mon2, + fc2_w_mon2_max, + simple_lr_fun(linear2.weight), + i + 1, + ) ) - fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output( - fc2_b, - fc2_b_grad, - fc2_b_mon1, - fc2_b_mon2, - simple_lr_fun(linear2.bias), - i + 1, + fc2_b, fc2_b_mon1, fc2_b_mon2, fc2_b_mon2_max = ( + get_numpy_output( + fc2_b, + fc2_b_grad, + fc2_b_mon1, + fc2_b_mon2, + fc2_b_mon2_max, + simple_lr_fun(linear2.bias), + i + 1, + ) ) np.testing.assert_allclose( @@ -718,6 +780,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( main_grad = grad.astype(paddle.float32) moment1 = paddle.randn(shape).astype(paddle.float32) moment2 = paddle.randn(shape).astype(paddle.float32).abs() + moment2_max = paddle.zeros(shape).astype(paddle.float32) lr = paddle.zeros([1]).astype(paddle.float32) lr[0] = lr_rate beta1_pow_acc = paddle.ones([1]).astype(paddle.float32) @@ -730,14 +793,16 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( ref_beta2_pow_acc = beta2_pow_acc.astype(paddle.float32) ref_moment_1 = moment1.astype(paddle.float32) ref_moment_2 = moment2.astype(paddle.float32) + ref_moment_2_max = moment2_max.astype(paddle.float32) # reference code - _, _, _, _, _, _ = paddle._C_ops.adamw_( + _, _, _, _, _, _, _ = paddle._C_ops.adamw_( ref_param, main_grad, lr, ref_moment_1, ref_moment_2, + ref_moment_2_max, ref_beta1_pow_acc, ref_beta2_pow_acc, master_weight, @@ -752,15 +817,17 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, False, False, + False, # Currently, xpu NOT support amsgrad. ) if use_main_grad: - _, _, _, _, _, _ = paddle._C_ops.adamw_( + _, _, _, _, _, _, _ = paddle._C_ops.adamw_( param, main_grad, lr, moment1, moment2, + moment2_max, beta1_pow_acc, beta2_pow_acc, master_weight, @@ -775,6 +842,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, find_master, False, + False, # Currently, xpu NOT support amsgrad. ) np.testing.assert_allclose( param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 @@ -783,12 +851,13 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( master_weight.numpy(), ref_param.numpy(), rtol=1e-6 ) else: - _, _, _, _, _, _ = paddle._C_ops.adamw_( + _, _, _, _, _, _, _ = paddle._C_ops.adamw_( param, grad, lr, moment1, moment2, + moment2_max, beta1_pow_acc, beta2_pow_acc, master_weight, @@ -803,6 +872,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, find_master, False, + False, # Currently, xpu NOT support amsgrad. ) np.testing.assert_allclose( param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 diff --git a/test/xpu/test_merged_adam_op_xpu.py b/test/xpu/test_merged_adam_op_xpu.py index 5848db0aabfe66..b8bdda757e6b74 100644 --- a/test/xpu/test_merged_adam_op_xpu.py +++ b/test/xpu/test_merged_adam_op_xpu.py @@ -34,6 +34,7 @@ def run_adam_op( lrs, moment1s, moment2s, + moment2s_max, beta1_pows, beta2_pows, master_params, @@ -43,6 +44,7 @@ def run_adam_op( place, multi_precision=False, use_merged=False, + amsgrad=False, ): assert len(params) == len(grads) assert len(params) == len(lrs) @@ -59,24 +61,27 @@ def run_adam_op( lr_vars = [paddle.to_tensor(l) for l in lrs] moment1_vars = [paddle.to_tensor(m) for m in moment1s] moment2_vars = [paddle.to_tensor(m) for m in moment2s] + moment2_max_vars = [paddle.to_tensor(m) for m in moment2s_max] beta1_pow_vars = [paddle.to_tensor(b) for b in beta1_pows] beta2_pow_vars = [paddle.to_tensor(b) for b in beta2_pows] master_param_vars = [paddle.to_tensor(m_p) for m_p in master_params] if not use_merged: for i in range(len(param_vars)): - _, _, _, _, _, _ = _legacy_C_ops.adam( + _, _, _, _, _, _, _ = _legacy_C_ops.adam( param_vars[i], grad_vars[i], lr_vars[i], moment1_vars[i], moment2_vars[i], + moment2_max_vars[i], beta1_pow_vars[i], beta2_pow_vars[i], master_param_vars[i], param_vars[i], moment1_vars[i], moment2_vars[i], + moment2_max_vars[i], beta1_pow_vars[i], beta2_pow_vars[i], master_param_vars[i], @@ -88,14 +93,16 @@ def run_adam_op( beta2, 'multi_precision', False, + amsgrad, ) else: - _, _, _, _, _, _ = _C_ops.merged_adam_( + _, _, _, _, _, _, _ = _C_ops.merged_adam_( param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars, + moment2_max_vars, beta1_pow_vars, beta2_pow_vars, master_param_vars, @@ -104,12 +111,14 @@ def run_adam_op( epsilon, False, False, + amsgrad, ) outputs = { 'ParamOut': param_vars, 'Moment1Out': moment1_vars, 'Moment2Out': moment2_vars, + 'Moment2MaxOut': moment2_max_vars, 'Beta1PowOut': beta1_pow_vars, 'Beta2PowOut': beta2_pow_vars, 'MasterParamOut': master_param_vars, @@ -131,6 +140,9 @@ def setUp(self): def gen_rand_data(self, shapes, dtype): return [np.random.random(s).astype(dtype) for s in shapes] + def gen_zero_data(self, shapes, dtype): + return [np.zeros(s).astype(dtype) for s in shapes] + def prepare_data(self, shapes, seed): np.random.seed(seed) mp_dtype = np.float32 @@ -141,6 +153,7 @@ def prepare_data(self, shapes, seed): lrs = [learning_rate.copy() for _ in shapes] moment1s = self.gen_rand_data(shapes, mp_dtype) moment2s = self.gen_rand_data(shapes, mp_dtype) + moment2s_max = self.gen_zero_data(shapes, mp_dtype) beta1_pow = self.gen_rand_data([[1]], mp_dtype) beta2_pow = self.gen_rand_data([[1]], mp_dtype) beta1_pows = [beta1_pow.copy() for _ in shapes] @@ -152,6 +165,7 @@ def prepare_data(self, shapes, seed): lrs, moment1s, moment2s, + moment2s_max, beta1_pows, beta2_pows, master_params, @@ -164,6 +178,7 @@ def check_with_place(self): lrs, moment1s, moment2s, + moment2s_max, beta1_pows, beta2_pows, master_params, @@ -176,6 +191,7 @@ def run_op(use_merged, place): lrs=lrs, moment1s=moment1s, moment2s=moment2s, + moment2s_max=moment2s_max, beta1_pows=beta1_pows, beta2_pows=beta2_pows, master_params=master_params, @@ -185,6 +201,7 @@ def run_op(use_merged, place): place=place, multi_precision=False, use_merged=use_merged, + amsgrad=False, # Currently, xpu NOT support amsgrad. ) outs1 = run_op(True, "xpu") From eb5de54ec8c1b2754495e13816c107898f3ba4c6 Mon Sep 17 00:00:00 2001 From: megemini Date: Sat, 7 Sep 2024 15:05:17 +0800 Subject: [PATCH 09/33] [Fix] moment2 max out settting values without amsgrad --- paddle/phi/kernels/funcs/adam_functors.h | 56 +++++++++++-------- paddle/phi/kernels/funcs/jit/refer/refer.h | 33 ++++++----- paddle/phi/kernels/gpu/adam_kernel.cu | 29 ++++++++-- paddle/phi/kernels/gpu/adamw_kernel.cu | 20 +++++-- paddle/phi/kernels/gpu/fused_adam_kernel.cu | 4 +- .../kernels/selected_rows/gpu/adam_kernel.cu | 10 +++- .../kernels/selected_rows/gpu/adamw_kernel.cu | 10 +++- paddle/phi/ops/yaml/op_compat.yaml | 4 +- .../fleet/hybrid_parallel_sharding_model.py | 4 ++ 9 files changed, 112 insertions(+), 58 deletions(-) diff --git a/paddle/phi/kernels/funcs/adam_functors.h b/paddle/phi/kernels/funcs/adam_functors.h index 3143cf65def218..c3d2f6619baa4a 100644 --- a/paddle/phi/kernels/funcs/adam_functors.h +++ b/paddle/phi/kernels/funcs/adam_functors.h @@ -221,6 +221,7 @@ class AdamFunctor { T g = grad_[i]; T mom1 = moment1_[i]; T mom2 = moment2_[i]; + T mom2_max = moment2_max_[i]; T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; @@ -232,18 +233,19 @@ class AdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + T mom2_max_; if (amsgrad_) { - T mom2_max = moment2_max_[i]; - T mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out_[i] = mom2_max_; + mom2_max_ = std::max(mom2, mom2_max); p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); } else { + mom2_max_ = mom2_max; p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); } // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; + moment2_max_out_[i] = mom2_max_; param_out_[i] = p; } }; @@ -310,6 +312,8 @@ class AdamFunctor { moment1_, static_cast(numel)}; Eigen::Map> mom2{ moment2_, static_cast(numel)}; + Eigen::Map> mom2_max{ + moment2_max_, static_cast(numel)}; Eigen::Map> param{ param_, static_cast(numel)}; @@ -319,6 +323,8 @@ class AdamFunctor { moment1_out_, static_cast(numel)}; Eigen::Map> moment2_out{ moment2_out_, static_cast(numel)}; + Eigen::Map> moment2_max_out{ + moment2_max_out_, static_cast(numel)}; T lr = *lr_; T beta1_pow = *beta1_pow_; @@ -331,15 +337,11 @@ class AdamFunctor { moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g; if (amsgrad_) { - Eigen::Map> mom2_max{ - moment2_max_, static_cast(numel)}; - Eigen::Map> moment2_max_out{ - moment2_max_out_, static_cast(numel)}; - moment2_max_out = moment2_out.cwiseMax(mom2_max); param_out = param - lr * (moment1_out / (moment2_max_out.sqrt() + epsilon_ * sqrt(1 - beta2_pow))); } else { + moment2_max_out = mom2_max; param_out = param - lr * (moment1_out / (moment2_out.sqrt() + epsilon_ * sqrt(1 - beta2_pow))); } @@ -427,6 +429,7 @@ class SparseAdamFunctor { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; + MT mom2_max = moment2_max_[i]; MT lr = *lr_; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; @@ -439,15 +442,13 @@ class SparseAdamFunctor { mom1 = beta1_ * mom1 + (static_cast(1.0) - beta1_) * g; mom2 = beta2_ * mom2 + (static_cast(1.0) - beta2_) * g * g; + MT mom2_max_; if (amsgrad_) { - MT mom2_max = moment2_max_[i]; - MT mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out_[i] = mom2_max_; - + mom2_max_ = std::max(mom2, mom2_max); p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); - } else { + mom2_max_ = mom2_max; p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); } @@ -455,6 +456,7 @@ class SparseAdamFunctor { // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; + moment2_max_out_[i] = mom2_max_; param_out_[i] = static_cast(p); if (master_param_out_) { master_param_out_[i] = p; @@ -545,6 +547,7 @@ class SparseAdamFunctor { // The following code is the same as dense T mom1 = moment1_[i]; T mom2 = moment2_[i]; + T mom2_max = moment2_max_[i]; T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; @@ -556,18 +559,19 @@ class SparseAdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + T mom2_max_; if (amsgrad_) { - T mom2_max = moment2_max_[i]; - T mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out_[i] = mom2_max_; + mom2_max_ = std::max(mom2, mom2_max); p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); } else { + mom2_max_ = mom2_max; p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); } // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; + moment2_max_out_[i] = mom2_max_; param_out_[i] = p; } @@ -590,24 +594,26 @@ class SparseAdamFunctor { for (int64_t k = 0; k != row_numel_; ++k) { T mom1 = moment1_[i * row_numel_ + k]; T mom2 = moment2_[i * row_numel_ + k]; + T mom2_max = moment2_max_[i * row_numel_ + k]; T p = param_[i * row_numel_ + k]; mom1 = beta1_ * mom1; mom2 = beta2_ * mom2; + T mom2_max_; if (amsgrad_) { - T mom2_max = moment2_max_[i * row_numel_ + k]; - T mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out_[i * row_numel_ + k] = mom2_max_; - p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_)); + mom2_max_ = std::max(mom2, mom2_max); } else { - p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + mom2_max_ = mom2_max; } + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_)); + // Write back to global memory moment1_out_[i * row_numel_ + k] = mom1; moment2_out_[i * row_numel_ + k] = mom2; + moment2_max_out_[i * row_numel_ + k] = mom2_max_; param_out_[i * row_numel_ + k] = p; } } @@ -731,6 +737,7 @@ class SparseAdamWFunctor { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; + MT mom2_max = moment2_max_[i]; MT lr = *lr_ * lr_ratio_; MT lr_orig = lr; MT beta1_pow = *beta1_pow_; @@ -746,13 +753,13 @@ class SparseAdamWFunctor { p -= lr_orig * coeff_ * p; + MT mom2_max_; if (amsgrad_) { - MT mom2_max = moment2_max_[i]; - MT mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out_[i] = mom2_max_; + mom2_max_ = std::max(mom2, mom2_max); p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); } else { + mom2_max_ = mom2_max; p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); } @@ -760,6 +767,7 @@ class SparseAdamWFunctor { // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; + moment2_max_out_[i] = mom2_max_; param_out_[i] = static_cast(p); if (master_param_out_) { master_param_out_[i] = p; diff --git a/paddle/phi/kernels/funcs/jit/refer/refer.h b/paddle/phi/kernels/funcs/jit/refer/refer.h index 82c17350e7d438..3f194f9b8782b7 100644 --- a/paddle/phi/kernels/funcs/jit/refer/refer.h +++ b/paddle/phi/kernels/funcs/jit/refer/refer.h @@ -536,16 +536,19 @@ void Adam(T beta1, mom2_out_ptr[i] = beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i]; - T mom2; if (amsgrad) { - mom2 = std::max(mom2_out_ptr[i], mom2_max_out_ptr[i]); - mom2_max_out_ptr[i] = mom2; + T mom2_max_ = std::max(mom2_out_ptr[i], mom2_max_ptr[i]); + mom2_max_out_ptr[i] = mom2_max_; + + param_out_ptr[i] = + param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_max_) + eps)); } else { - mom2 = mom2_out_ptr[i]; - } + mom2_max_out_ptr[i] = mom2_max_ptr[i]; - param_out_ptr[i] = - param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2) + eps)); + T mom2_ = mom2_out_ptr[i]; + param_out_ptr[i] = + param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_) + eps)); + } } } @@ -574,15 +577,19 @@ void AdamW(T beta1, mom2_out_ptr[i] = beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i]; - T mom2; if (amsgrad) { - mom2 = std::max(mom2_out_ptr[i], mom2_max_out_ptr[i]); - mom2_max_out_ptr[i] = mom2; + T mom2_max_ = std::max(mom2_out_ptr[i], mom2_max_ptr[i]); + mom2_max_out_ptr[i] = mom2_max_; + + param_out_ptr[i] = + param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2_max_) + eps)); } else { - mom2 = mom2_out_ptr[i]; - } + mom2_max_out_ptr[i] = mom2_max_ptr[i]; - param_out_ptr[i] = param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2) + eps)); + T mom2_ = mom2_out_ptr[i]; + param_out_ptr[i] = + param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2_) + eps)); + } } } diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 2b68d917085ccf..aa0002a4dee6e3 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -61,18 +61,21 @@ __global__ void AdamKernelREG(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + MT mom2_max = static_cast(moment2_max[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT mom2_max_; MT denom; if (amsgrad) { - MT mom2_max = static_cast(moment2_max[id]); - MT mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out[id] = mom2_max_; + mom2_max_ = std::max(mom2, mom2_max); + denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { + mom2_max_ = mom2_max; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -80,6 +83,7 @@ __global__ void AdamKernelREG(MT beta1, moment1_out[id] = mom1; moment2_out[id] = mom2; + moment2_max_out[id] = mom2_max_; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -118,18 +122,21 @@ __global__ void AdamKernelMEM(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + MT mom2_max = static_cast(moment2_max[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT mom2_max_; MT denom; if (amsgrad) { - MT mom2_max = static_cast(moment2_max[id]); - MT mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out[id] = mom2_max_; + mom2_max_ = std::max(mom2, mom2_max); + denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { + mom2_max_ = mom2_max; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -137,6 +144,7 @@ __global__ void AdamKernelMEM(MT beta1, moment1_out[id] = mom1; moment2_out[id] = mom2; + moment2_max_out[id] = mom2_max_; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -186,6 +194,7 @@ void AdamDenseKernel(const Context& dev_ctx, const auto grad_type = grad.dtype(); VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + VLOG(4) << "amsgrad: " << amsgrad; bool skip_update_ = false; if (skip_update.is_initialized()) { @@ -246,6 +255,8 @@ void AdamDenseKernel(const Context& dev_ctx, if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { // Compute with betapow in REG if (grad_type == phi::DataType::FLOAT32) { + VLOG(3) << "--> AdamKernelREG grad_type == phi::DataType::FLOAT32"; + AdamKernelREG <<>>( beta1_, @@ -268,6 +279,8 @@ void AdamDenseKernel(const Context& dev_ctx, param.numel(), amsgrad); } else { + VLOG(3) << "--> AdamKernelREG"; + AdamKernelREG<<>>( beta1_, beta2_, @@ -298,6 +311,8 @@ void AdamDenseKernel(const Context& dev_ctx, } } else { if (grad_type == phi::DataType::FLOAT32) { + VLOG(3) << "--> AdamKernelMEM grad_type == phi::DataType::FLOAT32"; + AdamKernelMEM <<>>( beta1_, @@ -320,6 +335,8 @@ void AdamDenseKernel(const Context& dev_ctx, param.numel(), amsgrad); } else { + VLOG(3) << "--> AdamKernelMEM"; + AdamKernelMEM<<>>( beta1_, beta2_, diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 5df2568d4f12a3..322bd9491e10eb 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -64,20 +64,23 @@ __global__ void AdamWKernelREG(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + MT mom2_max = static_cast(moment2_max[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT mom2_max_; MT denom; if (amsgrad) { - MT mom2_max = static_cast(moment2_max[id]); - MT mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out[id] = mom2_max_; + mom2_max_ = std::max(mom2, mom2_max); + denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { + mom2_max_ = mom2_max; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -85,6 +88,7 @@ __global__ void AdamWKernelREG(MT beta1, moment1_out[id] = mom1; moment2_out[id] = mom2; + moment2_max_out[id] = mom2_max_; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -125,20 +129,23 @@ __global__ void AdamWKernelMEM(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + MT mom2_max = static_cast(moment2_max[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT mom2_max_; MT denom; if (amsgrad) { - MT mom2_max = static_cast(moment2_max[id]); - MT mom2_max_ = std::max(mom2, mom2_max); - moment2_max_out[id] = mom2_max_; + mom2_max_ = std::max(mom2, mom2_max); + denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { + mom2_max_ = mom2_max; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -146,6 +153,7 @@ __global__ void AdamWKernelMEM(MT beta1, moment1_out[id] = mom1; moment2_out[id] = mom2; + moment2_max_out[id] = mom2_max_; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; diff --git a/paddle/phi/kernels/gpu/fused_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu index 094768c3e7caff..9c5ecc7eec17d4 100644 --- a/paddle/phi/kernels/gpu/fused_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -217,6 +217,7 @@ struct FusedAdamFunctor { MT beta2) { MT mom1 = static_cast(mom1_ptr[0]); MT mom2 = static_cast(mom2_ptr[0]); + MT mom2_max = static_cast(mom2_max_ptr[0]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; @@ -225,8 +226,9 @@ struct FusedAdamFunctor { mom2_ptr[0] = mom2; if (AMSGrad) { - MT mom2_max = static_cast(mom2_max_ptr[0]); mom2_max_ptr[0] = std::max(mom2, mom2_max); + } else { + mom2_max_ptr[0] = mom2_max; } } diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index c7c3a4e14c1a03..4c90e7711d147f 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -74,6 +74,7 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, } else { MT mom1 = mom1_[id]; MT mom2 = mom2_[id]; + MT mom2_max = mom2_max_[id]; MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 ? static_cast(grad_[row_idx * row_numel + id % row_numel]) @@ -81,14 +82,16 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT moment2_max_; MT denom; if (amsgrad) { - MT mom2_max = mom2_max_[id]; - MT moment2_max_ = std::max(mom2, mom2_max); - mom2_max_out_[id] = moment2_max_; + moment2_max_ = std::max(mom2, mom2_max); + denom = (sqrt(moment2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { + moment2_max_ = mom2_max; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -97,6 +100,7 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, // Write back to global memory mom1_out_[id] = mom1; mom2_out_[id] = mom2; + mom2_max_out_[id] = moment2_max_; param_out_[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index a428d98d2c01a0..73869d1146cf77 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -80,6 +80,7 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, } else { MT mom1 = static_cast(mom1_[id]); MT mom2 = static_cast(mom2_[id]); + MT mom2_max = static_cast(mom2_max_[id]); MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 @@ -91,14 +92,16 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT mom2_max_; MT denom; if (amsgrad) { - MT mom2_max = static_cast(mom2_max_[id]); - MT mom2_max_ = std::max(mom2, mom2_max); - mom2_max_out_[id] = mom2_max_; + mom2_max_ = std::max(mom2, mom2_max); + denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { + mom2_max_ = mom2_max; + denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -107,6 +110,7 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, // Write back to global memory mom1_out_[id] = mom1; mom2_out_[id] = mom2; + mom2_max_out_[id] = mom2_max_; param_out_[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index f9a14db20273d9..3ada85199c4bc9 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -81,9 +81,9 @@ - op : adamw_ (adamw) inputs : - {param: Param, grad: Grad, learning_rate: LearningRate, moment1: Moment1, moment2: Moment2, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, master_param: MasterParam, skip_update: SkipUpdate} + {param: Param, grad: Grad, learning_rate: LearningRate, moment1: Moment1, moment2: Moment2, moment2_max: Moment2Max, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, master_param: MasterParam, skip_update: SkipUpdate} outputs : - {param_out: ParamOut, moment1_out: Moment1Out, moment2_out: Moment2Out, beta1_pow_out: Beta1PowOut, beta2_pow_out: Beta2PowOut, master_param_out: MasterParamOut} + {param_out: ParamOut, moment1_out: Moment1Out, moment2_out: Moment2Out, moment2_max_out: Moment2MaxOut, beta1_pow_out: Beta1PowOut, beta2_pow_out: Beta2PowOut, master_param_out: MasterParamOut} scalar : beta1 : data_type : float diff --git a/test/collective/fleet/hybrid_parallel_sharding_model.py b/test/collective/fleet/hybrid_parallel_sharding_model.py index ce7c518a39c741..51ab3a43aad93a 100644 --- a/test/collective/fleet/hybrid_parallel_sharding_model.py +++ b/test/collective/fleet/hybrid_parallel_sharding_model.py @@ -321,19 +321,23 @@ def test_sharding_adam(self): sharded_accumulators = { 'embedding_0.w_0_beta2_pow_acc_0', 'linear_1.b_0_moment2_0', + 'linear_1.b_0_moment2_max_0', 'linear_0.b_0_beta1_pow_acc_0', 'linear_0.b_0_beta2_pow_acc_0', 'linear_1.b_0_moment1_0', 'linear_2.b_0_beta2_pow_acc_0', 'linear_2.b_0_moment2_0', + 'linear_2.b_0_moment2_max_0', 'embedding_0.w_0_moment1_0', 'embedding_0.w_0_beta1_pow_acc_0', 'linear_0.b_0_moment2_0', + 'linear_0.b_0_moment2_max_0', 'linear_2.b_0_moment1_0', 'linear_0.b_0_moment1_0', 'linear_1.b_0_beta2_pow_acc_0', 'linear_1.b_0_beta1_pow_acc_0', 'embedding_0.w_0_moment2_0', + 'embedding_0.w_0_moment2_max_0', 'linear_2.b_0_beta1_pow_acc_0', } self.sharding_model( From 7aa9d60d1955fde59b58c67a957d65795ef59b23 Mon Sep 17 00:00:00 2001 From: megemini Date: Sat, 7 Sep 2024 20:10:07 +0800 Subject: [PATCH 10/33] [Update] unittest passed for adam and adamw --- paddle/phi/kernels/gpu/adam_kernel.cu | 8 - paddle/phi/kernels/gpu/adamw_kernel.cu | 1 + test/auto_parallel/test_api_dist_branch.py | 11 ++ test/legacy_test/test_adam_op.py | 188 +++++++++++++++++---- test/legacy_test/test_adamw_op.py | 150 +++++++++++++--- test/xpu/test_adam_op_xpu.py | 1 + 6 files changed, 299 insertions(+), 60 deletions(-) diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index aa0002a4dee6e3..d04d3ef1bd228b 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -255,8 +255,6 @@ void AdamDenseKernel(const Context& dev_ctx, if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { // Compute with betapow in REG if (grad_type == phi::DataType::FLOAT32) { - VLOG(3) << "--> AdamKernelREG grad_type == phi::DataType::FLOAT32"; - AdamKernelREG <<>>( beta1_, @@ -279,8 +277,6 @@ void AdamDenseKernel(const Context& dev_ctx, param.numel(), amsgrad); } else { - VLOG(3) << "--> AdamKernelREG"; - AdamKernelREG<<>>( beta1_, beta2_, @@ -311,8 +307,6 @@ void AdamDenseKernel(const Context& dev_ctx, } } else { if (grad_type == phi::DataType::FLOAT32) { - VLOG(3) << "--> AdamKernelMEM grad_type == phi::DataType::FLOAT32"; - AdamKernelMEM <<>>( beta1_, @@ -335,8 +329,6 @@ void AdamDenseKernel(const Context& dev_ctx, param.numel(), amsgrad); } else { - VLOG(3) << "--> AdamKernelMEM"; - AdamKernelMEM<<>>( beta1_, beta2_, diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 322bd9491e10eb..141b23216097cd 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -208,6 +208,7 @@ void AdamwDenseKernel(const Context& dev_ctx, VLOG(4) << "multi_precision: " << multi_precision; VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + VLOG(4) << "amsgrad:" << amsgrad; MPDType coeff_ = static_cast(coeff); MPDType lr_ratio_ = static_cast(lr_ratio); diff --git a/test/auto_parallel/test_api_dist_branch.py b/test/auto_parallel/test_api_dist_branch.py index f01bf2171fc637..997699d956518a 100644 --- a/test/auto_parallel/test_api_dist_branch.py +++ b/test/auto_parallel/test_api_dist_branch.py @@ -307,6 +307,7 @@ def test_merged_adam_for_dist_tensor(self): lrs = [np.random.random(s).astype(mp_dtype) for s in lr_shape] moment1s = [np.random.random(s).astype(mp_dtype) for s in shapes] moment2s = [np.random.random(s).astype(mp_dtype) for s in shapes] + moment2s_max = [np.zeros(s).astype(mp_dtype) for s in shapes] beta1_pows = [np.random.random(s).astype(mp_dtype) for s in lr_shape] beta2_pows = [np.random.random(s).astype(mp_dtype) for s in lr_shape] master_params = [p.astype(mp_dtype) for p in params] @@ -326,6 +327,10 @@ def test_merged_adam_for_dist_tensor(self): local_moment2s, dist_moment2s, ) = self.create_local_and_dist_tensor_list_pair(moment2s) + ( + local_moment2s_max, + dist_moment2s_max, + ) = self.create_local_and_dist_tensor_list_pair(moment2s_max) ( local_beta1_pows, dist_beta1_pows, @@ -343,6 +348,7 @@ def test_merged_adam_for_dist_tensor(self): local_param_out, local_moment1s_out, local_moment2s_out, + local_moment2s_max_out, local_beta1_pow_out, local_beta2_pow_out, local_master_param_out, @@ -352,6 +358,7 @@ def test_merged_adam_for_dist_tensor(self): local_lrs, local_moment1s, local_moment2s, + local_moment2s_max, local_beta1_pows, local_beta2_pows, local_master_params, @@ -360,12 +367,14 @@ def test_merged_adam_for_dist_tensor(self): epsilon, True, False, + False, ) ( dist_param_out, dist_moment1s_out, dist_moment2s_out, + dist_moment2s_max_out, dist_beta1_pow_out, dist_beta2_pow_out, dist_master_param_out, @@ -375,6 +384,7 @@ def test_merged_adam_for_dist_tensor(self): dist_lrs, dist_moment1s, dist_moment2s, + dist_moment2s_max, dist_beta1_pows, dist_beta2_pows, dist_master_params, @@ -383,6 +393,7 @@ def test_merged_adam_for_dist_tensor(self): epsilon, True, False, + False, ) for i in range(len(local_param_out)): self.check_tensor_eq(local_param_out[i], dist_param_out[i]) diff --git a/test/legacy_test/test_adam_op.py b/test/legacy_test/test_adam_op.py index b781dd3bf1263b..0ef28710d29a0f 100644 --- a/test/legacy_test/test_adam_op.py +++ b/test/legacy_test/test_adam_op.py @@ -64,6 +64,9 @@ def adam_wrapper( class TestAdamOp1(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test Adam Op with supplied attributes''' self.op_type = "adam" @@ -82,7 +85,7 @@ def setUp(self): epsilon = 1e-4 beta1_pow = beta1**10 beta2_pow = beta2**10 - amsgrad = False + self.set_amsgrad() self.inputs = { 'Param': param, @@ -99,7 +102,7 @@ def setUp(self): 'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2, - 'amsgrad': amsgrad, + 'amsgrad': self.amsgrad, } param_out, moment1_out, moment2_out, moment2_max_out = adam_step( @@ -119,10 +122,18 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestAdamOp1AMSGrad(TestAdamOp1): + def set_amsgrad(self): + self.amsgrad = True + + class TestAdamOp2(OpTest): def set_shape(self): self.shape = (102, 105) + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test Adam Op with supplied attributes''' self.op_type = "adam" @@ -142,7 +153,7 @@ def setUp(self): epsilon = 1e-8 beta1_pow = beta1**10 beta2_pow = beta2**10 - amsgrad = False + self.set_amsgrad() self.inputs = { 'Param': param, @@ -155,15 +166,15 @@ def setUp(self): 'Beta2Pow': np.array([beta2_pow]).astype("float32"), } - attributes = { + self.attrs = { 'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2, - 'amsgrad': amsgrad, + 'amsgrad': self.amsgrad, } param_out, moment1_out, moment2_out, moment2_max_out = adam_step( - self.inputs, attributes + self.inputs, self.attrs ) self.outputs = { @@ -184,7 +195,15 @@ def set_shape(self): self.shape = 3 +class TestAdamOp2AMSGrad(TestAdamOp2): + def set_amsgrad(self): + self.amsgrad = True + + class TestAdamOpMultipleSteps(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test Adam Operator with supplied attributes''' self.op_type = "adam" @@ -205,7 +224,7 @@ def setUp(self): epsilon = 1e-8 self.beta1_pow = self.beta1**10 self.beta2_pow = self.beta2**10 - self.amsgrad = False + self.set_amsgrad() self.inputs = { 'Param': param, @@ -261,12 +280,17 @@ def test_check_output(self): ) +class TestAdamOpMultipleStepsAMSGrad(TestAdamOpMultipleSteps): + def set_amsgrad(self): + self.amsgrad = True + + def adam_step(inputs, attributes): ''' Simulate one step of the adam optimizer :param inputs: dict of inputs :param attributes: dict of attributes - :return tuple: tuple of output param, moment1, moment2, + :return tuple: tuple of output param, moment1, moment2, moment2_max beta1 power accumulator and beta2 power accumulator ''' param = inputs['Param'] @@ -315,7 +339,7 @@ def adamw_step(inputs, attributes): Simulate one step of the adam optimizer :param inputs: dict of inputs :param attributes: dict of attributes - :return tuple: tuple of output param, moment1, moment2, + :return tuple: tuple of output param, moment1, moment2, moment2_max, beta1 power accumulator and beta2 power accumulator ''' param = inputs['Param'] @@ -370,7 +394,7 @@ def adam_step_sparse( Simulate one step of the adam optimizer :param inputs: dict of inputs :param attributes: dict of attributes - :return tuple: tuple of output param, moment1, moment2, + :return tuple: tuple of output param, moment1, moment2, moment2_max, beta1 power accumulator and beta2 power accumulator ''' param = inputs['Param'] @@ -410,6 +434,7 @@ def update_row(row_id, update_value): / (np.sqrt(moment2_max_out[row_id]) + epsilon) ) else: + moment2_max_out[row_id] = np.zeros_like(moment2_out[row_id]) param_out[row_id] = param[row_id] - lr_t * ( moment1_out[row_id] / (np.sqrt(moment2_out[row_id]) + epsilon) ) @@ -428,12 +453,16 @@ def update_row(row_id, update_value): class TestSparseAdamOp(unittest.TestCase): + def set_amsgrad(self): + self.amsgrad = False + def setup(self, scope, place, lazy_mode): beta1 = 0.78 beta2 = 0.836 epsilon = 1e-4 beta1_pow = np.array([beta1**10]).astype("float32") beta2_pow = np.array([beta2**10]).astype("float32") + self.set_amsgrad() height = 10 rows = [0, 4, 7] @@ -455,7 +484,7 @@ def setup(self, scope, place, lazy_mode): 'beta1': beta1, 'beta2': beta2, 'min_row_size_to_use_multithread': 2, - 'amsgrad': False, + 'amsgrad': self.amsgrad, } grad_selected_rows = scope.var('Grad').get_selected_rows() @@ -535,7 +564,15 @@ def test_sparse_adam(self): self.check_with_place(place, lazy_mode) +class TestSparseAdamOpAMSGrad(TestSparseAdamOp): + def set_amsgrad(self): + self.amsgrad = True + + class TestAdamOpBetaVariable(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test Adam Op with beta as Variable''' self.op_type = "adam" @@ -555,6 +592,7 @@ def setUp(self): epsilon = 1e-8 beta1_pow = beta1**10 beta2_pow = beta2**10 + self.set_amsgrad() self.inputs = { 'Param': param, @@ -569,10 +607,10 @@ def setUp(self): "Beta2Tensor": np.array([beta2]).astype("float32"), } - attributes = {'epsilon': epsilon, 'amsgrad': False} + self.attrs = {'epsilon': epsilon, 'amsgrad': self.amsgrad} param_out, moment1_out, moment2_out, moment2_max_out = adam_step( - self.inputs, attributes + self.inputs, self.attrs ) self.outputs = { @@ -588,7 +626,15 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestAdamOpBetaVariableAMSGrad(TestAdamOpBetaVariable): + def set_amsgrad(self): + self.amsgrad = True + + class TestAdamOpBetaEpsilonVariable(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test Adam Op with beta/epsilon as Variable''' self.op_type = "adam" @@ -608,6 +654,7 @@ def setUp(self): epsilon = 1e-8 beta1_pow = beta1**10 beta2_pow = beta2**10 + self.set_amsgrad() self.inputs = { 'Param': param, @@ -623,10 +670,10 @@ def setUp(self): "EpsilonTensor": np.array([epsilon]).astype("float32"), } - attributes = {'epsilon': epsilon, 'amsgrad': False} + self.attrs = {'epsilon': epsilon, 'amsgrad': self.amsgrad} param_out, moment1_out, moment2_out, moment2_max_out = adam_step( - self.inputs, attributes + self.inputs, self.attrs ) self.outputs = { @@ -642,7 +689,15 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestAdamOpBetaEpsilonVariableAMSGrad(TestAdamOpBetaEpsilonVariable): + def set_amsgrad(self): + self.amsgrad = True + + class TestAdamOpWithGlobalBetaPow(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test Adam Op with global_beta_pow''' self.op_type = "adam" @@ -662,6 +717,7 @@ def setUp(self): epsilon = 1e-8 beta1_pow = beta1**10 beta2_pow = beta2**10 + self.set_amsgrad() self.inputs = { 'Param': param, @@ -677,14 +733,16 @@ def setUp(self): "EpsilonTensor": np.array([epsilon]).astype("float32"), } - attributes = {'epsilon': epsilon, 'amsgrad': False} + self.attrs = { + 'use_global_beta_pow': True, + 'epsilon': epsilon, + 'amsgrad': self.amsgrad, + } param_out, moment1_out, moment2_out, moment2_max_out = adam_step( - self.inputs, attributes + self.inputs, self.attrs ) - self.attrs = {'use_global_beta_pow': True} - # use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty. self.outputs = { 'Moment1Out': moment1_out, @@ -699,7 +757,15 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestAdamOpWithGlobalBetaPowAMSGrad(TestAdamOpWithGlobalBetaPow): + def set_amsgrad(self): + self.amsgrad = True + + class TestAdamOpWithSkipUpdate(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test Adam Op with global_beta_pow''' self.op_type = "adam" @@ -719,6 +785,7 @@ def setUp(self): epsilon = 1e-8 beta1_pow = beta1**10 beta2_pow = beta2**10 + self.set_amsgrad() self.inputs = { 'Param': param, @@ -735,9 +802,11 @@ def setUp(self): "SkipUpdate": np.array([True]).astype("bool"), } - attributes = {'epsilon': epsilon, 'amsgrad': False} - - self.attrs = {'use_global_beta_pow': True, 'amsgrad': False} + self.attrs = { + 'use_global_beta_pow': True, + 'epsilon': epsilon, + 'amsgrad': self.amsgrad, + } # use_global_beta_pow=True, Beta1PowOut and Beta2PowOut are empty. self.outputs = { @@ -753,7 +822,15 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestAdamOpWithSkipUpdateAMSGrad(TestAdamOpWithSkipUpdate): + def set_amsgrad(self): + self.amsgrad = True + + class TestAdamOpV2(unittest.TestCase): + def setUp(self): + self.amsgrad = False + def test_pir_adam_op(self): with paddle.pir_utils.IrGuard(): place = base.CPUPlace() @@ -785,6 +862,7 @@ def test_pir_adam_op(self): beta2=beta2, weight_decay=0.01, epsilon=1e-8, + amsgrad=self.amsgrad, ) opt.minimize(loss) @@ -802,7 +880,9 @@ def test_adam_op_dygraph(self): linear = paddle.nn.Linear(13, 5) adam = paddle.optimizer.Adam( - learning_rate=0.01, parameters=linear.parameters() + learning_rate=0.01, + parameters=linear.parameters(), + amsgrad=self.amsgrad, ) out = linear(a) out.backward() @@ -814,7 +894,9 @@ def test_adam_op_with_state_dict(self): paddle.disable_static() emb = paddle.nn.Embedding(10, 10) - adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters()) + adam = paddle.optimizer.Adam( + 0.001, parameters=emb.parameters(), amsgrad=self.amsgrad + ) state_dict = adam.state_dict() adam.set_state_dict(state_dict) @@ -826,6 +908,7 @@ def test_adam_op_with_state_dict(self): learning_rate=learning_rate, weight_decay=paddle.regularizer.L2Decay(0.001), parameters=emb.parameters(), + amsgrad=self.amsgrad, ) lr = adam.get_lr() state_dict = adam.state_dict() @@ -836,7 +919,9 @@ def test_adam_op_with_state_dict(self): learning_rate = np.array([0.01]).astype("float32") learning_rate = paddle.to_tensor(learning_rate) adam = paddle.optimizer.Adam( - learning_rate=learning_rate, parameters=emb.parameters() + learning_rate=learning_rate, + parameters=emb.parameters(), + amsgrad=self.amsgrad, ) params = adam.get_opti_var_name_list() @@ -850,7 +935,10 @@ def test_adam_with_grad_clip(self): linear = paddle.nn.Linear(13, 5) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) adam = paddle.optimizer.Adam( - 0.1, parameters=linear.parameters(), grad_clip=clip + 0.1, + parameters=linear.parameters(), + grad_clip=clip, + amsgrad=self.amsgrad, ) out = linear(a) out.backward() @@ -861,7 +949,9 @@ def test_adam_with_grad_clip(self): def test_adam_op_with_set_lr(self): paddle.disable_static() linear = paddle.nn.Linear(10, 10) - adam = paddle.optimizer.Adam(0.1, parameters=linear.parameters()) + adam = paddle.optimizer.Adam( + 0.1, parameters=linear.parameters(), amsgrad=self.amsgrad + ) lr = 0.01 adam.set_lr(lr) @@ -879,15 +969,24 @@ def test_adam_op_invalid_input(self): linear = paddle.nn.Linear(10, 10) with self.assertRaises(ValueError): adam = paddle.optimizer.Adam( - 0.1, beta1=-1, parameters=linear.parameters() + 0.1, + beta1=-1, + parameters=linear.parameters(), + amsgrad=self.amsgrad, ) with self.assertRaises(ValueError): adam = paddle.optimizer.Adam( - 0.1, beta2=-1, parameters=linear.parameters() + 0.1, + beta2=-1, + parameters=linear.parameters(), + amsgrad=self.amsgrad, ) with self.assertRaises(ValueError): adam = paddle.optimizer.Adam( - 0.1, epsilon=-1, parameters=linear.parameters() + 0.1, + epsilon=-1, + parameters=linear.parameters(), + amsgrad=self.amsgrad, ) paddle.enable_static() @@ -897,7 +996,10 @@ def test_adam_op_with_sparse_input_and_weight_decay(self): x = paddle.to_tensor(x_data, stop_gradient=False) emb = paddle.nn.Embedding(10, 10, sparse=True) adam = paddle.optimizer.Adam( - 0.001, parameters=emb.parameters(), weight_decay=0.01 + 0.001, + parameters=emb.parameters(), + weight_decay=0.01, + amsgrad=self.amsgrad, ) with self.assertRaises(RuntimeError): @@ -907,6 +1009,11 @@ def test_adam_op_with_sparse_input_and_weight_decay(self): paddle.enable_static() +class TestAdamOpV2AMSGrad(TestAdamOpV2): + def setUp(self): + self.amsgrad = True + + class TestAdamOpV2Group(TestAdamOpV2): def test_adam_op(self): paddle.disable_static() @@ -935,7 +1042,15 @@ def test_adam_op(self): adam.clear_gradients() +class TestAdamOpV2GroupAMSGrad(TestAdamOpV2Group): + def setUp(self): + self.amsgrad = True + + class TestMultiTensorAdam(unittest.TestCase): + def setUp(self): + self.amsgrad = False + def _adam_optimize_dygraph( self, place, @@ -965,6 +1080,7 @@ def _adam_optimize_dygraph( parameters=model.parameters(), use_multi_tensor=use_multi_tensor, multi_precision=use_amp, + amsgrad=self.amsgrad, ) else: parameters = list(model.parameters()) @@ -986,6 +1102,7 @@ def _adam_optimize_dygraph( ], use_multi_tensor=use_multi_tensor, multi_precision=use_amp, + amsgrad=self.amsgrad, ) for idx in range(2): @@ -1022,7 +1139,9 @@ def _adam_optimize_static( train_program = paddle.static.Program() startup_program = paddle.static.Program() optimizer = paddle.optimizer.Adam( - multi_precision=use_amp, use_multi_tensor=use_multi_tensor + multi_precision=use_amp, + use_multi_tensor=use_multi_tensor, + amsgrad=self.amsgrad, ) with paddle.static.program_guard(train_program, startup_program): @@ -1155,7 +1274,10 @@ def test_pir_main(self): self._check_with_place_amp(place, use_amp) -# TODO(megemini): AMSGrad +class TestMultiTensorAdamAMSGrad(TestMultiTensorAdam): + def setUp(self): + self.amsgrad = True + if __name__ == "__main__": paddle.enable_static() diff --git a/test/legacy_test/test_adamw_op.py b/test/legacy_test/test_adamw_op.py index 7d1512a6ef0e67..49d6517080175f 100644 --- a/test/legacy_test/test_adamw_op.py +++ b/test/legacy_test/test_adamw_op.py @@ -117,6 +117,9 @@ def adamw_wrapper( class TestAdamW(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test AdamW Op with supplied attributes''' self.op_type = "adamw" @@ -135,6 +138,7 @@ def setUp(self): epsilon = 1e-4 beta1_pow = beta1**10 beta2_pow = beta2**10 + self.set_amsgrad() self.inputs = { 'Param': param, @@ -153,7 +157,7 @@ def setUp(self): 'beta2': beta2, "coeff": 0.5, "with_decay": True, - "amsgrad": False, + "amsgrad": self.amsgrad, } param_out, moment1_out, moment2_out, moment2_max_out = adamw_step( @@ -173,10 +177,18 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestAdamWAMSGrad(TestAdamW): + def set_amsgrad(self): + self.amsgrad = True + + @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) class TestAdamW2(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): '''Test AdamW Op with supplied attributes''' self.op_type = "adamw" @@ -195,6 +207,7 @@ def setUp(self): epsilon = 1e-4 beta1_pow = beta1**10 beta2_pow = beta2**10 + self.set_amsgrad() self.inputs = { 'Param': param, @@ -214,7 +227,7 @@ def setUp(self): "lr_ratio": 0.1, "coeff": 0.5, "with_decay": True, - "amsgrad": False, + "amsgrad": self.amsgrad, } param_out, moment1_out, moment2_out, moment2_max_out = adamw_step( @@ -234,7 +247,15 @@ def test_check_output(self): self.check_output_with_place(core.CUDAPlace(0), check_pir=True) +class TestAdamW2AMSGrad(TestAdamW2): + def set_amsgrad(self): + self.amsgrad = True + + class TestAdamWOp(unittest.TestCase): + def setUp(self): + self.amsgrad = False + def test_adamw_op_dygraph(self): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") @@ -245,6 +266,7 @@ def test_adamw_op_dygraph(self): parameters=linear.parameters(), apply_decay_param_fun=lambda name: True, weight_decay=0.01, + amsgrad=self.amsgrad, ) for _ in range(2): @@ -294,6 +316,7 @@ def test_adamw_op(self): beta2=beta2, weight_decay=0.01, epsilon=1e-8, + amsgrad=self.amsgrad, ) opt.minimize(loss) @@ -313,6 +336,7 @@ def test_adamw_op_dygraph_bypassing_step(self): parameters=linear.parameters(), apply_decay_param_fun=lambda name: True, weight_decay=0.01, + amsgrad=self.amsgrad, ) os.environ["FLAGS_shard_bypass_dygraph_optimizer"] = "1" for _ in range(2): @@ -331,6 +355,7 @@ def test_adamw_op_coverage(self): parameters=linear.parameters(), apply_decay_param_fun=lambda name: True, weight_decay=0.01, + amsgrad=self.amsgrad, ) assert adam.__str__() is not None @@ -365,6 +390,7 @@ def test_pir_adam_op(self): beta2=beta2, weight_decay=0.01, epsilon=1e-8, + amsgrad=self.amsgrad, ) opt.minimize(loss) @@ -380,18 +406,32 @@ def test_adamw_op_invalid_input(self): linear = paddle.nn.Linear(10, 10) with self.assertRaises(ValueError): adam = paddle.optimizer.AdamW( - 0.1, beta1=-1, parameters=linear.parameters() + 0.1, + beta1=-1, + parameters=linear.parameters(), + amsgrad=self.amsgrad, ) with self.assertRaises(ValueError): adam = paddle.optimizer.AdamW( - 0.1, beta2=-1, parameters=linear.parameters() + 0.1, + beta2=-1, + parameters=linear.parameters(), + amsgrad=self.amsgrad, ) with self.assertRaises(ValueError): adam = paddle.optimizer.AdamW( - 0.1, epsilon=-1, parameters=linear.parameters() + 0.1, + epsilon=-1, + parameters=linear.parameters(), + amsgrad=self.amsgrad, ) +class TestAdamWOpAMSGrad(TestAdamWOp): + def setUp(self): + self.amsgrad = True + + class TestAdamWOpGroup(TestAdamWOp): def test_adamw_op_dygraph(self): paddle.disable_static() @@ -407,6 +447,7 @@ def test_adamw_op_dygraph(self): ], apply_decay_param_fun=lambda name: True, weight_decay=0.01, + amsgrad=self.amsgrad, ) for _ in range(2): @@ -430,6 +471,7 @@ def test_adamw_op_dygraph_bypassing_step(self): ], apply_decay_param_fun=lambda name: True, weight_decay=0.01, + amsgrad=self.amsgrad, ) os.environ["FLAGS_shard_bypass_dygraph_optimizer"] = "1" @@ -441,7 +483,15 @@ def test_adamw_op_dygraph_bypassing_step(self): adam.clear_gradients() +class TestAdamWOpGroupAMSGrad(TestAdamWOpGroup): + def setUp(self): + self.amsgrad = True + + class TestAdamWOpMultiPrecisionWithMainGrad(unittest.TestCase): + def setUp(self): + self.amsgrad = False + def _test_adamw_op_dygraph_place_amp_with_maingrad( self, place, shape, use_main_grad ): @@ -507,7 +557,7 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, False, False, - False, + self.amsgrad, ) if use_main_grad: @@ -532,14 +582,21 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, find_master, False, - False, + self.amsgrad, ) np.testing.assert_allclose( param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 ) - np.testing.assert_allclose( - master_weight.numpy(), ref_param.numpy(), rtol=1e-6 - ) + + if self.amsgrad: + np.testing.assert_allclose( + master_weight.numpy(), ref_param.numpy(), rtol=1e-4 + ) + else: + np.testing.assert_allclose( + master_weight.numpy(), ref_param.numpy(), rtol=1e-6 + ) + else: _, _, _, _, _, _, _ = paddle._C_ops.adamw_( param, @@ -562,14 +619,20 @@ def _test_adamw_op_dygraph_place_amp_with_maingrad( 1000, find_master, False, - False, + self.amsgrad, ) np.testing.assert_allclose( param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 ) - np.testing.assert_allclose( - master_weight.numpy(), ref_param.numpy(), rtol=1e-6 - ) + + if self.amsgrad: + np.testing.assert_allclose( + master_weight.numpy(), ref_param.numpy(), rtol=1e-4 + ) + else: + np.testing.assert_allclose( + master_weight.numpy(), ref_param.numpy(), rtol=1e-6 + ) def _get_places(self): places = [] @@ -588,7 +651,17 @@ def test_main(self): ) +class TestAdamWOpMultiPrecisionWithMainGradAMSGrad( + TestAdamWOpMultiPrecisionWithMainGrad +): + def setUp(self): + self.amsgrad = True + + class TestAdamWOpMultiPrecision(unittest.TestCase): + def setUp(self): + self.amsgrad = False + def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False): paddle.disable_static() paddle.seed(10) @@ -608,6 +681,7 @@ def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False): } ], multi_precision=use_amp, + amsgrad=self.amsgrad, ) for idx in range(2): @@ -649,7 +723,15 @@ def test_main(self): self._test_adamw_op_dygraph_place_amp(place, use_amp) +class TestAdamWOpMultiPrecisionAMSGrad(TestAdamWOpMultiPrecision): + def setUp(self): + self.amsgrad = True + + class TestAdamWOpError(unittest.TestCase): + def setUp(self): + self.amsgrad = False + def test_api_errors(self): def test_weight_decay_dtype(): linear = paddle.nn.Linear(13, 5) @@ -657,6 +739,7 @@ def test_weight_decay_dtype(): learning_rate=0.01, parameters=linear.parameters(), weight_decay=1, + amsgrad=self.amsgrad, ) def test_parameters_dtype1(): @@ -664,6 +747,7 @@ def test_parameters_dtype1(): learning_rate=0.01, parameters=paddle.randn((5, 5)), weight_decay=0.1, + amsgrad=self.amsgrad, ) def test_parameters_dtype2(): @@ -672,11 +756,15 @@ def test_parameters_dtype2(): learning_rate=0.01, parameters={'params': linear.parameters()}, weight_decay=0.1, + amsgrad=self.amsgrad, ) def test_parameters_dtype3(): adam = paddle.optimizer.AdamW( - learning_rate=0.01, parameters=None, weight_decay=0.1 + learning_rate=0.01, + parameters=None, + weight_decay=0.1, + amsgrad=self.amsgrad, ) def test_parameters_dtype4(): @@ -685,6 +773,7 @@ def test_parameters_dtype4(): learning_rate=0.01, parameters={'params': set(linear.parameters())}, weight_decay=0.1, + amsgrad=self.amsgrad, ) def test_learning_rate_dtype(): @@ -693,6 +782,7 @@ def test_learning_rate_dtype(): learning_rate=1, parameters=linear.parameters(), weight_decay=0.1, + amsgrad=self.amsgrad, ) def test_grad_clip_dtype(): @@ -702,6 +792,7 @@ def test_grad_clip_dtype(): parameters=linear.parameters(), weight_decay=0.1, grad_clip=0.1, + amsgrad=self.amsgrad, ) self.assertRaises(TypeError, test_weight_decay_dtype) @@ -713,6 +804,11 @@ def test_grad_clip_dtype(): self.assertRaises(TypeError, test_grad_clip_dtype) +class TestAdamWOpErrorAMSGrad(TestAdamWOpError): + def setUp(self): + self.amsgrad = True + + class TestAdamWOpGroupWithLR(TestAdamWOp): def test_adamw_op_dygraph(self): paddle.disable_static() @@ -736,6 +832,7 @@ def test_adamw_op_dygraph(self): ], apply_decay_param_fun=lambda name: True, weight_decay=0.01, + amsgrad=self.amsgrad, ) for _ in range(2): @@ -746,6 +843,11 @@ def test_adamw_op_dygraph(self): adam.clear_gradients() +class TestAdamWOpGroupWithLRAMSGrad(TestAdamWOpGroupWithLR): + def setUp(self): + self.amsgrad = True + + def simple_lr_setting(param, decay_rate, n_layers): if "fc_0" in param.name or "linear_1" in param.name: depth = int(param.name.split("_")[2]) + 1 @@ -765,6 +867,7 @@ def setUp(self): random.seed(2022) np.random.seed(2022) paddle.seed(2022) + self.amsgrad = False def test_adamw_op_dygraph(self): paddle.disable_static() @@ -816,6 +919,7 @@ def test_adamw_op_dygraph(self): apply_decay_param_fun=lambda name: True, weight_decay=weight_decay, lr_ratio=simple_lr_fun, + amsgrad=self.amsgrad, ) def get_numpy_output( @@ -839,7 +943,7 @@ def get_numpy_output( "lr_ratio": lr_ratio, "coeff": weight_decay, "with_decay": True, - "amsgrad": False, + "amsgrad": self.amsgrad, } param_out, moment1_out, moment2_out, moment2_max_out = adamw_step( np_inputs, np_attrs @@ -991,6 +1095,7 @@ def test_adamw_op(self): weight_decay=weight_decay, epsilon=epsilon, lr_ratio=simple_lr_fun, + amsgrad=self.amsgrad, ) opt.minimize(avg_cost) @@ -1015,7 +1120,7 @@ def get_numpy_output( "lr_ratio": lr_ratio, "coeff": weight_decay, "with_decay": True, - "amsgrad": False, + "amsgrad": self.amsgrad, } param_out, moment1_out, moment2_out, moment2_max_out = ( adamw_step(np_inputs, np_attrs) @@ -1210,6 +1315,7 @@ def test_adamw_op_with_pir(self): weight_decay=weight_decay, epsilon=epsilon, lr_ratio=simple_lr_fun, + amsgrad=self.amsgrad, ) _, params_grads = opt.minimize(avg_cost) @@ -1234,7 +1340,7 @@ def get_numpy_output( "lr_ratio": lr_ratio, "coeff": weight_decay, "with_decay": True, - "amsgrad": False, + "amsgrad": self.amsgrad, } param_out, moment1_out, moment2_out, moment2_out_max = ( adamw_step(np_inputs, np_attrs) @@ -1381,7 +1487,13 @@ def get_numpy_output( paddle.disable_static() -# TODO(megemini): AMSGrad +class TestAdamWOpLayerwiseLRAMSGrad(TestAdamWOpLayerwiseLR): + def setUp(self): + random.seed(2022) + np.random.seed(2022) + paddle.seed(2022) + self.amsgrad = True + if __name__ == "__main__": unittest.main() diff --git a/test/xpu/test_adam_op_xpu.py b/test/xpu/test_adam_op_xpu.py index dc47654f7dcb96..8f5c771cdfa6b9 100644 --- a/test/xpu/test_adam_op_xpu.py +++ b/test/xpu/test_adam_op_xpu.py @@ -339,6 +339,7 @@ def update_row(row_id, update_value): / (np.sqrt(moment2_max_out[row_id]) + epsilon) ) else: + moment2_max_out[row_id] = np.zeros_like(moment2_out[row_id]) param_out[row_id] = param[row_id] - lr_t * ( moment1_out[row_id] / (np.sqrt(moment2_out[row_id]) + epsilon) ) From 96216e49f992b19fac04ea2ed8308a972c36a738 Mon Sep 17 00:00:00 2001 From: megemini Date: Sun, 8 Sep 2024 13:54:53 +0800 Subject: [PATCH 11/33] [Update] unittest passed for merged and fused amda --- test/legacy_test/test_fused_adam_op.py | 59 +++++++++++++++++++++---- test/legacy_test/test_merged_adam_op.py | 32 +++++++++++++- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/test/legacy_test/test_fused_adam_op.py b/test/legacy_test/test_fused_adam_op.py index 8bbc1fafef05b7..1a3af7cb0d0101 100644 --- a/test/legacy_test/test_fused_adam_op.py +++ b/test/legacy_test/test_fused_adam_op.py @@ -25,12 +25,13 @@ def fused_adam_step(inputs, attributes, num): Simulate one step of the fused_adam optimizer :param inputs: dict of inputs :param attributes: dict of attributes - :return tuple: tuple of output params, moments1, moments2, beta1_pows, beta2_pows + :return tuple: tuple of output params, moments1, moments2, moments2_max, beta1_pows, beta2_pows ''' params = inputs['Params'] grads = inputs['Grads'] moments1 = inputs['Moments1'] moments2 = inputs['Moments2'] + moments2_max = inputs['Moments2Max'] lr = inputs['LearningRate'] beta1_pows = inputs['Beta1Pows'] beta2_pows = inputs['Beta2Pows'] @@ -38,6 +39,7 @@ def fused_adam_step(inputs, attributes, num): params_out = [] moments1_out = [] moments2_out = [] + moments2_max_out = [] beta1_pows_out = [] beta2_pows_out = [] @@ -52,16 +54,37 @@ def fused_adam_step(inputs, attributes, num): else: beta2 = inputs['Beta2Tensor'][0][0] + amsgrad = attributes['amsgrad'] + for i in range(num): - moments1_out.append(beta1 * moments1[i][1] + (1 - beta1) * grads[i][1]) - moments2_out.append( - beta2 * moments2[i][1] + (1 - beta2) * np.square(grads[i][1]) + _moment1_out = beta1 * moments1[i][1] + (1 - beta1) * grads[i][1] + _moment2_out = beta2 * moments2[i][1] + (1 - beta2) * np.square( + grads[i][1] ) + + moments1_out.append(_moment1_out) + moments2_out.append(_moment2_out) + lr_t = lr * np.sqrt(1 - beta2_pows[i][1]) / (1 - beta1_pows[i][1]) - params_out.append( - params[i][1] - - lr_t * (moments1_out[i] / (np.sqrt(moments2_out[i]) + epsilon)) - ) + + if amsgrad: + _moment2_max = np.maximum(_moment2_out, moments2_max[i][1]) + moments2_max_out.append(_moment2_max) + + params_out.append( + params[i][1] + - lr_t + * (moments1_out[i] / (np.sqrt(moments2_max_out[i]) + epsilon)) + ) + else: + _moment2_max = np.zeros_like(_moment2_out) + moments2_max_out.append(_moment2_max) + + params_out.append( + params[i][1] + - lr_t + * (moments1_out[i] / (np.sqrt(moments2_out[i]) + epsilon)) + ) for i in range(num): beta1_pows_out.append(beta1_pows[i][1] * beta1) @@ -71,12 +94,16 @@ def fused_adam_step(inputs, attributes, num): params_out, moments1_out, moments2_out, + moments2_max_out, beta1_pows_out, beta2_pows_out, ) class TestFusedAdamOp(OpTest): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): paddle.enable_static() @@ -91,12 +118,14 @@ def setUp(self): epsilon = 1e-4 beta1_pow = beta1**10 beta2_pow = beta2**10 + self.set_amsgrad() self.attrs = { 'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2, "chunk_size": 32 * 2048, + "amsgrad": self.amsgrad, } for i in range(num): @@ -126,6 +155,10 @@ def setUp(self): 'Moments2': [ ("moments2" + str(i), inputs_list[3][i]) for i in range(num) ], + 'Moments2Max': [ + ("moments2_max" + str(i), np.zeros_like(inputs_list[0][i])) + for i in range(num) + ], 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pows': [ ("beta1_pows" + str(i), inputs_list[4][i]) for i in range(num) @@ -139,6 +172,7 @@ def setUp(self): params_out, moments1_out, moments2_out, + moments2_max_out, beta1_pows_out, beta2_pows_out, ) = fused_adam_step(self.inputs, self.attrs, num) @@ -150,6 +184,10 @@ def setUp(self): 'Moments2Out': [ ("moments2_out" + str(i), moments2_out[i]) for i in range(num) ], + 'Moments2MaxOut': [ + ("moments2_max_out" + str(i), moments2_max_out[i]) + for i in range(num) + ], 'ParamsOut': [ ("params_out" + str(i), params_out[i]) for i in range(num) ], @@ -169,6 +207,11 @@ def test_check_output(self): self.check_output(check_dygraph=False) +class TestFusedAdamOpAMSGrad(TestFusedAdamOp): + def set_amsgrad(self): + self.amsgrad = True + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/test/legacy_test/test_merged_adam_op.py b/test/legacy_test/test_merged_adam_op.py index 8d1295d6a33412..29c21f4561256a 100644 --- a/test/legacy_test/test_merged_adam_op.py +++ b/test/legacy_test/test_merged_adam_op.py @@ -27,6 +27,7 @@ def run_adam_op( lrs, moment1s, moment2s, + moment2s_max, beta1_pows, beta2_pows, master_params, @@ -36,11 +37,13 @@ def run_adam_op( place, multi_precision=False, use_merged=False, + amsgrad=False, ): assert len(params) == len(grads) assert len(params) == len(lrs) assert len(params) == len(moment1s) assert len(params) == len(moment2s) + assert len(params) == len(moment2s_max) assert len(params) == len(beta1_pows) assert len(params) == len(beta1_pows) assert len(params) == len(master_params) @@ -52,24 +55,27 @@ def run_adam_op( lr_vars = [paddle.to_tensor(l) for l in lrs] moment1_vars = [paddle.to_tensor(m) for m in moment1s] moment2_vars = [paddle.to_tensor(m) for m in moment2s] + moment2_max_vars = [paddle.to_tensor(m) for m in moment2s_max] beta1_pow_vars = [paddle.to_tensor(b) for b in beta1_pows] beta2_pow_vars = [paddle.to_tensor(b) for b in beta2_pows] master_param_vars = [paddle.to_tensor(m_p) for m_p in master_params] if not use_merged: for i in range(len(param_vars)): - _, _, _, _, _, _ = _legacy_C_ops.adam( + _, _, _, _, _, _, _ = _legacy_C_ops.adam( param_vars[i], grad_vars[i], lr_vars[i], moment1_vars[i], moment2_vars[i], + moment2_max_vars[i], beta1_pow_vars[i], beta2_pow_vars[i], master_param_vars[i], param_vars[i], moment1_vars[i], moment2_vars[i], + moment2_max_vars[i], beta1_pow_vars[i], beta2_pow_vars[i], master_param_vars[i], @@ -81,14 +87,17 @@ def run_adam_op( beta2, 'multi_precision', multi_precision, + 'amsgrad', + amsgrad, ) else: - _, _, _, _, _, _ = _C_ops.merged_adam_( + _, _, _, _, _, _, _ = _C_ops.merged_adam_( param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars, + moment2_max_vars, beta1_pow_vars, beta2_pow_vars, master_param_vars, @@ -97,12 +106,14 @@ def run_adam_op( epsilon, multi_precision, False, + amsgrad, ) outputs = { 'ParamOut': param_vars, 'Moment1Out': moment1_vars, 'Moment2Out': moment2_vars, + 'Moment2MaxOut': moment2_max_vars, 'Beta1PowOut': beta1_pow_vars, 'Beta2PowOut': beta2_pow_vars, 'MasterParamOut': master_param_vars, @@ -112,14 +123,21 @@ def run_adam_op( class TestMergedAdam(unittest.TestCase): + def set_amsgrad(self): + self.amsgrad = False + def setUp(self): paddle.disable_static() self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] self.seed = 10 + self.set_amsgrad() def gen_rand_data(self, shapes, dtype): return [np.random.random(s).astype(dtype) for s in shapes] + def gen_zero_data(self, shapes, dtype): + return [np.zeros(s).astype(dtype) for s in shapes] + def prepare_data(self, shapes, multi_precision, seed, place): np.random.seed(seed) mp_dtype = np.float32 @@ -129,6 +147,7 @@ def prepare_data(self, shapes, multi_precision, seed, place): lrs = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) moment1s = self.gen_rand_data(shapes, mp_dtype) moment2s = self.gen_rand_data(shapes, mp_dtype) + moment2s_max = self.gen_zero_data(shapes, mp_dtype) beta1_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) beta2_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) master_params = [p.astype(mp_dtype) for p in params] @@ -138,6 +157,7 @@ def prepare_data(self, shapes, multi_precision, seed, place): lrs, moment1s, moment2s, + moment2s_max, beta1_pows, beta2_pows, master_params, @@ -150,6 +170,7 @@ def check_with_place(self, place, multi_precision): lrs, moment1s, moment2s, + moment2s_max, beta1_pows, beta2_pows, master_params, @@ -162,6 +183,7 @@ def run_op(use_merged): lrs=lrs, moment1s=moment1s, moment2s=moment2s, + moment2s_max=moment2s_max, beta1_pows=beta1_pows, beta2_pows=beta2_pows, master_params=master_params, @@ -171,6 +193,7 @@ def run_op(use_merged): place=place, multi_precision=multi_precision, use_merged=use_merged, + amsgrad=self.amsgrad, ) outs1 = run_op(True) @@ -206,5 +229,10 @@ def test_main(self): self.check_with_place(place, multi_precision) +class TestMergedAdamAMSGrad(TestMergedAdam): + def set_amsgrad(self): + self.amsgrad = True + + if __name__ == "__main__": unittest.main() From 98abe719bb737374658620e015f0689b9ab731ab Mon Sep 17 00:00:00 2001 From: megemini Date: Tue, 10 Sep 2024 23:45:46 +0800 Subject: [PATCH 12/33] [Update] make moment2_max optional --- paddle/fluid/operators/fused/fused_adam_op.cc | 2 + paddle/phi/infermeta/multiary.cc | 16 ++- paddle/phi/infermeta/multiary.h | 4 +- paddle/phi/infermeta/spmd_rules/optimizer.cc | 112 +++++++++++------- paddle/phi/infermeta/spmd_rules/optimizer.h | 80 +++++++------ paddle/phi/kernels/adam_kernel.h | 4 +- paddle/phi/kernels/adamw_kernel.h | 2 +- paddle/phi/kernels/cpu/adam_kernel.cc | 40 +++++-- paddle/phi/kernels/cpu/adamw_kernel.cc | 21 ++-- paddle/phi/kernels/cpu/fused_adam_kernel.cc | 32 ++--- paddle/phi/kernels/funcs/adam_functors.h | 65 +++++----- paddle/phi/kernels/funcs/jit/refer/refer.h | 5 - paddle/phi/kernels/funcs/multi_tensor_apply.h | 2 +- paddle/phi/kernels/fused_adam_kernel.h | 2 +- paddle/phi/kernels/gpu/adam_kernel.cu | 73 +++++++----- paddle/phi/kernels/gpu/adamw_kernel.cu | 49 ++++---- paddle/phi/kernels/gpu/fused_adam_kernel.cu | 52 +++++--- .../phi/kernels/selected_rows/adam_kernel.h | 2 +- .../phi/kernels/selected_rows/adamw_kernel.h | 2 +- .../kernels/selected_rows/cpu/adam_kernel.cc | 14 ++- .../kernels/selected_rows/cpu/adamw_kernel.cc | 2 +- .../kernels/selected_rows/gpu/adam_kernel.cu | 33 ++++-- .../kernels/selected_rows/gpu/adamw_kernel.cu | 32 +++-- .../kernels/selected_rows/xpu/adam_kernel.cc | 2 +- paddle/phi/kernels/xpu/adam_kernel.cc | 6 +- paddle/phi/kernels/xpu/adamw_kernel.cc | 4 +- .../ops/yaml/inconsistent/dygraph_ops.yaml | 2 +- .../phi/ops/yaml/inconsistent/static_ops.yaml | 2 +- paddle/phi/ops/yaml/ops.yaml | 6 +- python/paddle/optimizer/adam.py | 72 +++++++---- python/paddle/optimizer/adamw.py | 40 ++++--- test/legacy_test/test_adam_op.py | 37 ++++-- test/legacy_test/test_adamw_op.py | 15 ++- test/legacy_test/test_fused_adam_op.py | 9 +- test/white_list/no_check_set_white_list.py | 3 + 35 files changed, 512 insertions(+), 332 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_adam_op.cc b/paddle/fluid/operators/fused/fused_adam_op.cc index 3649410a6459fd..932bdbfd90a6c2 100644 --- a/paddle/fluid/operators/fused/fused_adam_op.cc +++ b/paddle/fluid/operators/fused/fused_adam_op.cc @@ -58,6 +58,7 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Moments1", "(Tensor) Input first moments").AsDuplicable(); AddInput("Moments2", "(Tensor) Input second moments").AsDuplicable(); AddInput("Moments2Max", "(Tensor) Input second moments max for amsgrad") + .AsDispensable() .AsDuplicable(); AddInput("Beta1Pows", "(Tensor, default Tensor) Input beta1 power accumulator") @@ -76,6 +77,7 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Moments2Out", "(Tensor) Output second moments").AsDuplicable(); AddOutput("Moments2MaxOut", "(Tensor) Output second moments max for amsgrad") + .AsDispensable() .AsDuplicable(); AddOutput("Beta1PowsOut", "(Tensor) Output beta1 power accumulator") .AsDuplicable(); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 5d6ba06710afe5..15c7487dd56627 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -235,8 +235,10 @@ void AdamInferMeta(const MetaTensor& param, moment1_out->set_dtype(moment1.dtype()); moment2_out->set_dims(param_dims); moment2_out->set_dtype(moment2.dtype()); - moment2_max_out->set_dims(param_dims); - moment2_max_out->set_dtype(moment2_max.dtype()); + if (amsgrad) { + moment2_max_out->set_dims(param_dims); + moment2_max_out->set_dtype(moment2.dtype()); + } beta1_pow_out->set_dims(beta1_pow_dims); beta1_pow_out->set_dtype(beta1_pow.dtype()); @@ -3865,7 +3867,7 @@ void MergedAdamInferMeta( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max, + const paddle::optional>& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -5796,7 +5798,7 @@ void FusedAdamInferMeta( const MetaTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, - const std::vector& moments2_max, + const paddle::optional>& moments2_max, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, @@ -5825,8 +5827,10 @@ void FusedAdamInferMeta( moments1_out[i]->set_dtype(moments1[i]->dtype()); moments2_out[i]->set_dims(moments2[i]->dims()); moments2_out[i]->set_dtype(moments2[i]->dtype()); - moments2_max_out[i]->set_dims(moments2_max[i]->dims()); - moments2_max_out[i]->set_dtype(moments2_max[i]->dtype()); + if (amsgrad) { + moments2_max_out[i]->set_dims(moments2_max.get()[i]->dims()); + moments2_max_out[i]->set_dtype(moments2_max.get()[i]->dtype()); + } beta1_pows_out[i]->set_dims(beta1_pows[i]->dims()); beta1_pows_out[i]->set_dtype(beta1_pows[i]->dtype()); beta2_pows_out[i]->set_dims(beta2_pows[i]->dims()); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 96ad06ce7ea258..7c93a31a55b56a 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -716,7 +716,7 @@ void MergedAdamInferMeta( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max, + const paddle::optional>& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -1125,7 +1125,7 @@ void FusedAdamInferMeta( const MetaTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, - const std::vector& moments2_max, + const paddle::optional>& moments2_max, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, diff --git a/paddle/phi/infermeta/spmd_rules/optimizer.cc b/paddle/phi/infermeta/spmd_rules/optimizer.cc index 60382b1ffefad0..a9303e9278bce0 100644 --- a/paddle/phi/infermeta/spmd_rules/optimizer.cc +++ b/paddle/phi/infermeta/spmd_rules/optimizer.cc @@ -25,24 +25,25 @@ limitations under the License. */ namespace phi { namespace distributed { -SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, - const DistMetaTensor& grad, - const DistMetaTensor& learning_rate, - const DistMetaTensor& moment1, - const DistMetaTensor& moment2, - const DistMetaTensor& moment2_max, - const DistMetaTensor& beta1_pow, - const DistMetaTensor& beta2_pow, - const DistMetaTensor& master_param, - const DistMetaTensor& skip_update, - const Scalar& beta1, - const Scalar& beta2, - const Scalar& epsilon, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow, - bool amsgrad) { +SpmdInfo AdamInferSpmdDynamic( + const DistMetaTensor& param, + const DistMetaTensor& grad, + const DistMetaTensor& learning_rate, + const DistMetaTensor& moment1, + const DistMetaTensor& moment2, + const paddle::optional& moment2_max, + const DistMetaTensor& beta1_pow, + const DistMetaTensor& beta2_pow, + const DistMetaTensor& master_param, + const DistMetaTensor& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + bool amsgrad) { // shape check PADDLE_ENFORCE( param.dims().size() == grad.dims().size() && @@ -82,7 +83,8 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, TensorDistAttr moment2_dist_attr = CopyTensorDistAttrForOutput(moment2.dist_attr()); TensorDistAttr moment2_max_dist_attr = - CopyTensorDistAttrForOutput(moment2_max.dist_attr()); + amsgrad ? CopyTensorDistAttrForOutput(moment2_max.get().dist_attr()) + : TensorDistAttr(); TensorDistAttr beta1_pow_dist_attr = CopyTensorDistAttrForOutput(beta1_pow.dist_attr()); TensorDistAttr beta2_pow_dist_attr = @@ -119,7 +121,12 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, auto grad_spmd_dims_mapping = grad_dist_attr_spmd.dims_mapping(); auto momentum1_src_dims_mapping = moment1.dist_attr().dims_mapping(); auto momentum2_src_dims_mapping = moment2.dist_attr().dims_mapping(); - auto momentum2_max_src_dims_mapping = moment2_max.dist_attr().dims_mapping(); + + std::vector momentum2_max_src_dims_mapping; + if (amsgrad) { + momentum2_max_src_dims_mapping = + moment2_max.get().dist_attr().dims_mapping(); + } // Get the final dist attr for param, master_param, grad and momentum. // Whatever the input dist attrs are, the output dist attr should be same. @@ -134,10 +141,20 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, // and the unshard tensors should keep unshard status. std::vector dst_dims_mapping; for (int64_t i = 0; i < param.dims().size(); ++i) { - std::vector shard_status{param_spmd_dims_mapping[i], - grad_spmd_dims_mapping[i], - momentum1_src_dims_mapping[i], - momentum2_src_dims_mapping[i]}; + std::vector shard_status; + if (amsgrad) { + shard_status.assign({param_spmd_dims_mapping[i], + grad_spmd_dims_mapping[i], + momentum1_src_dims_mapping[i], + momentum2_src_dims_mapping[i], + momentum2_max_src_dims_mapping[i]}); + + } else { + shard_status.assign({param_spmd_dims_mapping[i], + grad_spmd_dims_mapping[i], + momentum1_src_dims_mapping[i], + momentum2_src_dims_mapping[i]}); + } int64_t dst_shard_status = -1; for (auto status : shard_status) { if (status == -1) { @@ -177,7 +194,9 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, } moment1_dist_attr.set_dims_mapping(dst_dims_mapping); moment2_dist_attr.set_dims_mapping(dst_dims_mapping); - moment2_max_dist_attr.set_dims_mapping(dst_dims_mapping); + if (amsgrad) { + moment2_max_dist_attr.set_dims_mapping(dst_dims_mapping); + } return {{param_dist_attr, grad_dist_attr, @@ -198,27 +217,28 @@ SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, master_param_dist_attr}}; } -SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param, - const DistMetaTensor& grad, - const DistMetaTensor& learning_rate, - const DistMetaTensor& moment1, - const DistMetaTensor& moment2, - const DistMetaTensor& moment2_max, - const DistMetaTensor& beta1_pow, - const DistMetaTensor& beta2_pow, - const DistMetaTensor& master_param, - const DistMetaTensor& skip_update, - const Scalar& beta1, - const Scalar& beta2, - const Scalar& epsilon, - float lr_ratio, - float coeff, - bool with_decay, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow, - bool amsgrad) { +SpmdInfo AdamwInferSpmdDynamic( + const DistMetaTensor& param, + const DistMetaTensor& grad, + const DistMetaTensor& learning_rate, + const DistMetaTensor& moment1, + const DistMetaTensor& moment2, + const paddle::optional& moment2_max, + const DistMetaTensor& beta1_pow, + const DistMetaTensor& beta2_pow, + const DistMetaTensor& master_param, + const DistMetaTensor& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + float lr_ratio, + float coeff, + bool with_decay, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + bool amsgrad) { return AdamInferSpmdDynamic(param, grad, learning_rate, diff --git a/paddle/phi/infermeta/spmd_rules/optimizer.h b/paddle/phi/infermeta/spmd_rules/optimizer.h index 3fd825a2e14965..3a372e8a7f7d94 100644 --- a/paddle/phi/infermeta/spmd_rules/optimizer.h +++ b/paddle/phi/infermeta/spmd_rules/optimizer.h @@ -23,46 +23,48 @@ limitations under the License. */ namespace phi { namespace distributed { -SpmdInfo AdamInferSpmdDynamic(const DistMetaTensor& param, - const DistMetaTensor& grad, - const DistMetaTensor& learning_rate, - const DistMetaTensor& moment1, - const DistMetaTensor& moment2, - const DistMetaTensor& moment2_max, - const DistMetaTensor& beta1_pow, - const DistMetaTensor& beta2_pow, - const DistMetaTensor& master_param, - const DistMetaTensor& skip_update, - const Scalar& beta1, - const Scalar& beta2, - const Scalar& epsilon, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow, - bool amsgrad); +SpmdInfo AdamInferSpmdDynamic( + const DistMetaTensor& param, + const DistMetaTensor& grad, + const DistMetaTensor& learning_rate, + const DistMetaTensor& moment1, + const DistMetaTensor& moment2, + const paddle::optional& moment2_max, + const DistMetaTensor& beta1_pow, + const DistMetaTensor& beta2_pow, + const DistMetaTensor& master_param, + const DistMetaTensor& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + bool amsgrad); -SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param, - const DistMetaTensor& grad, - const DistMetaTensor& learning_rate, - const DistMetaTensor& moment1, - const DistMetaTensor& moment2, - const DistMetaTensor& moment2_max, - const DistMetaTensor& beta1_pow, - const DistMetaTensor& beta2_pow, - const DistMetaTensor& master_param, - const DistMetaTensor& skip_update, - const Scalar& beta1, - const Scalar& beta2, - const Scalar& epsilon, - float lr_ratio, - float coeff, - bool with_decay, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow, - bool amsgrad); +SpmdInfo AdamwInferSpmdDynamic( + const DistMetaTensor& param, + const DistMetaTensor& grad, + const DistMetaTensor& learning_rate, + const DistMetaTensor& moment1, + const DistMetaTensor& moment2, + const paddle::optional& moment2_max, + const DistMetaTensor& beta1_pow, + const DistMetaTensor& beta2_pow, + const DistMetaTensor& master_param, + const DistMetaTensor& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + float lr_ratio, + float coeff, + bool with_decay, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + bool amsgrad); SpmdInfo SgdInferSpmd(const DistMetaTensor& param, const DistMetaTensor& learning_rate, diff --git a/paddle/phi/kernels/adam_kernel.h b/paddle/phi/kernels/adam_kernel.h index a7d1033e00f854..dd6ee99794e605 100644 --- a/paddle/phi/kernels/adam_kernel.h +++ b/paddle/phi/kernels/adam_kernel.h @@ -26,7 +26,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -55,7 +55,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max, + const paddle::optional>& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, diff --git a/paddle/phi/kernels/adamw_kernel.h b/paddle/phi/kernels/adamw_kernel.h index ea34b9ca289855..3393c9a7027d41 100644 --- a/paddle/phi/kernels/adamw_kernel.h +++ b/paddle/phi/kernels/adamw_kernel.h @@ -26,7 +26,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, diff --git a/paddle/phi/kernels/cpu/adam_kernel.cc b/paddle/phi/kernels/cpu/adam_kernel.cc index 0d30dc28a8220f..7aab5d6c8bab0b 100644 --- a/paddle/phi/kernels/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/cpu/adam_kernel.cc @@ -35,7 +35,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -75,7 +75,13 @@ void AdamDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); + if (amsgrad) { + phi::Copy(dev_ctx, + moment2_max.get(), + dev_ctx.GetPlace(), + false, + moment2_max_out); + } if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -116,7 +122,8 @@ void AdamDenseKernel(const Context& dev_ctx, T* param_out_ptr = dev_ctx.template Alloc(param_out); T* mom1_out_ptr = dev_ctx.template Alloc(moment1_out); T* mom2_out_ptr = dev_ctx.template Alloc(moment2_out); - T* mom2_max_out_ptr = dev_ctx.template Alloc(moment2_max_out); + T* mom2_max_out_ptr = + amsgrad ? dev_ctx.template Alloc(moment2_max_out) : nullptr; T learning_rate_ = learning_rate.data()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); @@ -128,7 +135,7 @@ void AdamDenseKernel(const Context& dev_ctx, const T* param_ptr = param.data(); const T* mom1_ptr = moment1.data(); const T* mom2_ptr = moment2.data(); - const T* mom2_max_ptr = moment2_max.data(); + const T* mom2_max_ptr = amsgrad ? moment2_max.get().data() : nullptr; const T* grad_ptr = grad.data(); auto adam = @@ -142,6 +149,9 @@ void AdamDenseKernel(const Context& dev_ctx, #endif for (int64_t i = 0; i < numel / chunk_size; ++i) { const int64_t offset = i * chunk_size; + const T* mom2_max_in_data = amsgrad ? mom2_max_ptr + offset : nullptr; + T* mom2_max_out_data = amsgrad ? mom2_max_out_ptr + offset : nullptr; + adam(beta1_, beta2_, -learning_rate_, @@ -150,11 +160,11 @@ void AdamDenseKernel(const Context& dev_ctx, grad_ptr + offset, mom1_ptr + offset, mom2_ptr + offset, - mom2_max_ptr + offset, + mom2_max_in_data, param_ptr + offset, mom1_out_ptr + offset, mom2_out_ptr + offset, - mom2_max_out_ptr + offset, + mom2_max_out_data, param_out_ptr + offset, amsgrad); } @@ -162,6 +172,9 @@ void AdamDenseKernel(const Context& dev_ctx, if (numel % chunk_size != 0) { const int64_t offset = (numel / chunk_size) * chunk_size; const int64_t tail_numel = numel % chunk_size; + const T* mom2_max_in_data = amsgrad ? mom2_max_ptr + offset : nullptr; + T* mom2_max_out_data = amsgrad ? mom2_max_out_ptr + offset : nullptr; + adam(beta1_, beta2_, -learning_rate_, @@ -170,11 +183,11 @@ void AdamDenseKernel(const Context& dev_ctx, grad_ptr + offset, mom1_ptr + offset, mom2_ptr + offset, - mom2_max_ptr + offset, + mom2_max_in_data, param_ptr + offset, mom1_out_ptr + offset, mom2_out_ptr + offset, - mom2_max_out_ptr + offset, + mom2_max_out_data, param_out_ptr + offset, amsgrad); } @@ -188,7 +201,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max, + const paddle::optional>& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -260,6 +273,11 @@ void MergedAdamKernel( T epsilon_ = epsilon.to(); for (size_t idx = 0; idx < param_num; idx++) { + const T* mom2_max_in_data = + amsgrad ? moment2_max.get()[idx]->data() : nullptr; + T* mom2_max_out_data = + amsgrad ? dev_ctx.template Alloc(moment2_max_out[idx]) : nullptr; + phi::funcs::AdamFunctor functor( beta1_, beta2_, @@ -270,8 +288,8 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), - moment2_max[idx]->data(), - dev_ctx.template Alloc(moment2_max_out[idx]), + mom2_max_in_data, + mom2_max_out_data, learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), diff --git a/paddle/phi/kernels/cpu/adamw_kernel.cc b/paddle/phi/kernels/cpu/adamw_kernel.cc index 97a5d44cfab4f7..ede1986189473a 100644 --- a/paddle/phi/kernels/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/cpu/adamw_kernel.cc @@ -35,7 +35,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -136,7 +136,8 @@ void AdamwDenseKernel(const Context& dev_ctx, T* param_out_ptr = dev_ctx.template Alloc(param_out); T* mom1_out_ptr = dev_ctx.template Alloc(moment1_out); T* mom2_out_ptr = dev_ctx.template Alloc(moment2_out); - T* mom2_max_out_ptr = dev_ctx.template Alloc(moment2_max_out); + T* mom2_max_out_ptr = + amsgrad ? dev_ctx.template Alloc(moment2_max_out) : nullptr; T old_lr = learning_rate.data()[0]; T learning_rate_ = learning_rate.data()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); @@ -147,7 +148,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const T* param_ptr = param.data(); const T* mom1_ptr = moment1.data(); const T* mom2_ptr = moment2.data(); - const T* mom2_max_ptr = moment2_max.data(); + const T* mom2_max_ptr = amsgrad ? moment2_max.get().data() : nullptr; const T* grad_ptr = grad.data(); auto adamw = @@ -161,6 +162,9 @@ void AdamwDenseKernel(const Context& dev_ctx, #endif for (int64_t i = 0; i < numel / chunk_size; ++i) { const int64_t offset = i * chunk_size; + const T* mom2_max_in_data = amsgrad ? mom2_max_ptr + offset : nullptr; + T* mom2_max_out_data = amsgrad ? mom2_max_out_ptr + offset : nullptr; + adamw(beta1_, beta2_, -learning_rate_, @@ -172,11 +176,11 @@ void AdamwDenseKernel(const Context& dev_ctx, grad_ptr + offset, mom1_ptr + offset, mom2_ptr + offset, - mom2_max_ptr + offset, + mom2_max_in_data, param_ptr + offset, mom1_out_ptr + offset, mom2_out_ptr + offset, - mom2_max_out_ptr + offset, + mom2_max_out_data, param_out_ptr + offset, amsgrad); } @@ -184,6 +188,9 @@ void AdamwDenseKernel(const Context& dev_ctx, if (numel % chunk_size != 0) { const int64_t offset = (numel / chunk_size) * chunk_size; const int64_t tail_numel = numel % chunk_size; + const T* mom2_max_in_data = amsgrad ? mom2_max_ptr + offset : nullptr; + T* mom2_max_out_data = amsgrad ? mom2_max_out_ptr + offset : nullptr; + adamw(beta1_, beta2_, -learning_rate_, @@ -195,11 +202,11 @@ void AdamwDenseKernel(const Context& dev_ctx, grad_ptr + offset, mom1_ptr + offset, mom2_ptr + offset, - mom2_max_ptr + offset, + mom2_max_in_data, param_ptr + offset, mom1_out_ptr + offset, mom2_out_ptr + offset, - mom2_max_out_ptr + offset, + mom2_max_out_data, param_out_ptr + offset, amsgrad); } diff --git a/paddle/phi/kernels/cpu/fused_adam_kernel.cc b/paddle/phi/kernels/cpu/fused_adam_kernel.cc index 66712300080f6d..865188b37669ab 100644 --- a/paddle/phi/kernels/cpu/fused_adam_kernel.cc +++ b/paddle/phi/kernels/cpu/fused_adam_kernel.cc @@ -36,7 +36,7 @@ void FusedAdamKernel( const DenseTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, - const std::vector& moments2_max, + const paddle::optional>& moments2_max, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, @@ -82,15 +82,17 @@ void FusedAdamKernel( "is %d, the size of Input(params) is %d.", moments2.size(), params_num)); - PADDLE_ENFORCE_EQ( - params_num, - moments2_max.size(), - errors::InvalidArgument( - "The size of Input(moments2 max) must be equal to " - "Input(params), but got the size of Input(moments2 max) " - "is %d, the size of Input(params) is %d.", - moments2_max.size(), - params_num)); + if (amsgrad) { + PADDLE_ENFORCE_EQ( + params_num, + moments2_max.get().size(), + errors::InvalidArgument( + "The size of Input(moments2 max) must be equal to " + "Input(params), but got the size of Input(moments2 max) " + "is %d, the size of Input(params) is %d.", + moments2_max.get().size(), + params_num)); + } PADDLE_ENFORCE_EQ(params_num, beta1_pows.size(), errors::InvalidArgument( @@ -110,6 +112,8 @@ void FusedAdamKernel( for (size_t idx = 0; idx < params_num; idx++) { auto master_params_tmp = TensorPtrToOptionalTensor(master_params, idx); + auto moments2_max_tmp = TensorPtrToOptionalTensor(moments2_max, idx); + if (!use_adamw) { AdamDenseKernel( dev_ctx, @@ -118,7 +122,7 @@ void FusedAdamKernel( learning_rate, *moments1[idx], *moments2[idx], - *moments2_max[idx], + moments2_max_tmp, *beta1_pows[idx], *beta2_pows[idx], master_params_tmp, @@ -134,7 +138,7 @@ void FusedAdamKernel( params_out[idx], moments1_out[idx], moments2_out[idx], - moments2_max_out[idx], + amsgrad ? moments2_max_out[idx] : nullptr, beta1_pows_out[idx], beta2_pows_out[idx], master_params_out.empty() ? nullptr : master_params_out[idx]); @@ -146,7 +150,7 @@ void FusedAdamKernel( learning_rate, *moments1[idx], *moments2[idx], - *moments2_max[idx], + moments2_max_tmp, *beta1_pows[idx], *beta2_pows[idx], master_params_tmp, @@ -165,7 +169,7 @@ void FusedAdamKernel( params_out[idx], moments1_out[idx], moments2_out[idx], - moments2_max_out[idx], + amsgrad ? moments2_max_out[idx] : nullptr, beta1_pows_out[idx], beta2_pows_out[idx], master_params_out.empty() ? nullptr : master_params_out[idx]); diff --git a/paddle/phi/kernels/funcs/adam_functors.h b/paddle/phi/kernels/funcs/adam_functors.h index c3d2f6619baa4a..5d674f36fe836b 100644 --- a/paddle/phi/kernels/funcs/adam_functors.h +++ b/paddle/phi/kernels/funcs/adam_functors.h @@ -221,7 +221,7 @@ class AdamFunctor { T g = grad_[i]; T mom1 = moment1_[i]; T mom2 = moment2_[i]; - T mom2_max = moment2_max_[i]; + T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; @@ -233,19 +233,19 @@ class AdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - T mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + T mom2_max_ = std::max(mom2, moment2_max_[i]); p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); + + // Write back to global memory + moment2_max_out_[i] = mom2_max_; } else { - mom2_max_ = mom2_max; p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); } // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; - moment2_max_out_[i] = mom2_max_; param_out_[i] = p; } }; @@ -312,8 +312,6 @@ class AdamFunctor { moment1_, static_cast(numel)}; Eigen::Map> mom2{ moment2_, static_cast(numel)}; - Eigen::Map> mom2_max{ - moment2_max_, static_cast(numel)}; Eigen::Map> param{ param_, static_cast(numel)}; @@ -323,8 +321,6 @@ class AdamFunctor { moment1_out_, static_cast(numel)}; Eigen::Map> moment2_out{ moment2_out_, static_cast(numel)}; - Eigen::Map> moment2_max_out{ - moment2_max_out_, static_cast(numel)}; T lr = *lr_; T beta1_pow = *beta1_pow_; @@ -337,11 +333,15 @@ class AdamFunctor { moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g; if (amsgrad_) { + Eigen::Map> mom2_max{ + moment2_max_, static_cast(numel)}; + Eigen::Map> moment2_max_out{ + moment2_max_out_, static_cast(numel)}; + moment2_max_out = moment2_out.cwiseMax(mom2_max); param_out = param - lr * (moment1_out / (moment2_max_out.sqrt() + epsilon_ * sqrt(1 - beta2_pow))); } else { - moment2_max_out = mom2_max; param_out = param - lr * (moment1_out / (moment2_out.sqrt() + epsilon_ * sqrt(1 - beta2_pow))); } @@ -429,7 +429,7 @@ class SparseAdamFunctor { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; - MT mom2_max = moment2_max_[i]; + MT lr = *lr_; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; @@ -442,13 +442,14 @@ class SparseAdamFunctor { mom1 = beta1_ * mom1 + (static_cast(1.0) - beta1_) * g; mom2 = beta2_ * mom2 + (static_cast(1.0) - beta2_) * g * g; - MT mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max_ = std::max(mom2, moment2_max_[i]); p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); + + // Write back to global memory + moment2_max_out_[i] = mom2_max_; } else { - mom2_max_ = mom2_max; p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); } @@ -456,7 +457,6 @@ class SparseAdamFunctor { // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; - moment2_max_out_[i] = mom2_max_; param_out_[i] = static_cast(p); if (master_param_out_) { master_param_out_[i] = p; @@ -547,7 +547,7 @@ class SparseAdamFunctor { // The following code is the same as dense T mom1 = moment1_[i]; T mom2 = moment2_[i]; - T mom2_max = moment2_max_[i]; + T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; @@ -559,19 +559,19 @@ class SparseAdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - T mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + T mom2_max_ = std::max(mom2, moment2_max_[i]); p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(1 - beta2_pow))); + + // Write back to global memory + moment2_max_out_[i] = mom2_max_; } else { - mom2_max_ = mom2_max; p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); } // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; - moment2_max_out_[i] = mom2_max_; param_out_[i] = p; } @@ -594,26 +594,25 @@ class SparseAdamFunctor { for (int64_t k = 0; k != row_numel_; ++k) { T mom1 = moment1_[i * row_numel_ + k]; T mom2 = moment2_[i * row_numel_ + k]; - T mom2_max = moment2_max_[i * row_numel_ + k]; - T p = param_[i * row_numel_ + k]; mom1 = beta1_ * mom1; mom2 = beta2_ * mom2; - T mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + T mom2_max = moment2_max_[i * row_numel_ + k]; + T mom2_max_ = std::max(mom2, mom2_max); + p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_)); + + // Write back to global memory + moment2_max_out_[i * row_numel_ + k] = mom2_max_; } else { - mom2_max_ = mom2_max; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); } - p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_)); - // Write back to global memory moment1_out_[i * row_numel_ + k] = mom1; moment2_out_[i * row_numel_ + k] = mom2; - moment2_max_out_[i * row_numel_ + k] = mom2_max_; param_out_[i * row_numel_ + k] = p; } } @@ -737,7 +736,7 @@ class SparseAdamWFunctor { // The following code is the same as dense MT mom1 = moment1_[i]; MT mom2 = moment2_[i]; - MT mom2_max = moment2_max_[i]; + MT lr = *lr_ * lr_ratio_; MT lr_orig = lr; MT beta1_pow = *beta1_pow_; @@ -753,13 +752,14 @@ class SparseAdamWFunctor { p -= lr_orig * coeff_ * p; - MT mom2_max_; if (amsgrad_) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max_ = std::max(mom2, moment2_max_[i]); p -= lr * (mom1 / (sqrt(mom2_max_) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); + + // Write back to global memory + moment2_max_out_[i] = mom2_max_; } else { - mom2_max_ = mom2_max; p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); } @@ -767,7 +767,6 @@ class SparseAdamWFunctor { // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; - moment2_max_out_[i] = mom2_max_; param_out_[i] = static_cast(p); if (master_param_out_) { master_param_out_[i] = p; diff --git a/paddle/phi/kernels/funcs/jit/refer/refer.h b/paddle/phi/kernels/funcs/jit/refer/refer.h index 3f194f9b8782b7..86bd1dae898a72 100644 --- a/paddle/phi/kernels/funcs/jit/refer/refer.h +++ b/paddle/phi/kernels/funcs/jit/refer/refer.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include #include @@ -543,8 +542,6 @@ void Adam(T beta1, param_out_ptr[i] = param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_max_) + eps)); } else { - mom2_max_out_ptr[i] = mom2_max_ptr[i]; - T mom2_ = mom2_out_ptr[i]; param_out_ptr[i] = param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_) + eps)); @@ -584,8 +581,6 @@ void AdamW(T beta1, param_out_ptr[i] = param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2_max_) + eps)); } else { - mom2_max_out_ptr[i] = mom2_max_ptr[i]; - T mom2_ = mom2_out_ptr[i]; param_out_ptr[i] = param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2_) + eps)); diff --git a/paddle/phi/kernels/funcs/multi_tensor_apply.h b/paddle/phi/kernels/funcs/multi_tensor_apply.h index 6fe90864881381..bf64752d9bdfbb 100644 --- a/paddle/phi/kernels/funcs/multi_tensor_apply.h +++ b/paddle/phi/kernels/funcs/multi_tensor_apply.h @@ -76,7 +76,7 @@ void LaunchMultiTensorApplyKernel( errors::InvalidArgument( "input_vector.size() != InputNum - 1, the input vector's size is " "unequal to InputNum - 1, please cheack grads, params, momemts1, " - "moments2, moments2_max, and, master_params.")); + "moments2, moments2_max(if use amsgrad), and, master_params.")); size_t length = input_vector[0].size(); PADDLE_ENFORCE_GT( length, diff --git a/paddle/phi/kernels/fused_adam_kernel.h b/paddle/phi/kernels/fused_adam_kernel.h index 16944abdb8b1a1..e908962251f065 100644 --- a/paddle/phi/kernels/fused_adam_kernel.h +++ b/paddle/phi/kernels/fused_adam_kernel.h @@ -27,7 +27,7 @@ void FusedAdamKernel( const DenseTensor &learning_rate, const std::vector &moments1, const std::vector &moments2, - const std::vector &moments2_max, + const paddle::optional> &moments2_max, const std::vector &beta1_pows, const std::vector &beta2_pows, const paddle::optional> &master_params, diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index d04d3ef1bd228b..e6528f92f530c3 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -61,21 +61,19 @@ __global__ void AdamKernelREG(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); - MT mom2_max = static_cast(moment2_max[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(moment2_max[id]); + MT mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out[id] = mom2_max_; denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2_max; - denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -83,7 +81,6 @@ __global__ void AdamKernelREG(MT beta1, moment1_out[id] = mom1; moment2_out[id] = mom2; - moment2_max_out[id] = mom2_max_; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -122,21 +119,19 @@ __global__ void AdamKernelMEM(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); - MT mom2_max = static_cast(moment2_max[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(moment2_max[id]); + MT mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out[id] = mom2_max_; denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2_max; - denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -144,7 +139,6 @@ __global__ void AdamKernelMEM(MT beta1, moment1_out[id] = mom1; moment2_out[id] = mom2; - moment2_max_out[id] = mom2_max_; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -170,7 +164,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -214,7 +208,13 @@ void AdamDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); + if (amsgrad) { + phi::Copy(dev_ctx, + moment2_max.get(), + dev_ctx.GetPlace(), + false, + moment2_max_out); + } if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -248,6 +248,11 @@ void AdamDenseKernel(const Context& dev_ctx, multi_precision ? dev_ctx.template Alloc(master_param_outs) : nullptr; + const MPDType* moment2_max_in_data = + amsgrad ? moment2_max.get().data() : nullptr; + MPDType* moment2_max_out_data = + amsgrad ? dev_ctx.template Alloc(moment2_max_out) : nullptr; + // update param and moment int threads = 512; int blocks = (param.numel() + threads - 1) / threads; @@ -266,8 +271,8 @@ void AdamDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad.data(), param.data(), @@ -287,8 +292,8 @@ void AdamDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad.data(), param.data(), @@ -318,8 +323,8 @@ void AdamDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad.data(), param.data(), @@ -339,8 +344,8 @@ void AdamDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad.data(), param.data(), @@ -371,7 +376,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max, + const paddle::optional>& moment2_max, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -403,6 +408,12 @@ void MergedAdamKernel( multi_precision ? dev_ctx.template Alloc(master_param_out[idx]) : nullptr; + const MPDType* moment2_max_in_data = + amsgrad ? moment2_max.get()[idx]->data() : nullptr; + MPDType* moment2_max_out_data = + amsgrad ? dev_ctx.template Alloc(moment2_max_out[idx]) + : nullptr; + // update param and moment int threads = 512; int blocks = (param[idx]->numel() + threads - 1) / threads; @@ -423,8 +434,8 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), - moment2_max[idx]->data(), - dev_ctx.template Alloc(moment2_max_out[idx]), + moment2_max_in_data, + moment2_max_out_data, learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), @@ -444,8 +455,8 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), - moment2_max[idx]->data(), - dev_ctx.template Alloc(moment2_max_out[idx]), + moment2_max_in_data, + moment2_max_out_data, learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), @@ -475,8 +486,8 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), - moment2_max[idx]->data(), - dev_ctx.template Alloc(moment2_max_out[idx]), + moment2_max_in_data, + moment2_max_out_data, learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), @@ -496,8 +507,8 @@ void MergedAdamKernel( dev_ctx.template Alloc(moment1_out[idx]), moment2[idx]->data(), dev_ctx.template Alloc(moment2_out[idx]), - moment2_max[idx]->data(), - dev_ctx.template Alloc(moment2_max_out[idx]), + moment2_max_in_data, + moment2_max_out_data, learning_rate[idx]->data(), grad[idx]->data(), param[idx]->data(), diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index 141b23216097cd..df2715c269fdc0 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -64,23 +64,21 @@ __global__ void AdamWKernelREG(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); - MT mom2_max = static_cast(moment2_max[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(moment2_max[id]); + MT mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out[id] = mom2_max_; denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2_max; - denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -88,7 +86,6 @@ __global__ void AdamWKernelREG(MT beta1, moment1_out[id] = mom1; moment2_out[id] = mom2; - moment2_max_out[id] = mom2_max_; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -129,23 +126,21 @@ __global__ void AdamWKernelMEM(MT beta1, MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); - MT mom2_max = static_cast(moment2_max[id]); p *= (static_cast(1.0) - lr * coeff); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(moment2_max[id]); + MT mom2_max_ = std::max(mom2, mom2_max); + moment2_max_out[id] = mom2_max_; denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2_max; - denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -153,7 +148,6 @@ __global__ void AdamWKernelMEM(MT beta1, moment1_out[id] = mom1; moment2_out[id] = mom2; - moment2_max_out[id] = mom2_max_; param_out[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -179,7 +173,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -232,7 +226,13 @@ void AdamwDenseKernel(const Context& dev_ctx, phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); + if (amsgrad) { + phi::Copy(dev_ctx, + moment2_max.get(), + dev_ctx.GetPlace(), + false, + moment2_max_out); + } if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -271,6 +271,11 @@ void AdamwDenseKernel(const Context& dev_ctx, multi_precision ? dev_ctx.template Alloc(master_param_outs) : nullptr; + const MPDType* moment2_max_in_data = + amsgrad ? moment2_max.get().data() : nullptr; + MPDType* moment2_max_out_data = + amsgrad ? dev_ctx.template Alloc(moment2_max_out) : nullptr; + // update param and moment int threads = 512; int blocks = (param.numel() + threads - 1) / threads; @@ -291,8 +296,8 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad.data(), param.data(), @@ -314,8 +319,8 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad.data(), param.data(), @@ -347,8 +352,8 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad.data(), param.data(), @@ -370,8 +375,8 @@ void AdamwDenseKernel(const Context& dev_ctx, dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad.data(), param.data(), diff --git a/paddle/phi/kernels/gpu/fused_adam_kernel.cu b/paddle/phi/kernels/gpu/fused_adam_kernel.cu index 9c5ecc7eec17d4..4fd72aee0ddd4f 100644 --- a/paddle/phi/kernels/gpu/fused_adam_kernel.cu +++ b/paddle/phi/kernels/gpu/fused_adam_kernel.cu @@ -106,10 +106,13 @@ struct FusedAdamFunctor { mom1_ptr = static_cast(t_info.tensor_addrs[1][tensor_id]) + offset; mom2_ptr = static_cast(t_info.tensor_addrs[2][tensor_id]) + offset; mom2_max_ptr = - static_cast(t_info.tensor_addrs[3][tensor_id]) + offset; + AMSGrad ? static_cast(t_info.tensor_addrs[3][tensor_id]) + offset + : nullptr; mp_ptr = IsMultiPrecision - ? static_cast(t_info.tensor_addrs[4][tensor_id]) + offset + ? static_cast( + t_info.tensor_addrs[3 + (AMSGrad ? 1 : 0)][tensor_id]) + + offset : nullptr; n -= offset; @@ -137,7 +140,9 @@ struct FusedAdamFunctor { phi::Load(g_ptr + idx, &g_vec); phi::Load(mom1_ptr + idx, &mom1_vec); phi::Load(mom2_ptr + idx, &mom2_vec); - phi::Load(mom2_max_ptr + idx, &mom2_max_vec); + if (AMSGrad) { + phi::Load(mom2_max_ptr + idx, &mom2_max_vec); + } } else { int size = n - idx; for (int j = 0; j < size; j++) { @@ -149,7 +154,9 @@ struct FusedAdamFunctor { g_vec[j] = g_ptr[idx + j]; mom1_vec[j] = static_cast(mom1_ptr[idx + j]); mom2_vec[j] = static_cast(mom2_ptr[idx + j]); - mom2_max_vec[j] = static_cast(mom2_max_ptr[idx + j]); + if (AMSGrad) { + mom2_max_vec[j] = static_cast(mom2_max_ptr[idx + j]); + } } #pragma unroll for (int j = size; j < VecSize; j++) { @@ -158,7 +165,9 @@ struct FusedAdamFunctor { mp_vec[j] = MT(0); mom1_vec[j] = MT(0); mom2_vec[j] = MT(0); - mom2_max_vec[j] = MT(0); + if (AMSGrad) { + mom2_max_vec[j] = MT(0); + } } } @@ -167,14 +176,14 @@ struct FusedAdamFunctor { MT p = IsMultiPrecision ? mp_vec[j] : static_cast(p_vec[j]); UpdateMoments(&mom1_vec[j], &mom2_vec[j], - &mom2_max_vec[j], + AMSGrad ? &mom2_max_vec[j] : nullptr, static_cast(g_vec[j]), beta1, beta2); mp_vec[j] = UpdateParameter(p, mom1_vec[j], mom2_vec[j], - mom2_max_vec[j], + AMSGrad ? mom2_max_vec[j] : MT(0), beta1_pow, beta2_pow, lr, @@ -185,7 +194,9 @@ struct FusedAdamFunctor { if (idx <= n - VecSize) { phi::Store(mom1_vec, mom1_ptr + idx); phi::Store(mom2_vec, mom2_ptr + idx); - phi::Store(mom2_max_vec, mom2_max_ptr + idx); + if (AMSGrad) { + phi::Store(mom2_max_vec, mom2_max_ptr + idx); + } if (IsMultiPrecision) { phi::Store(mp_vec, mp_ptr + idx); } @@ -201,7 +212,9 @@ struct FusedAdamFunctor { p_ptr[idx + j] = static_cast(mp_vec[j]); mom1_ptr[idx + j] = mom1_vec[j]; mom2_ptr[idx + j] = mom2_vec[j]; - mom2_max_ptr[idx + j] = mom2_max_vec[j]; + if (AMSGrad) { + mom2_max_ptr[idx + j] = mom2_max_vec[j]; + } } } } @@ -217,7 +230,6 @@ struct FusedAdamFunctor { MT beta2) { MT mom1 = static_cast(mom1_ptr[0]); MT mom2 = static_cast(mom2_ptr[0]); - MT mom2_max = static_cast(mom2_max_ptr[0]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; @@ -226,9 +238,8 @@ struct FusedAdamFunctor { mom2_ptr[0] = mom2; if (AMSGrad) { + MT mom2_max = static_cast(mom2_max_ptr[0]); mom2_max_ptr[0] = std::max(mom2, mom2_max); - } else { - mom2_max_ptr[0] = mom2_max; } } @@ -299,7 +310,7 @@ void FusedAdamKernel( const DenseTensor& learning_rate, const std::vector& moments1, const std::vector& moments2, - const std::vector& moments2_max, + const paddle::optional>& moments2_max, const std::vector& beta1_pows, const std::vector& beta2_pows, const paddle::optional>& master_params, @@ -350,7 +361,9 @@ void FusedAdamKernel( CopyTensorIfDifferent(dev_ctx, params, params_out); CopyTensorIfDifferent(dev_ctx, moments1, moments1_out); CopyTensorIfDifferent(dev_ctx, moments2, moments2_out); - CopyTensorIfDifferent(dev_ctx, moments2_max, moments2_max_out); + if (amsgrad) { + CopyTensorIfDifferent(dev_ctx, moments2_max.get(), moments2_max_out); + } CopyTensorIfDifferent(dev_ctx, beta1_pows, beta1_pows_out, true); CopyTensorIfDifferent(dev_ctx, beta2_pows, beta2_pows_out, true); if (master_params) { @@ -386,7 +399,9 @@ void FusedAdamKernel( input_vector.push_back(params_out); input_vector.push_back(moments1_out); input_vector.push_back(moments2_out); - input_vector.push_back(moments2_max_out); + if (amsgrad) { + input_vector.push_back(moments2_max_out); + } if (multi_precision) { input_vector.push_back(master_params_out); } @@ -397,7 +412,8 @@ void FusedAdamKernel( #define PD_LAUNCH_MULTI_TENSOR_APPLY_ADAM_KERNEL_BASE( \ __multi_precision, __is_cpu_betapow, __use_adamw, __amsgrad, __vec_size) \ do { \ - constexpr int kInputNum = __multi_precision ? 6 : 5; \ + constexpr int kInputNum = \ + (__multi_precision ? 5 : 4) + (__amsgrad ? 1 : 0); \ constexpr int kMaxTensorSize = __multi_precision ? 48 : 60; \ constexpr int kMaxBlockSize = __multi_precision ? 320 : 320; \ constexpr int kBlockSize = 512; \ @@ -515,7 +531,9 @@ void FusedAdamKernel( int vec_size = GetVecSizeFromTensors(params_out); vec_size = GetVecSizeFromTensors(moments1_out, vec_size); vec_size = GetVecSizeFromTensors(moments2_out, vec_size); - vec_size = GetVecSizeFromTensors(moments2_max_out, vec_size); + if (amsgrad) { + vec_size = GetVecSizeFromTensors(moments2_max_out, vec_size); + } if (master_params) { vec_size = GetVecSizeFromTensors(master_params_out, vec_size); } diff --git a/paddle/phi/kernels/selected_rows/adam_kernel.h b/paddle/phi/kernels/selected_rows/adam_kernel.h index 2ac909903a4089..3d7167fd69b4e8 100644 --- a/paddle/phi/kernels/selected_rows/adam_kernel.h +++ b/paddle/phi/kernels/selected_rows/adam_kernel.h @@ -29,7 +29,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, diff --git a/paddle/phi/kernels/selected_rows/adamw_kernel.h b/paddle/phi/kernels/selected_rows/adamw_kernel.h index 25321c87b321dd..5ca1dd62369029 100644 --- a/paddle/phi/kernels/selected_rows/adamw_kernel.h +++ b/paddle/phi/kernels/selected_rows/adamw_kernel.h @@ -29,7 +29,7 @@ void AdamwDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, diff --git a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc index ab98fd298e7475..5a6b2d01148a58 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adam_kernel.cc @@ -37,7 +37,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param UNUSED, @@ -77,7 +77,13 @@ void AdamDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); + if (amsgrad) { + phi::Copy(dev_ctx, + moment2_max.get(), + dev_ctx.GetPlace(), + false, + moment2_max_out); + } if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, beta2_pow_out); @@ -151,8 +157,8 @@ void AdamDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + amsgrad ? moment2_max.get().data() : nullptr, + amsgrad ? dev_ctx.template Alloc(moment2_max_out) : nullptr, learning_rate.data(), grad_data, param.data(), diff --git a/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc index 9b7197a3e95e9d..3b62d9520424d7 100644 --- a/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/selected_rows/cpu/adamw_kernel.cc @@ -34,7 +34,7 @@ void AdamwDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, diff --git a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu index 4c90e7711d147f..338d3dacb2138e 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu @@ -74,7 +74,7 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, } else { MT mom1 = mom1_[id]; MT mom2 = mom2_[id]; - MT mom2_max = mom2_max_[id]; + MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 ? static_cast(grad_[row_idx * row_numel + id % row_numel]) @@ -82,16 +82,15 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT moment2_max_; MT denom; if (amsgrad) { - moment2_max_ = std::max(mom2, mom2_max); + MT mom2_max = mom2_max_[id]; + MT moment2_max_ = std::max(mom2, mom2_max); + mom2_max_out_[id] = moment2_max_; denom = (sqrt(moment2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - moment2_max_ = mom2_max; - denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -100,7 +99,6 @@ __global__ void SparseAdamCUDAKernelREG(MT beta1, // Write back to global memory mom1_out_[id] = mom1; mom2_out_[id] = mom2; - mom2_max_out_[id] = moment2_max_; param_out_[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -117,7 +115,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -159,7 +157,13 @@ void AdamDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); + if (amsgrad) { + phi::Copy(dev_ctx, + moment2_max.get(), + dev_ctx.GetPlace(), + false, + moment2_max_out); + } if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -193,6 +197,11 @@ void AdamDenseParamSparseGradKernel( multi_precision ? dev_ctx.template Alloc(master_param_outs) : nullptr; + const MPDType* moment2_max_in_data = + amsgrad ? moment2_max.get().data() : nullptr; + MPDType* moment2_max_out_data = + amsgrad ? dev_ctx.template Alloc(moment2_max_out) : nullptr; + if (grad.rows().size() == 0) { VLOG(3) << "grad row size is 0!!"; return; @@ -242,8 +251,8 @@ void AdamDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad_data, param.data(), @@ -274,8 +283,8 @@ void AdamDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad_data, param.data(), diff --git a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu index 73869d1146cf77..01a81c10b3e766 100644 --- a/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -80,7 +80,6 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, } else { MT mom1 = static_cast(mom1_[id]); MT mom2 = static_cast(mom2_[id]); - MT mom2_max = static_cast(mom2_max_[id]); MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 @@ -92,16 +91,15 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - MT mom2_max_; MT denom; if (amsgrad) { - mom2_max_ = std::max(mom2, mom2_max); + MT mom2_max = static_cast(mom2_max_[id]); + MT mom2_max_ = std::max(mom2, mom2_max); + mom2_max_out_[id] = mom2_max_; denom = (sqrt(mom2_max_) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } else { - mom2_max_ = mom2_max; - denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; } @@ -110,7 +108,6 @@ __global__ void SparseAdamWCUDAKernelREG(MT beta1, // Write back to global memory mom1_out_[id] = mom1; mom2_out_[id] = mom2; - mom2_max_out_[id] = mom2_max_; param_out_[id] = static_cast(p); if (master_param_out) { master_param_out[id] = p; @@ -127,7 +124,7 @@ void AdamwDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max, + const paddle::optional& moment2_max, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -176,7 +173,13 @@ void AdamwDenseParamSparseGradKernel( phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); - phi::Copy(dev_ctx, moment2_max, dev_ctx.GetPlace(), false, moment2_max_out); + if (amsgrad) { + phi::Copy(dev_ctx, + moment2_max.get(), + dev_ctx.GetPlace(), + false, + moment2_max_out); + } if (!use_global_beta_pow) { phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); @@ -215,6 +218,11 @@ void AdamwDenseParamSparseGradKernel( multi_precision ? dev_ctx.template Alloc(master_param_outs) : nullptr; + const MPDType* moment2_max_in_data = + amsgrad ? moment2_max.get().data() : nullptr; + MPDType* moment2_max_out_data = + amsgrad ? dev_ctx.template Alloc(moment2_max_out) : nullptr; + if (grad.rows().size() == 0) { VLOG(3) << "grad row size is 0!!"; return; @@ -266,8 +274,8 @@ void AdamwDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad_data, param.data(), @@ -300,8 +308,8 @@ void AdamwDenseParamSparseGradKernel( dev_ctx.template Alloc(moment1_out), moment2.data(), dev_ctx.template Alloc(moment2_out), - moment2_max.data(), - dev_ctx.template Alloc(moment2_max_out), + moment2_max_in_data, + moment2_max_out_data, learning_rate.data(), grad_data, param.data(), diff --git a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc index 232a67e79ec454..31cd0d18d18f8d 100644 --- a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc @@ -34,7 +34,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max UNUSED, + const paddle::optional& moment2_max UNUSED, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc index 609b4133be079a..1a118507d1e553 100644 --- a/paddle/phi/kernels/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -32,7 +32,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max UNUSED, + const paddle::optional& moment2_max_ UNUSED, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -264,7 +264,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max UNUSED, + const std::vector& moment2_max_ UNUSED, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -277,7 +277,7 @@ void MergedAdamKernel( std::vector param_out, std::vector moment1_out, std::vector moment2_out, - std::vector moment2_max_out UNUSED, + std::vector moment2_max_out_ UNUSED, std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out) { diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index efb0c19b11265a..fbc237ec6b749b 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -483,7 +483,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const DenseTensor& moment2_max UNUSED, + const paddle::optional& moment2_max_ UNUSED, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -502,7 +502,7 @@ void AdamwDenseKernel(const Context& dev_ctx, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, - DenseTensor* moment2_max_out UNUSED, + DenseTensor* moment2_max_out_ UNUSED, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { diff --git a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml index e8b251b0484c11..1036b0eeb51541 100755 --- a/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml @@ -162,7 +162,7 @@ kernel : func : fused_adam data_type : params - optional : skip_update, master_params + optional : moments2_max, skip_update, master_params, moments2_max_out inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (moments2_max -> moments2_max_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) - op : fused_gemm_epilogue diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 9d195c3505bca1..8938ec6d3a4695 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -331,7 +331,7 @@ kernel : func : fused_adam data_type : params - optional : skip_update, master_params, master_params_out + optional : moments2_max, skip_update, master_params, moments2_max_out, master_params_out inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (moments2_max -> moments2_max_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) - op : fused_gate_attention diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 6d7329c7f0d146..3522f5c9788959 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -97,7 +97,7 @@ func : adam {dense, dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense, dense}, adam_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense, dense} data_type : param - optional : master_param, skip_update, master_param_out + optional : moment2_max, master_param, skip_update, moment2_max_out, master_param_out inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (moment2_max -> moment2_max_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) traits : pir::SideEffectTrait @@ -122,7 +122,7 @@ kernel : func : adamw data_type : param - optional : master_param, skip_update, master_param_out + optional : moment2_max, master_param, skip_update, moment2_max_out, master_param_out inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (moment2_max -> moment2_max_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) traits : pir::SideEffectTrait @@ -3275,7 +3275,7 @@ kernel : func : merged_adam data_type : param - optional: master_param, master_param_out + optional: moment2_max, master_param, moment2_max_out, master_param_out inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (moment2_max -> moment2_max_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_out) traits : pir::SideEffectTrait diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 6b4236f4b6cd8f..7faec0eb232015 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -117,7 +117,8 @@ class Adam(Optimizer): The default value is False. multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false. - amsgrad (bool, optional): Whether to use the AMSGrad of this algorithm. Default is false. + amsgrad (bool, optional): Whether to use the AMSGrad variant of this algorithm from the paper + `On the Convergence of Adam and Beyond `_. Default is false. name (str|None, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -258,7 +259,9 @@ def __init__( self._param_dict = self._create_multi_tensor_dict() self._moment1_dict = self._create_multi_tensor_dict() self._moment2_dict = self._create_multi_tensor_dict() - self._moment2_max_dict = self._create_multi_tensor_dict() + self._moment2_max_dict = ( + self._create_multi_tensor_dict() if amsgrad else None + ) self._beta1_pow_acc_dict = self._create_multi_tensor_dict() self._beta2_pow_acc_dict = self._create_multi_tensor_dict() self._master_weight_dict = self._create_multi_tensor_dict() @@ -276,7 +279,8 @@ def _add_moments_pows(self, p): acc_dtype = core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) - self._add_accumulator(self._moment2_acc_max_str, p, dtype=acc_dtype) + if self._amsgrad: + self._add_accumulator(self._moment2_acc_max_str, p, dtype=acc_dtype) self._add_accumulator( name=self._beta1_pow_acc_str, param=p, @@ -340,8 +344,12 @@ def _append_optimize_op(self, block, param_and_grad): moment2 = self._get_accumulator_master( self._moment2_acc_str, param_and_grad[0] ) - moment2_max = self._get_accumulator_master( - self._moment2_acc_max_str, param_and_grad[0] + moment2_max = ( + self._get_accumulator_master( + self._moment2_acc_max_str, param_and_grad[0] + ) + if self._amsgrad + else None ) beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] @@ -404,7 +412,6 @@ def _append_optimize_op(self, block, param_and_grad): "LearningRate": [lr], "Moment1": [moment1], "Moment2": [moment2], - "Moment2Max": [moment2_max], "Beta1Pow": [beta1_pow_acc], "Beta2Pow": [beta2_pow_acc], } @@ -419,7 +426,6 @@ def _append_optimize_op(self, block, param_and_grad): "ParamOut": [param_and_grad[0]], "Moment1Out": [moment1], "Moment2Out": [moment2], - "Moment2MaxOut": [moment2_max], "Beta1PowOut": [beta1_pow_acc], "Beta2PowOut": [beta2_pow_acc], } @@ -443,6 +449,10 @@ def _append_optimize_op(self, block, param_and_grad): else: attrs['epsilon'] = self._epsilon + if self._amsgrad: + inputs['Moment2Max'] = [moment2_max] + outputs["Moment2MaxOut"] = [moment2_max] + if find_master: inputs["MasterParam"] = master_weight outputs["MasterParamOut"] = master_weight @@ -550,8 +560,10 @@ def _multi_tensor_init(self, target_block, parameters, param_group_idx): for param in parameters: moment1 = self._get_accumulator_master(self._moment1_acc_str, param) moment2 = self._get_accumulator_master(self._moment2_acc_str, param) - moment2_max = self._get_accumulator_master( - self._moment2_acc_max_str, param + moment2_max = ( + self._get_accumulator_master(self._moment2_acc_max_str, param) + if self._amsgrad + else None ) beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param @@ -570,9 +582,10 @@ def _multi_tensor_init(self, target_block, parameters, param_group_idx): self._moment2_dict['FP32_LODTensor'][param_group_idx].append( moment2 ) - self._moment2_max_dict['FP32_LODTensor'][ - param_group_idx - ].append(moment2_max) + if self._amsgrad: + self._moment2_max_dict['FP32_LODTensor'][ + param_group_idx + ].append(moment2_max) self._beta1_pow_acc_dict['FP32_LODTensor'][ param_group_idx ].append(beta1_pow_acc) @@ -589,9 +602,10 @@ def _multi_tensor_init(self, target_block, parameters, param_group_idx): self._moment2_dict['FP16_LODTensor'][param_group_idx].append( moment2 ) - self._moment2_max_dict['FP16_LODTensor'][ - param_group_idx - ].append(moment2_max) + if self._amsgrad: + self._moment2_max_dict['FP16_LODTensor'][ + param_group_idx + ].append(moment2_max) self._beta1_pow_acc_dict['FP16_LODTensor'][ param_group_idx ].append(beta1_pow_acc) @@ -787,7 +801,11 @@ def _append_optimize_multi_tensor_op( lr_dict[key], self._moment1_dict[key][param_group_idx], self._moment2_dict[key][param_group_idx], - self._moment2_max_dict[key][param_group_idx], + ( + self._moment2_max_dict[key][param_group_idx] + if self._amsgrad + else None + ), self._beta1_pow_acc_dict[key][param_group_idx], self._beta2_pow_acc_dict[key][param_group_idx], master_weight, @@ -811,7 +829,11 @@ def _append_optimize_multi_tensor_op( lr_dict[key], self._moment1_dict[key][param_group_idx], self._moment2_dict[key][param_group_idx], - self._moment2_max_dict[key][param_group_idx], + ( + self._moment2_max_dict[key][param_group_idx] + if self._amsgrad + else None + ), self._beta1_pow_acc_dict[key][param_group_idx], self._beta2_pow_acc_dict[key][param_group_idx], master_weight, @@ -829,9 +851,6 @@ def _append_optimize_multi_tensor_op( "LearningRate": lr_dict[key], "Moment1": self._moment1_dict[key][param_group_idx], "Moment2": self._moment2_dict[key][param_group_idx], - "Moment2Max": self._moment2_max_dict[key][ - param_group_idx - ], "Beta1Pow": self._beta1_pow_acc_dict[key][ param_group_idx ], @@ -843,9 +862,6 @@ def _append_optimize_multi_tensor_op( "ParamOut": self._param_dict[key][param_group_idx], "Moment1Out": self._moment1_dict[key][param_group_idx], "Moment2Out": self._moment2_dict[key][param_group_idx], - "Moment2MaxOut": self._moment2_max_dict[key][ - param_group_idx - ], "Beta1PowOut": self._beta1_pow_acc_dict[key][ param_group_idx ], @@ -859,6 +875,15 @@ def _append_optimize_multi_tensor_op( "beta2": _beta2, "amsgrad": self._amsgrad, } + + if self._amsgrad: + inputs["Moment2Max"] = self._moment2_max_dict[key][ + param_group_idx + ] + outputs["Moment2MaxOut"] = self._moment2_max_dict[key][ + param_group_idx + ] + if find_master: inputs["MasterParam"] = self._master_weight_dict[key][ param_group_idx @@ -867,6 +892,7 @@ def _append_optimize_multi_tensor_op( key ][param_group_idx] attrs["multi_precision"] = find_master + target_block.append_op( type="merged_adam", inputs=inputs, diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index c5a4f8334c1e5f..a55a1f6204aa53 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -104,7 +104,8 @@ class AdamW(Optimizer): different semantics with the original Adam algorithm and may lead to different result. The default value is False. multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. - amsgrad (bool, optional): Whether to use the AMSGrad of this algorithm. Default is false. + amsgrad (bool, optional): Whether to use the AMSGrad variant of this algorithm from the paper + `On the Convergence of Adam and Beyond `_. Default is false. name (str|None, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -380,21 +381,26 @@ def _add_moments_pows(self, p): self._add_accumulator( self._moment2_acc_str, p, dtype=core.VarDesc.VarType.FP16 ) - self._add_accumulator( - self._moment2_acc_max_str, - p, - dtype=core.VarDesc.VarType.FP16, - ) + if self._amsgrad: + self._add_accumulator( + self._moment2_acc_max_str, + p, + dtype=core.VarDesc.VarType.FP16, + ) else: self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) - self._add_accumulator( - self._moment2_acc_max_str, p, dtype=acc_dtype - ) + if self._amsgrad: + self._add_accumulator( + self._moment2_acc_max_str, p, dtype=acc_dtype + ) else: self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) - self._add_accumulator(self._moment2_acc_max_str, p, dtype=acc_dtype) + if self._amsgrad: + self._add_accumulator( + self._moment2_acc_max_str, p, dtype=acc_dtype + ) self._add_accumulator( name=self._beta1_pow_acc_str, param=p, @@ -467,8 +473,12 @@ def _append_optimize_op(self, block, param_and_grad): moment2 = self._get_accumulator_master( self._moment2_acc_str, param_and_grad[0] ) - moment2_max = self._get_accumulator_master( - self._moment2_acc_max_str, param_and_grad[0] + moment2_max = ( + self._get_accumulator_master( + self._moment2_acc_max_str, param_and_grad[0] + ) + if self._amsgrad + else None ) beta1_pow_acc = self._get_accumulator_master( self._beta1_pow_acc_str, param_and_grad[0] @@ -540,7 +550,6 @@ def _append_optimize_op(self, block, param_and_grad): "LearningRate": [lr], "Moment1": [moment1], "Moment2": [moment2], - "Moment2Max": [moment2_max], "Beta1Pow": [beta1_pow_acc], "Beta2Pow": [beta2_pow_acc], } @@ -555,7 +564,6 @@ def _append_optimize_op(self, block, param_and_grad): "ParamOut": [param_and_grad[0]], "Moment1Out": [moment1], "Moment2Out": [moment2], - "Moment2MaxOut": [moment2_max], "Beta1PowOut": [beta1_pow_acc], "Beta2PowOut": [beta2_pow_acc], } @@ -586,6 +594,10 @@ def _append_optimize_op(self, block, param_and_grad): else: attrs['epsilon'] = self._epsilon + if self._amsgrad: + inputs["Moment2Max"] = [moment2_max] + outputs["Moment2MaxOut"] = [moment2_max] + if find_master: inputs["MasterParam"] = master_weight outputs["MasterParamOut"] = master_weight diff --git a/test/legacy_test/test_adam_op.py b/test/legacy_test/test_adam_op.py index 0ef28710d29a0f..8007a59a71764a 100644 --- a/test/legacy_test/test_adam_op.py +++ b/test/legacy_test/test_adam_op.py @@ -66,6 +66,8 @@ def adam_wrapper( class TestAdamOp1(OpTest): def set_amsgrad(self): self.amsgrad = False + # no check `Moment2MaxOut` with amsgrad is False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test Adam Op with supplied attributes''' @@ -119,12 +121,13 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(no_check_set=self.no_check_set, check_pir=True) class TestAdamOp1AMSGrad(TestAdamOp1): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None class TestAdamOp2(OpTest): @@ -133,6 +136,7 @@ def set_shape(self): def set_amsgrad(self): self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test Adam Op with supplied attributes''' @@ -187,7 +191,7 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(no_check_set=self.no_check_set, check_pir=True) class TestAdamOnlyTailOp(TestAdamOp2): @@ -198,11 +202,13 @@ def set_shape(self): class TestAdamOp2AMSGrad(TestAdamOp2): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None class TestAdamOpMultipleSteps(OpTest): def set_amsgrad(self): self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test Adam Operator with supplied attributes''' @@ -262,7 +268,7 @@ def test_check_output(self): } # Verify output for this step - self.check_output(check_pir=True) + self.check_output(no_check_set=self.no_check_set, check_pir=True) # Output of this step becomes input for next step self.inputs['Param'] = param_out @@ -283,6 +289,7 @@ def test_check_output(self): class TestAdamOpMultipleStepsAMSGrad(TestAdamOpMultipleSteps): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None def adam_step(inputs, attributes): @@ -326,7 +333,7 @@ def adam_step(inputs, attributes): moment1_out / (np.sqrt(moment2_max_out) + epsilon) ) else: - moment2_max_out = np.zeros_like(moment2_out) + moment2_max_out = np.empty_like(moment2_out) param_out = param - lr_t * ( moment1_out / (np.sqrt(moment2_out) + epsilon) ) @@ -379,7 +386,7 @@ def adamw_step(inputs, attributes): moment1_out / (np.sqrt(moment2_max_out) + epsilon) ) else: - moment2_max_out = np.zeros_like(moment2_out) + moment2_max_out = np.empty_like(moment2_out) param_out = param - lr_t * ( moment1_out / (np.sqrt(moment2_out) + epsilon) ) @@ -434,7 +441,7 @@ def update_row(row_id, update_value): / (np.sqrt(moment2_max_out[row_id]) + epsilon) ) else: - moment2_max_out[row_id] = np.zeros_like(moment2_out[row_id]) + moment2_max_out[row_id] = np.empty_like(moment2_out[row_id]) param_out[row_id] = param[row_id] - lr_t * ( moment1_out[row_id] / (np.sqrt(moment2_out[row_id]) + epsilon) ) @@ -455,6 +462,7 @@ def update_row(row_id, update_value): class TestSparseAdamOp(unittest.TestCase): def set_amsgrad(self): self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] def setup(self, scope, place, lazy_mode): beta1 = 0.78 @@ -567,11 +575,13 @@ def test_sparse_adam(self): class TestSparseAdamOpAMSGrad(TestSparseAdamOp): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None class TestAdamOpBetaVariable(OpTest): def set_amsgrad(self): self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test Adam Op with beta as Variable''' @@ -623,17 +633,19 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(no_check_set=self.no_check_set, check_pir=True) class TestAdamOpBetaVariableAMSGrad(TestAdamOpBetaVariable): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None class TestAdamOpBetaEpsilonVariable(OpTest): def set_amsgrad(self): self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test Adam Op with beta/epsilon as Variable''' @@ -686,17 +698,19 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(no_check_set=self.no_check_set, check_pir=True) class TestAdamOpBetaEpsilonVariableAMSGrad(TestAdamOpBetaEpsilonVariable): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None class TestAdamOpWithGlobalBetaPow(OpTest): def set_amsgrad(self): self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test Adam Op with global_beta_pow''' @@ -754,17 +768,19 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(no_check_set=self.no_check_set, check_pir=True) class TestAdamOpWithGlobalBetaPowAMSGrad(TestAdamOpWithGlobalBetaPow): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None class TestAdamOpWithSkipUpdate(OpTest): def set_amsgrad(self): self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test Adam Op with global_beta_pow''' @@ -819,12 +835,13 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(no_check_set=self.no_check_set, check_pir=True) class TestAdamOpWithSkipUpdateAMSGrad(TestAdamOpWithSkipUpdate): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None class TestAdamOpV2(unittest.TestCase): diff --git a/test/legacy_test/test_adamw_op.py b/test/legacy_test/test_adamw_op.py index 49d6517080175f..eeaf360e16566a 100644 --- a/test/legacy_test/test_adamw_op.py +++ b/test/legacy_test/test_adamw_op.py @@ -64,7 +64,7 @@ def adamw_step(inputs, attributes): moment2_max_out = np.maximum(moment2_out, moment2_max) denom = (np.sqrt(moment2_max_out) / np.sqrt(1.0 - beta2_pow)) + epsilon else: - moment2_max_out = np.zeros_like(moment2_out) + moment2_max_out = np.empty_like(moment2_out) denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon param_out = param + ((moment1_out / denom) * (-(lr / (1.0 - beta1_pow)))) @@ -119,6 +119,8 @@ def adamw_wrapper( class TestAdamW(OpTest): def set_amsgrad(self): self.amsgrad = False + # no check `Moment2MaxOut` with amsgrad is False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test AdamW Op with supplied attributes''' @@ -174,12 +176,13 @@ def setUp(self): } def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(no_check_set=self.no_check_set, check_pir=True) class TestAdamWAMSGrad(TestAdamW): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None @unittest.skipIf( @@ -188,6 +191,7 @@ def set_amsgrad(self): class TestAdamW2(OpTest): def set_amsgrad(self): self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] def setUp(self): '''Test AdamW Op with supplied attributes''' @@ -244,12 +248,17 @@ def setUp(self): } def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0), check_pir=True) + self.check_output_with_place( + no_check_set=self.no_check_set, + place=core.CUDAPlace(0), + check_pir=True, + ) class TestAdamW2AMSGrad(TestAdamW2): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None class TestAdamWOp(unittest.TestCase): diff --git a/test/legacy_test/test_fused_adam_op.py b/test/legacy_test/test_fused_adam_op.py index 1a3af7cb0d0101..225d7c9ab68909 100644 --- a/test/legacy_test/test_fused_adam_op.py +++ b/test/legacy_test/test_fused_adam_op.py @@ -77,7 +77,7 @@ def fused_adam_step(inputs, attributes, num): * (moments1_out[i] / (np.sqrt(moments2_max_out[i]) + epsilon)) ) else: - _moment2_max = np.zeros_like(_moment2_out) + _moment2_max = np.empty_like(_moment2_out) moments2_max_out.append(_moment2_max) params_out.append( @@ -103,6 +103,8 @@ def fused_adam_step(inputs, attributes, num): class TestFusedAdamOp(OpTest): def set_amsgrad(self): self.amsgrad = False + # no check `Moment2MaxOut` with amsgrad is False + self.no_check_set = ['Moments2MaxOut'] def setUp(self): paddle.enable_static() @@ -204,12 +206,15 @@ def setUp(self): def test_check_output(self): paddle.enable_static() if paddle.is_compiled_with_cuda(): - self.check_output(check_dygraph=False) + self.check_output( + no_check_set=self.no_check_set, check_dygraph=False + ) class TestFusedAdamOpAMSGrad(TestFusedAdamOp): def set_amsgrad(self): self.amsgrad = True + self.no_check_set = None if __name__ == "__main__": diff --git a/test/white_list/no_check_set_white_list.py b/test/white_list/no_check_set_white_list.py index 16bf755eecf6ef..c244591490561e 100644 --- a/test/white_list/no_check_set_white_list.py +++ b/test/white_list/no_check_set_white_list.py @@ -40,4 +40,7 @@ 'rrelu', 'layer_norm', 'max_pool2d_v2', + 'adam', # AMSGrad variant no check moment2 max output + 'adamw', # AMSGrad variant no check moment2 max output + 'fused_adam', # AMSGrad variant no check moments2 max output ] From e2d2f9b4a39d67d64cb93177dc36bfd2f049e00b Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 11 Sep 2024 16:07:53 +0800 Subject: [PATCH 13/33] [Update] test_adamw_op.py with new test cast --- test/legacy_test/test_adamw_op.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/test/legacy_test/test_adamw_op.py b/test/legacy_test/test_adamw_op.py index e319210993d7bc..6b94e2cef7cf0e 100644 --- a/test/legacy_test/test_adamw_op.py +++ b/test/legacy_test/test_adamw_op.py @@ -1503,16 +1503,20 @@ def test_weight_decay_int(self): fc1_w = np.array(linear1.weight) fc1_w_mon1 = np.zeros_like(fc1_w) fc1_w_mon2 = np.zeros_like(fc1_w) + fc1_w_mon2_max = np.zeros_like(fc1_w) fc1_b = np.array(linear1.bias) fc1_b_mon1 = np.zeros_like(fc1_b) fc1_b_mon2 = np.zeros_like(fc1_b) + fc1_b_mon2_max = np.zeros_like(fc1_b) fc2_w = np.array(linear2.weight) fc2_w_mon1 = np.zeros_like(fc2_w) fc2_w_mon2 = np.zeros_like(fc2_w) + fc2_w_mon2_max = np.zeros_like(fc2_w) fc2_b = np.array(linear2.bias) fc2_b_mon1 = np.zeros_like(fc2_b) fc2_b_mon2 = np.zeros_like(fc2_b) + fc2_b_mon2_max = np.zeros_like(fc2_b) simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2) learning_rate = 0.001 @@ -1531,14 +1535,18 @@ def test_weight_decay_int(self): apply_decay_param_fun=lambda name: True, weight_decay=weight_decay, lr_ratio=simple_lr_fun, + amsgrad=self.amsgrad, ) - def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): + def get_numpy_output( + param, grad, moment1, moment2, moment2_max, lr_ratio, t + ): np_inputs = { 'Param': param, 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, + 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1**t]).astype("float32"), 'Beta2Pow': np.array([beta2**t]).astype("float32"), @@ -1551,11 +1559,12 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): "lr_ratio": lr_ratio, "coeff": float(weight_decay), "with_decay": True, + "amsgrad": self.amsgrad, } - param_out, moment1_out, moment2_out = adamw_step( + param_out, moment1_out, moment2_out, moment2_out_max = adamw_step( np_inputs, np_attrs ) - return param_out, moment1_out, moment2_out + return param_out, moment1_out, moment2_out, moment2_out_max for i in range(5): a = paddle.to_tensor( @@ -1566,35 +1575,39 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): out = paddle.mean(out) out.backward() - fc1_w, fc1_w_mon1, fc1_w_mon2 = get_numpy_output( + fc1_w, fc1_w_mon1, fc1_w_mon2, fc1_w_mon2_max = get_numpy_output( fc1_w, np.array(linear1.weight.grad), fc1_w_mon1, fc1_w_mon2, + fc1_w_mon2_max, simple_lr_fun(linear1.weight), i + 1, ) - fc1_b, fc1_b_mon1, fc1_b_mon2 = get_numpy_output( + fc1_b, fc1_b_mon1, fc1_b_mon2, fc1_b_mon2_max = get_numpy_output( fc1_b, np.array(linear1.bias.grad), fc1_b_mon1, fc1_b_mon2, + fc1_b_mon2_max, simple_lr_fun(linear1.bias), i + 1, ) - fc2_w, fc2_w_mon1, fc2_w_mon2 = get_numpy_output( + fc2_w, fc2_w_mon1, fc2_w_mon2, fc2_w_mon2_max = get_numpy_output( fc2_w, np.array(linear2.weight.grad), fc2_w_mon1, fc2_w_mon2, + fc2_w_mon2_max, simple_lr_fun(linear2.weight), i + 1, ) - fc2_b, fc2_b_mon1, fc2_b_mon2 = get_numpy_output( + fc2_b, fc2_b_mon1, fc2_b_mon2, fc2_b_mon2_max = get_numpy_output( fc2_b, np.array(linear2.bias.grad), fc2_b_mon1, fc2_b_mon2, + fc2_b_mon2_max, simple_lr_fun(linear2.bias), i + 1, ) From 7d7ddb148a1e99cb3c4bd05875af829088f4a94c Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 12 Sep 2024 13:01:44 +0800 Subject: [PATCH 14/33] [Update] adam adamw with amsgrad formula --- python/paddle/optimizer/adam.py | 21 +++++++++++---------- python/paddle/optimizer/adamw.py | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 2865575b7dfdb4..dd652751b03d92 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -66,16 +66,17 @@ class Adam(Optimizer): .. math:: - t & = t + 1 - - moment\_1\_out & = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad - - moment\_2\_out & = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad - - learning\_rate & = learning\_rate * \ - \frac{\sqrt{1 - {\beta}_2^t}}{1 - {\beta}_1^t} - - param\_out & = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \begin{aligned} + &\hspace{5mm} t = t + 1 \\ + &\hspace{5mm} moment\_1\_out = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad \\ + &\hspace{5mm} moment\_2\_out = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad \\ + &\hspace{5mm} learning\_rate = learning\_rate * \frac{\sqrt{1 - {\beta}_2^t}}{1 - {\beta}_1^t} \\ + &\hspace{5mm}\textbf{if} \: \textit{amsgrad}: \\ + &\hspace{15mm} moment\_2\_max\_out = max(moment\_2\_out, moment\_2\_max) \\ + &\hspace{15mm} param\_out = param - learning\_rate * \frac{moment\_1\_out}{\sqrt{moment\_2\_max\_out} + \epsilon} \\ + &\hspace{5mm}\textbf{else}: \: \\ + &\hspace{15mm} param\_out = param - learning\_rate * \frac{moment\_1\_out}{\sqrt{moment\_2\_out} + \epsilon} \\ + \end{aligned} Related paper: `Adam: A Method for Stochastic Optimization `_ diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 61c18e3c318afa..29a9dcbd00cbef 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -54,17 +54,17 @@ class AdamW(Optimizer): .. math:: - t & = t + 1 - - moment\_1\_out & = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad - - moment\_2\_out & = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad - - learning\_rate & = learning\_rate * - \frac{\sqrt{1 - {\beta}_2^t}}{1 - {beta}_1^t} - - param\_out & = param - learning\_rate * (\frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \lambda * param) - + \begin{aligned} + &\hspace{5mm} t = t + 1 \\ + &\hspace{5mm} moment\_1\_out = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad \\ + &\hspace{5mm} moment\_2\_out = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad \\ + &\hspace{5mm} learning\_rate = learning\_rate * \frac{\sqrt{1 - {\beta}_2^t}}{1 - {\beta}_1^t} \\ + &\hspace{5mm}\textbf{if} \: \textit{amsgrad}: \\ + &\hspace{15mm} moment\_2\_max\_out = max(moment\_2\_out, moment\_2\_max) \\ + &\hspace{15mm} param\_out = param - learning\_rate * (\frac{moment\_1\_out}{\sqrt{moment\_2\_max\_out} + \epsilon} + \lambda * param) \\ + &\hspace{5mm}\textbf{else}: \: \\ + &\hspace{15mm} param\_out = param - learning\_rate * (\frac{moment\_1\_out}{\sqrt{moment\_2\_out} + \epsilon} + \lambda * param) \\ + \end{aligned} Args: learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``. From fc6204f19dba6cbe94226c97f628e412547ef9c5 Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 18 Sep 2024 18:08:26 +0800 Subject: [PATCH 15/33] [Update] adam/adamw for test.cc --- paddle/phi/kernels/funcs/jit/test.cc | 58 ++++++++++++++++++++++++---- test/legacy_test/test_adam_op.py | 2 + 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/funcs/jit/test.cc b/paddle/phi/kernels/funcs/jit/test.cc index 6e1b7ee1536b4d..996420d2fdb8ea 100644 --- a/paddle/phi/kernels/funcs/jit/test.cc +++ b/paddle/phi/kernels/funcs/jit/test.cc @@ -39,6 +39,13 @@ void RandomVec(const int n, } } +template +void ZeroVec(const int n, T* a) { + for (int i = 0; i < n; ++i) { + a[i] = static_cast(0); + } +} + template void ExpectEQ(const T* target, const T* refer, size_t n) { if (std::is_floating_point::value) { @@ -708,20 +715,24 @@ void TestKernelAdam() { T learning_rate = lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow)); T eps = epsilon * sqrt(1 - beta2_pow); + bool amsgrad = false; std::vector param(numel); std::vector grad(numel); std::vector mom1(numel); std::vector mom2(numel); + std::vector mom2_max(numel); std::vector param_out(param.size()); std::vector mom1_out(mom1.size()); std::vector mom2_out(mom2.size()); + std::vector mom2_max_out(mom2_max.size()); RandomVec(numel, param.data(), 0.5f); RandomVec(numel, grad.data(), 0.5f); RandomVec(numel, mom1.data(), 0.5f); RandomVec(numel, mom2.data(), 0.5f); + ZeroVec(numel, mom2_max.data()); auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); @@ -734,10 +745,13 @@ void TestKernelAdam() { grad.data(), mom1.data(), mom2.data(), + mom2_max.data(), param.data(), mom1_out.data(), mom2_out.data(), - param_out.data()); + mom2_max_out.data(), + param_out.data(), + amsgrad); auto verifier = [](const typename KernelTuple::func_type tgt, T beta1, @@ -748,10 +762,13 @@ void TestKernelAdam() { const std::vector& grad, const std::vector& mom1, const std::vector& mom2, + const std::vector& mom2_max, const std::vector& param, const std::vector& ref_mom1_out, const std::vector& ref_mom2_out, - const std::vector& ref_param_out) { + const std::vector& ref_mom2_max_out, + const std::vector& ref_param_out, + bool amsgrad) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(param.size(), static_cast(numel)); EXPECT_EQ(grad.size(), static_cast(numel)); @@ -760,6 +777,7 @@ void TestKernelAdam() { std::vector jit_mom1_out(ref_mom1_out.size()); std::vector jit_mom2_out(ref_mom2_out.size()); + std::vector jit_mom2_max_out(ref_mom2_max_out.size()); std::vector jit_param_out(ref_param_out.size()); tgt(beta1, @@ -770,10 +788,13 @@ void TestKernelAdam() { grad.data(), mom1.data(), mom2.data(), + mom2_max.data(), param.data(), jit_mom1_out.data(), jit_mom2_out.data(), - jit_param_out.data()); + jit_mom2_max_out.data(), + jit_param_out.data(), + amsgrad); ExpectEQ(ref_mom1_out.data(), jit_mom1_out.data(), numel); ExpectEQ(ref_mom2_out.data(), jit_mom2_out.data(), numel); @@ -789,10 +810,13 @@ void TestKernelAdam() { grad, mom1, mom2, + mom2_max, param, mom1_out, mom2_out, - param_out); + mom2_max_out, + param_out, + amsgrad); } template @@ -812,20 +836,25 @@ void TestKernelAdamW() { T learning_rate = old_lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow)); T eps = epsilon * sqrt(1 - beta2_pow); + bool amsgrad = false; std::vector param(numel); std::vector grad(numel); std::vector mom1(numel); std::vector mom2(numel); + std::vector mom2_max(numel); std::vector param_out(param.size()); std::vector mom1_out(mom1.size()); std::vector mom2_out(mom2.size()); + std::vector mom2_max_out(mom2_max.size()); RandomVec(numel, param.data(), 0.5f); RandomVec(numel, grad.data(), 0.5f); RandomVec(numel, mom1.data(), 0.5f); RandomVec(numel, mom2.data(), 0.5f); + ZeroVec(numel, mom2_max.data()); + auto ref = jit::GetReferFunc(); EXPECT_TRUE(ref != nullptr); ref(beta1, @@ -839,10 +868,13 @@ void TestKernelAdamW() { grad.data(), mom1.data(), mom2.data(), + mom2_max.data(), param.data(), mom1_out.data(), mom2_out.data(), - param_out.data()); + mom2_max_out.data(), + param_out.data(), + amsgrad); auto verifier = [](const typename KernelTuple::func_type tgt, T beta1, @@ -856,10 +888,13 @@ void TestKernelAdamW() { const std::vector& grad, const std::vector& mom1, const std::vector& mom2, + const std::vector& mom2_max, const std::vector& param, const std::vector& ref_mom1_out, const std::vector& ref_mom2_out, - const std::vector& ref_param_out) { + const std::vector& ref_mom2_max_out, + const std::vector& ref_param_out, + bool amsgrad) { EXPECT_TRUE(tgt != nullptr); EXPECT_EQ(param.size(), static_cast(numel)); EXPECT_EQ(grad.size(), static_cast(numel)); @@ -868,6 +903,7 @@ void TestKernelAdamW() { std::vector jit_mom1_out(ref_mom1_out.size()); std::vector jit_mom2_out(ref_mom2_out.size()); + std::vector jit_mom2_max_out(ref_mom2_max_out.size()); std::vector jit_param_out(ref_param_out.size()); tgt(beta1, @@ -881,10 +917,13 @@ void TestKernelAdamW() { grad.data(), mom1.data(), mom2.data(), + mom2_max.data(), param.data(), jit_mom1_out.data(), jit_mom2_out.data(), - jit_param_out.data()); + jit_mom2_max_out.data(), + jit_param_out.data(), + amsgrad); ExpectEQ(ref_mom1_out.data(), jit_mom1_out.data(), numel); ExpectEQ(ref_mom2_out.data(), jit_mom2_out.data(), numel); @@ -904,10 +943,13 @@ void TestKernelAdamW() { grad, mom1, mom2, + mom2_max, param, mom1_out, mom2_out, - param_out); + mom2_max_out, + param_out, + amsgrad); } template diff --git a/test/legacy_test/test_adam_op.py b/test/legacy_test/test_adam_op.py index 0107135a73c55d..8090e41178b2fb 100644 --- a/test/legacy_test/test_adam_op.py +++ b/test/legacy_test/test_adam_op.py @@ -1029,6 +1029,8 @@ def test_adam_op_with_sparse_input_and_weight_decay(self): class TestAdamOpV2AMSGrad(TestAdamOpV2): def setUp(self): self.amsgrad = True + + class TestAdamOpV2WeightDecay(unittest.TestCase): def test_weight_decay_int(self): paddle.disable_static() From 01448909672f7a0764e257643359ff7b4dbb98b0 Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 18 Sep 2024 18:32:00 +0800 Subject: [PATCH 16/33] [Fix] xpu param name --- paddle/phi/kernels/xpu/adam_kernel.cc | 6 +++--- paddle/phi/kernels/xpu/adamw_kernel.cc | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc index 1a118507d1e553..1d803d11a24ed3 100644 --- a/paddle/phi/kernels/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -32,7 +32,7 @@ void AdamDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const paddle::optional& moment2_max_ UNUSED, + const paddle::optional& moment2_max UNUSED, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -264,7 +264,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max_ UNUSED, + const std::vector& moment2_max UNUSED, const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -277,7 +277,7 @@ void MergedAdamKernel( std::vector param_out, std::vector moment1_out, std::vector moment2_out, - std::vector moment2_max_out_ UNUSED, + std::vector moment2_max_out UNUSED, std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out) { diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index fbc237ec6b749b..18d14d7d6076da 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -435,11 +435,11 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, moment1_out->set_storage_properties(std::move(moment1_out_sp)); // for moment2 - float moment2_max = GetAbsMax(dev_ctx, - moment2_output_for_xdnn, - buffer_for_findmax, - moment2_out->numel()); - float moment2_scale_value = 65504.0f / moment2_max / 2.0f; + float moment2_max_ = GetAbsMax(dev_ctx, + moment2_output_for_xdnn, + buffer_for_findmax, + moment2_out->numel()); + float moment2_scale_value = 65504.0f / moment2_max_ / 2.0f; // int scale(Context* ctx, const T* x, T* y, int64_t len, bool // bias_after_scale, float _scale, float _bias); r = xpu::scale(dev_ctx.x_context(), @@ -483,7 +483,7 @@ void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const paddle::optional& moment2_max_ UNUSED, + const paddle::optional& moment2_max UNUSED, const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -502,7 +502,7 @@ void AdamwDenseKernel(const Context& dev_ctx, DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, - DenseTensor* moment2_max_out_ UNUSED, + DenseTensor* moment2_max_out UNUSED, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { @@ -806,11 +806,11 @@ void AdamwDenseKernel(const Context& dev_ctx, moment1_out->set_storage_properties(std::move(moment1_out_sp)); // for moment2 - float moment2_max = GetAbsMax(dev_ctx, - moment2_output_for_xdnn, - buffer_for_findmax, - moment2_out->numel()); - float moment2_scale_value = 65504.0f / moment2_max / 2.0f; + float moment2_max_ = GetAbsMax(dev_ctx, + moment2_output_for_xdnn, + buffer_for_findmax, + moment2_out->numel()); + float moment2_scale_value = 65504.0f / moment2_max_ / 2.0f; // int scale(Context* ctx, const T* x, T* y, int64_t len, bool // bias_after_scale, float _scale, float _bias); r = xpu::scale(dev_ctx.x_context(), From c6942c0183a96fdd18662372da90dc810fca953d Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 18 Sep 2024 22:00:45 +0800 Subject: [PATCH 17/33] [Fix] xpu param name & unittest --- paddle/phi/kernels/xpu/adam_kernel.cc | 59 ++++++++++--------- paddle/phi/kernels/xpu/adamw_kernel.cc | 59 ++++++++++--------- .../test_adam_optimizer_fp32_fp64.py | 51 ++++++++-------- 3 files changed, 87 insertions(+), 82 deletions(-) diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc index 1d803d11a24ed3..ec846bc481b8bb 100644 --- a/paddle/phi/kernels/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -26,32 +26,33 @@ namespace phi { template -void AdamDenseKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& grad, - const DenseTensor& learning_rate, - const DenseTensor& moment1, - const DenseTensor& moment2, - const paddle::optional& moment2_max UNUSED, - const DenseTensor& beta1_pow, - const DenseTensor& beta2_pow, - const paddle::optional& master_param, - const paddle::optional& skip_update, - const Scalar& beta1, - const Scalar& beta2, - const Scalar& epsilon, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow, - bool amsgrad UNUSED, - DenseTensor* param_out, - DenseTensor* moment1_out, - DenseTensor* moment2_out, - DenseTensor* moment2_max_out UNUSED, - DenseTensor* beta1_pow_out, - DenseTensor* beta2_pow_out, - DenseTensor* master_param_outs) { +void AdamDenseKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const paddle::optional& moment2_max, // UNUSED + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + bool amsgrad, // UNUSED + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* moment2_max_out, // UNUSED + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs) { xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); float* param_ptr = nullptr; funcs::GetDataPointer( @@ -264,7 +265,7 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max UNUSED, + const std::vector& moment2_max, // UNUSED const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, @@ -273,11 +274,11 @@ void MergedAdamKernel( const Scalar& epsilon, bool multi_precision, bool use_global_beta_pow, - bool amsgrad UNUSED, + bool amsgrad, // UNUSED std::vector param_out, std::vector moment1_out, std::vector moment2_out, - std::vector moment2_max_out UNUSED, + std::vector moment2_max_out, // UNUSED std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out) { diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index 18d14d7d6076da..d2291fa4d17f37 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -477,35 +477,36 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, } template -void AdamwDenseKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& grad, - const DenseTensor& learning_rate, - const DenseTensor& moment1, - const DenseTensor& moment2, - const paddle::optional& moment2_max UNUSED, - const DenseTensor& beta1_pow, - const DenseTensor& beta2_pow, - const paddle::optional& master_param, - const paddle::optional& skip_update, - const Scalar& beta1, - const Scalar& beta2, - const Scalar& epsilon, - float lr_ratio, - float coeff, - bool with_decay, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - bool use_global_beta_pow, - bool amsgrad UNUSED, - DenseTensor* param_out, - DenseTensor* moment1_out, - DenseTensor* moment2_out, - DenseTensor* moment2_max_out UNUSED, - DenseTensor* beta1_pow_out, - DenseTensor* beta2_pow_out, - DenseTensor* master_param_outs) { +void AdamwDenseKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const paddle::optional& moment2_max, // UNUSED + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + float lr_ratio, + float coeff, + bool with_decay, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + bool amsgrad, // UNUSED + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* moment2_max_out, // UNUSED + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs) { auto dev_version = phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId()); if (dev_version == phi::backends::xpu::XPUVersion::XPU3) { diff --git a/test/legacy_test/test_adam_optimizer_fp32_fp64.py b/test/legacy_test/test_adam_optimizer_fp32_fp64.py index d166dff5b3018f..36a54b9a701461 100644 --- a/test/legacy_test/test_adam_optimizer_fp32_fp64.py +++ b/test/legacy_test/test_adam_optimizer_fp32_fp64.py @@ -15,6 +15,8 @@ import os import unittest +from utils import static_guard + import paddle from paddle import base @@ -33,30 +35,31 @@ def get_places(): def main_test_func(place, dtype): - main = base.Program() - startup = base.Program() - with base.program_guard(main, startup): - with base.scope_guard(base.Scope()): - x = paddle.static.data(name='x', shape=[None, 13], dtype=dtype) - y = paddle.static.data(name='y', shape=[None, 1], dtype=dtype) - y_predict = paddle.static.nn.fc(x, size=1) - cost = paddle.nn.functional.square_error_cost( - input=y_predict, label=y - ) - avg_cost = paddle.mean(cost) - - adam_optimizer = paddle.optimizer.Adam(0.01) - adam_optimizer.minimize(avg_cost) - - fetch_list = [avg_cost] - train_reader = paddle.batch( - paddle.dataset.uci_housing.train(), batch_size=1 - ) - feeder = base.DataFeeder(place=place, feed_list=[x, y]) - exe = base.Executor(place) - exe.run(base.default_startup_program()) - for data in train_reader(): - exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) + with static_guard(): + main = base.Program() + startup = base.Program() + with base.program_guard(main, startup): + with base.scope_guard(base.Scope()): + x = paddle.static.data(name='x', shape=[None, 13], dtype=dtype) + y = paddle.static.data(name='y', shape=[None, 1], dtype=dtype) + y_predict = paddle.static.nn.fc(x, size=1) + cost = paddle.nn.functional.square_error_cost( + input=y_predict, label=y + ) + avg_cost = paddle.mean(cost) + + adam_optimizer = paddle.optimizer.Adam(0.01) + adam_optimizer.minimize(avg_cost) + + fetch_list = [avg_cost] + train_reader = paddle.batch( + paddle.dataset.uci_housing.train(), batch_size=1 + ) + feeder = base.DataFeeder(place=place, feed_list=[x, y]) + exe = base.Executor(place) + exe.run(base.default_startup_program()) + for data in train_reader(): + exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) class AdamFp32Test(unittest.TestCase): From f9cb32e5058d083793456800051cccb9af066b44 Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 18 Sep 2024 22:49:36 +0800 Subject: [PATCH 18/33] [Fix] xpu param type --- paddle/phi/kernels/xpu/adam_kernel.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc index ec846bc481b8bb..751911530c6727 100644 --- a/paddle/phi/kernels/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -265,7 +265,8 @@ void MergedAdamKernel( const std::vector& learning_rate, const std::vector& moment1, const std::vector& moment2, - const std::vector& moment2_max, // UNUSED + const paddle::optional>& + moment2_max, // UNUSED const std::vector& beta1_pow, const std::vector& beta2_pow, const paddle::optional>& master_param, From 92ad89dc1d394977b75d1fae2be76cea513bcac6 Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 18 Sep 2024 23:53:10 +0800 Subject: [PATCH 19/33] [Fix] xpu unittest --- test/xpu/test_adam_op_xpu.py | 27 +++++++++++++++++---------- test/xpu/test_adamw_op_xpu.py | 10 ++++------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/test/xpu/test_adam_op_xpu.py b/test/xpu/test_adam_op_xpu.py index 8f5c771cdfa6b9..76da470f3bc741 100644 --- a/test/xpu/test_adam_op_xpu.py +++ b/test/xpu/test_adam_op_xpu.py @@ -45,7 +45,7 @@ def setUp(self): self.set_shape() self.set_inputs() self.set_steps() - param_out, moment1_out, moment2_out = adam_step( + param_out, moment1_out, moment2_out, moment2_out_max = adam_step( self.inputs, self.attrs ) @@ -109,7 +109,11 @@ def set_inputs(self): } def test_check_output(self): - self.check_output_with_place(place=paddle.XPUPlace(0), atol=1e-2) + self.check_output_with_place( + no_check_set=['Moment2MaxOut'], + place=paddle.XPUPlace(0), + atol=1e-2, + ) class TestAdamOp2(TestAdamOp): '''Test Adam Op with supplied attributes''' @@ -163,7 +167,7 @@ def setUp(self): self.set_shape() self.set_inputs() self.set_steps() - param_out, moment1_out, moment2_out = adam_step( + param_out, moment1_out, moment2_out, moment2_out_max = adam_step( self.inputs, self.attrs ) @@ -207,8 +211,8 @@ def set_steps(self): def test_check_output(self): for _ in range(self.num_steps): - param_out, moment1_out, moment2_out = adam_step( - self.inputs, self.attrs + param_out, moment1_out, moment2_out, moment2_out_max = ( + adam_step(self.inputs, self.attrs) ) beta1_pow_out = self.inputs['Beta1Pow'] * self.beta1 @@ -223,7 +227,9 @@ def test_check_output(self): # Verify output for this step self.check_output_with_place( - place=paddle.XPUPlace(0), atol=1e-2 + no_check_set=['Moment2MaxOut'], + place=paddle.XPUPlace(0), + atol=1e-2, ) # Output of this step becomes input for next step @@ -374,7 +380,6 @@ def setup(self, scope, place, lazy_mode): "Param": np.full((height, row_numel), 5.0).astype("float32"), "Moment1": np.full((height, row_numel), 5.0).astype("float32"), "Moment2": np.full((height, row_numel), 5.0).astype("float32"), - "Moment2Max": np.zeros((height, row_numel)).astype("float32"), 'Beta1Pow': beta1_pow, 'Beta2Pow': beta2_pow, "LearningRate": np.full((1), 2.0).astype("float32"), @@ -413,11 +418,12 @@ def setup(self, scope, place, lazy_mode): "ParamOut": param_out, "Moment1Out": mom1, "Moment2Out": mom2, - "Moment2MaxOut": mom2_max, 'Beta1PowOut': beta1_pow * beta1, 'Beta2PowOut': beta2_pow * beta2, } + self.no_check_set = ['Moment2MaxOut'] + def check_with_place(self, place, lazy_mode): scope = core.Scope() self.setup(scope, place, lazy_mode) @@ -442,6 +448,9 @@ def check_with_place(self, place, lazy_mode): adam_op.run(scope, place) for key, np_array in self.outputs.items(): + if key in self.no_check_set: # Currently, xpu NOT support amsgrad. + continue + out_var = scope.var(key).get_tensor() actual = np.array(out_var) actual = actual.reshape([actual.size]) @@ -474,7 +483,6 @@ def setup(self, scope, place, lazy_mode): "Param": np.full((height, row_numel), 5.0).astype("float16"), "Moment1": np.full((height, row_numel), 5.0).astype("float16"), "Moment2": np.full((height, row_numel), 5.0).astype("float16"), - "Moment2Max": np.zeros((height, row_numel)).astype("float16"), 'Beta1Pow': beta1_pow, 'Beta2Pow': beta2_pow, "LearningRate": np.full((1), 2.0).astype("float16"), @@ -513,7 +521,6 @@ def setup(self, scope, place, lazy_mode): "ParamOut": param_out, "Moment1Out": mom1, "Moment2Out": mom2, - "Moment2MaxOut": mom2_max, 'Beta1PowOut': beta1_pow * beta1, 'Beta2PowOut': beta2_pow * beta2, } diff --git a/test/xpu/test_adamw_op_xpu.py b/test/xpu/test_adamw_op_xpu.py index c4723c136f3e27..f60035579400c3 100644 --- a/test/xpu/test_adamw_op_xpu.py +++ b/test/xpu/test_adamw_op_xpu.py @@ -74,6 +74,7 @@ def adamw_step(inputs, attributes): denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon param_out = param + ((moment1_out / denom) * (-(lr / (1.0 - beta1_pow)))) + return param_out, moment1_out, moment2_out, moment2_max_out @@ -104,7 +105,6 @@ def setUp(self): moment1 = np.random.uniform(-1, 1, self.shape).astype("float32") # The second moment is positive moment2 = np.random.random(self.shape).astype("float32") - moment2_max = np.zeros(self.shape).astype("float32") learning_rate = 0.004 beta1 = 0.78 @@ -120,7 +120,6 @@ def setUp(self): 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, - 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1_pow]).astype("float32"), 'Beta2Pow': np.array([beta2_pow]).astype("float32"), @@ -142,7 +141,6 @@ def setUp(self): self.outputs = { 'Moment1Out': moment1_out, 'Moment2Out': moment2_out, - 'Moment2MaxOut': moment2_max_out, 'ParamOut': param_out, 'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1, 'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2, @@ -162,7 +160,9 @@ def init_shape(self): def test_check_output(self): paddle.enable_static() - self.check_output_with_place(place=paddle.XPUPlace(0)) + self.check_output_with_place( + no_check_set=['Moment2MaxOut'], place=paddle.XPUPlace(0) + ) # Currently, xpu NOT support amsgrad. def infer_dtype_from_inputs_outputs(self, inputs, outputs): self.__class__.dtype = self.dtype @@ -415,7 +415,6 @@ def get_numpy_output( 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, - 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1**t]).astype("float32"), 'Beta2Pow': np.array([beta2**t]).astype("float32"), @@ -607,7 +606,6 @@ def get_numpy_output( 'Grad': grad, 'Moment1': moment1, 'Moment2': moment2, - 'Moment2Max': moment2_max, 'LearningRate': np.array([learning_rate]).astype("float32"), 'Beta1Pow': np.array([beta1**t]).astype("float32"), 'Beta2Pow': np.array([beta2**t]).astype("float32"), From 8e026cd71c1dbcc0179f6ab15c93208ba27eb642 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 19 Sep 2024 12:30:06 +0800 Subject: [PATCH 20/33] [Fix] xpu unittest --- test/xpu/test_adam_op_xpu.py | 12 ++++++------ test/xpu/test_adamw_op_xpu.py | 6 +++--- test/xpu/test_merged_adam_op_xpu.py | 5 +++++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/test/xpu/test_adam_op_xpu.py b/test/xpu/test_adam_op_xpu.py index 76da470f3bc741..2567080315e229 100644 --- a/test/xpu/test_adam_op_xpu.py +++ b/test/xpu/test_adam_op_xpu.py @@ -259,7 +259,7 @@ def adam_step(inputs, attributes): grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] - moment2_max = inputs['Moment2Max'] + moment2_max = inputs.get('Moment2Max', None) lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -275,7 +275,7 @@ def adam_step(inputs, attributes): else: beta2 = inputs['Beta2Tensor'][0] - amsgrad = attributes['amsgrad'] + amsgrad = attributes.get('amsgrad', False) moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) @@ -289,7 +289,7 @@ def adam_step(inputs, attributes): / (np.sqrt(moment2_max_out) + epsilon * np.sqrt(1 - beta2_pow)) ) else: - moment2_max_out = np.zeros_like(moment2_out) + moment2_max_out = np.empty_like(moment2_out) param_out = param - lr_t * ( moment1_out / (np.sqrt(moment2_out) + epsilon * np.sqrt(1 - beta2_pow)) @@ -312,7 +312,7 @@ def adam_step_sparse( # grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] - moment2_max = inputs['Moment2Max'] + moment2_max = inputs.get('Moment2Max', None) lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -320,7 +320,7 @@ def adam_step_sparse( beta1 = attributes['beta1'] beta2 = attributes['beta2'] epsilon = attributes['epsilon'] - amsgrad = attributes['amsgrad'] + amsgrad = attributes.get('amsgrad', False) moment1_out = np.zeros(shape=[height, row_numel]) moment2_out = np.zeros(shape=[height, row_numel]) @@ -345,7 +345,7 @@ def update_row(row_id, update_value): / (np.sqrt(moment2_max_out[row_id]) + epsilon) ) else: - moment2_max_out[row_id] = np.zeros_like(moment2_out[row_id]) + moment2_max_out[row_id] = np.empty_like(moment2_out[row_id]) param_out[row_id] = param[row_id] - lr_t * ( moment1_out[row_id] / (np.sqrt(moment2_out[row_id]) + epsilon) ) diff --git a/test/xpu/test_adamw_op_xpu.py b/test/xpu/test_adamw_op_xpu.py index f60035579400c3..8351829a672618 100644 --- a/test/xpu/test_adamw_op_xpu.py +++ b/test/xpu/test_adamw_op_xpu.py @@ -36,7 +36,7 @@ def adamw_step(inputs, attributes): grad = inputs['Grad'] moment1 = inputs['Moment1'] moment2 = inputs['Moment2'] - moment2_max = inputs['Moment2Max'] + moment2_max = inputs.get('Moment2Max', None) lr = inputs['LearningRate'] beta1_pow = inputs['Beta1Pow'] beta2_pow = inputs['Beta2Pow'] @@ -61,7 +61,7 @@ def adamw_step(inputs, attributes): else: beta2 = inputs['Beta2Tensor'][0] - amsgrad = attributes['amsgrad'] + amsgrad = attributes.get('amsgrad', False) moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) @@ -70,7 +70,7 @@ def adamw_step(inputs, attributes): moment2_max_out = np.maximum(moment2_out, moment2_max) denom = (np.sqrt(moment2_max_out) / np.sqrt(1.0 - beta2_pow)) + epsilon else: - moment2_max_out = np.zeros_like(moment2_out) + moment2_max_out = np.empty_like(moment2_out) denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon param_out = param + ((moment1_out / denom) * (-(lr / (1.0 - beta1_pow)))) diff --git a/test/xpu/test_merged_adam_op_xpu.py b/test/xpu/test_merged_adam_op_xpu.py index b8bdda757e6b74..b4aae12cb3ac3d 100644 --- a/test/xpu/test_merged_adam_op_xpu.py +++ b/test/xpu/test_merged_adam_op_xpu.py @@ -93,6 +93,7 @@ def run_adam_op( beta2, 'multi_precision', False, + 'amsgrad', amsgrad, ) else: @@ -136,6 +137,7 @@ class XPUTestMergedAdamBase(unittest.TestCase): def setUp(self): self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] self.seed = 10 + self.no_check_set = ['Moment2MaxOut'] def gen_rand_data(self, shapes, dtype): return [np.random.random(s).astype(dtype) for s in shapes] @@ -214,6 +216,9 @@ def run_op(use_merged, place): self.assertEqual(len(outs1), len(outs4)) for key in outs1.keys(): + if key in self.no_check_set: + continue + value1 = outs1[key] value2 = outs2[key] value3 = outs3[key] From 56d26df219129f64dc87b1b3e4613380d440d1e1 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 19 Sep 2024 13:45:41 +0800 Subject: [PATCH 21/33] [Fix] xpu unittest --- test/xpu/test_adam_op_xpu.py | 4 +--- test/xpu/test_merged_adam_op_xpu.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/test/xpu/test_adam_op_xpu.py b/test/xpu/test_adam_op_xpu.py index 2567080315e229..025881ed43be89 100644 --- a/test/xpu/test_adam_op_xpu.py +++ b/test/xpu/test_adam_op_xpu.py @@ -422,8 +422,6 @@ def setup(self, scope, place, lazy_mode): 'Beta2PowOut': beta2_pow * beta2, } - self.no_check_set = ['Moment2MaxOut'] - def check_with_place(self, place, lazy_mode): scope = core.Scope() self.setup(scope, place, lazy_mode) @@ -448,7 +446,7 @@ def check_with_place(self, place, lazy_mode): adam_op.run(scope, place) for key, np_array in self.outputs.items(): - if key in self.no_check_set: # Currently, xpu NOT support amsgrad. + if key in ['Moment2MaxOut']: # Currently, xpu NOT support amsgrad. continue out_var = scope.var(key).get_tensor() diff --git a/test/xpu/test_merged_adam_op_xpu.py b/test/xpu/test_merged_adam_op_xpu.py index b4aae12cb3ac3d..20cfdf1fe83332 100644 --- a/test/xpu/test_merged_adam_op_xpu.py +++ b/test/xpu/test_merged_adam_op_xpu.py @@ -137,7 +137,6 @@ class XPUTestMergedAdamBase(unittest.TestCase): def setUp(self): self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] self.seed = 10 - self.no_check_set = ['Moment2MaxOut'] def gen_rand_data(self, shapes, dtype): return [np.random.random(s).astype(dtype) for s in shapes] @@ -216,7 +215,7 @@ def run_op(use_merged, place): self.assertEqual(len(outs1), len(outs4)) for key in outs1.keys(): - if key in self.no_check_set: + if key in ['Moment2MaxOut']: continue value1 = outs1[key] From 26c7e63f6bde7e08f8ed85d5bc61dfbf7d02da99 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 19 Sep 2024 20:41:39 +0800 Subject: [PATCH 22/33] [Fix] merged_adam_ op_compat.yaml --- paddle/phi/ops/yaml/op_compat.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 12326ebf445c77..93ab4275c8e4a3 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -2514,9 +2514,9 @@ - op : merged_adam_ inputs : - {param: Param, grad: Grad, learning_rate: LearningRate, moment1: Moment1, moment2: Moment2, moment2: Moment2Max, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, master_param: MasterParam} + {param: Param, grad: Grad, learning_rate: LearningRate, moment1: Moment1, moment2: Moment2, moment2_max: Moment2Max, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, master_param: MasterParam} outputs : - {param_out: ParamOut, moment1_out: Moment1Out, moment2_out: Moment2Out, moment2_out: Moment2MaxOut, beta1_pow_out: Beta1PowOut, beta2_pow_out: Beta2PowOut, master_param_out: MasterParamOut} + {param_out: ParamOut, moment1_out: Moment1Out, moment2_out: Moment2Out, moment2_max_out: Moment2MaxOut, beta1_pow_out: Beta1PowOut, beta2_pow_out: Beta2PowOut, master_param_out: MasterParamOut} scalar : beta1 : data_type : float From ddb20357a9a9cb7245f57abc4ceda52fd103be32 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 19 Sep 2024 22:28:28 +0800 Subject: [PATCH 23/33] [Fix] remove UNUSED --- paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc index 31cd0d18d18f8d..b78605e93fce0f 100644 --- a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc @@ -34,7 +34,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const paddle::optional& moment2_max UNUSED, + const paddle::optional& moment2_max, // UNUSED const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -46,11 +46,11 @@ void AdamDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, - bool amsgrad UNUSED, + bool amsgrad, // UNUSED DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, - DenseTensor* moment2_max_out UNUSED, + DenseTensor* moment2_max_out, // UNUSED DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { From e41b66b3508e7221ea34c4dda7025b679597d1c8 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 19 Sep 2024 22:29:36 +0800 Subject: [PATCH 24/33] [Fix] remove UNUSED --- paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc index b78605e93fce0f..a92aca83298667 100644 --- a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc @@ -34,7 +34,7 @@ void AdamDenseParamSparseGradKernel( const DenseTensor& learning_rate, const DenseTensor& moment1, const DenseTensor& moment2, - const paddle::optional& moment2_max, // UNUSED + const paddle::optional& moment2_max, // UNUSED const DenseTensor& beta1_pow, const DenseTensor& beta2_pow, const paddle::optional& master_param, @@ -46,11 +46,11 @@ void AdamDenseParamSparseGradKernel( int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow, - bool amsgrad, // UNUSED + bool amsgrad, // UNUSED DenseTensor* param_out, DenseTensor* moment1_out, DenseTensor* moment2_out, - DenseTensor* moment2_max_out, // UNUSED + DenseTensor* moment2_max_out, // UNUSED DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { From 1f2831a1f6644c3fbb94e751b8ef5ca29f3e515a Mon Sep 17 00:00:00 2001 From: megemini Date: Fri, 20 Sep 2024 22:48:53 +0800 Subject: [PATCH 25/33] [Update] unittest adam op --- test/legacy_test/test_adam_op.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/legacy_test/test_adam_op.py b/test/legacy_test/test_adam_op.py index 8090e41178b2fb..75231fb6a5e4a6 100644 --- a/test/legacy_test/test_adam_op.py +++ b/test/legacy_test/test_adam_op.py @@ -549,6 +549,10 @@ def check_with_place(self, place, lazy_mode): adam_op.run(scope, place) for key, np_array in self.outputs.items(): + # do not check keys in `no_check_set`` + if self.no_check_set is not None and key in self.no_check_set: + continue + out_var = scope.var(key).get_tensor() actual = np.array(out_var) actual = actual.reshape([actual.size]) From cfbd1730e3ee95a438e624fccd3f61fe2ca02782 Mon Sep 17 00:00:00 2001 From: megemini Date: Sat, 21 Sep 2024 14:44:31 +0800 Subject: [PATCH 26/33] [Fix] op_compat.yaml --- paddle/phi/ops/yaml/op_compat.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index b59952acad4017..ceef1f2bfce21a 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -1425,7 +1425,7 @@ - op : fused_adam_(fused_adam) inputs : {params : Params, grads : Grads, learning_rate : LearningRate, moments1 : Moments1, - moments2 : Moments2, moments2 : Moments2Max, beta1_pows : Beta1Pows, beta2_pows : Beta2Pows, master_params : MasterParams, + moments2 : Moments2, moments2_max : Moments2Max, beta1_pows : Beta1Pows, beta2_pows : Beta2Pows, master_params : MasterParams, skip_update : SkipUpdate} outputs : {params_out : ParamsOut, moments1_out : Moments1Out, moments2_out : Moments2Out, moments2_max_out : Moments2MaxOut, From 9f977aca66592431beaadbc4d09a3edeff02c2e2 Mon Sep 17 00:00:00 2001 From: megemini Date: Sun, 29 Sep 2024 22:39:19 +0800 Subject: [PATCH 27/33] [Update] assembly for adam adamw --- paddle/phi/kernels/cpu/adam_kernel.cc | 2 +- paddle/phi/kernels/cpu/adamw_kernel.cc | 3 +- paddle/phi/kernels/funcs/jit/gen/adam.cc | 52 ++- paddle/phi/kernels/funcs/jit/gen/adam.h | 13 +- paddle/phi/kernels/funcs/jit/gen/adamw.cc | 50 ++- paddle/phi/kernels/funcs/jit/gen/adamw.h | 13 +- paddle/phi/kernels/funcs/jit/kernel_base.h | 14 +- paddle/phi/kernels/funcs/jit/kernel_key.cc | 11 +- paddle/phi/kernels/funcs/jit/test.cc | 493 +++++++++++---------- 9 files changed, 381 insertions(+), 270 deletions(-) diff --git a/paddle/phi/kernels/cpu/adam_kernel.cc b/paddle/phi/kernels/cpu/adam_kernel.cc index 7aab5d6c8bab0b..84b3d3c2257075 100644 --- a/paddle/phi/kernels/cpu/adam_kernel.cc +++ b/paddle/phi/kernels/cpu/adam_kernel.cc @@ -129,7 +129,7 @@ void AdamDenseKernel(const Context& dev_ctx, learning_rate.data()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); T eps = epsilon_ * sqrt(1 - beta2_p); - phi::jit::adam_attr_t attr(beta1_, beta2_); + phi::jit::adam_attr_t attr(beta1_, beta2_, amsgrad); int64_t numel = param.numel(); const T* param_ptr = param.data(); diff --git a/paddle/phi/kernels/cpu/adamw_kernel.cc b/paddle/phi/kernels/cpu/adamw_kernel.cc index ede1986189473a..868a0dd4cd7983 100644 --- a/paddle/phi/kernels/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/cpu/adamw_kernel.cc @@ -143,6 +143,7 @@ void AdamwDenseKernel(const Context& dev_ctx, learning_rate.data()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); T eps = epsilon_ * sqrt(1 - beta2_p); + phi::jit::adamw_attr_t attr(beta1_, beta2_, coeff_, amsgrad); int64_t numel = param.numel(); const T* param_ptr = param.data(); @@ -153,7 +154,7 @@ void AdamwDenseKernel(const Context& dev_ctx, auto adamw = phi::jit::KernelFuncs, phi::CPUPlace>::Cache().At( - 1); + attr); static constexpr int64_t chunk_size = 512; diff --git a/paddle/phi/kernels/funcs/jit/gen/adam.cc b/paddle/phi/kernels/funcs/jit/gen/adam.cc index fd151b75e8fbb8..af766f295381f5 100644 --- a/paddle/phi/kernels/funcs/jit/gen/adam.cc +++ b/paddle/phi/kernels/funcs/jit/gen/adam.cc @@ -28,8 +28,12 @@ void AdamJitCode::loadArgs() { static_cast(0xFFFFFFFFFFFFFFF8); static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8; - mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]); - mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]); + mov(reg_mom1_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]); + mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]); + mov(reg_mom2_max_out_ptr, ptr[rsp + (abi_pushes_offset + 24)]); + mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 32)]); + mov(reg_amsgrad, byte[rsp + (abi_pushes_offset + 40)]); + mov(eax, one_as_float); movd(xmm_one, eax); @@ -54,6 +58,9 @@ void AdamJitCode::loadArgs() { } void AdamJitCode::setTailOpmask() { + push(r13); + push(r14); + mov(r13, rcx); mov(rcx, reg_numel); @@ -65,6 +72,9 @@ void AdamJitCode::setTailOpmask() { kmovw(k1, r14d); mov(rcx, r13); + + pop(r14); + pop(r13); } void AdamJitCode::mainCode() { @@ -84,16 +94,32 @@ void AdamJitCode::mainCode() { vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8); vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm7); - // sqrt(mom2) + eps - vsqrtps(ymm7 | k1, ymm7); - vaddps(ymm7 | k1, ymm7, ymm_eps); + // make a local label: `.without_amsgrad` + inLocalLabel(); + // if not amsgrad then update params + cmp(reg_amsgrad, 0); + je(".without_amsgrad", T_NEAR); + // load mom2_max + vmovups(ymm9 | k1, ptr[reg_mom2_max_ptr + reg_offset]); + // compare mom2 and mom2_max and save to mom2 + vmaxps(ymm7 | k1, ymm7, ymm9); + // store mom2_max + vmovups(ptr[reg_mom2_max_out_ptr + reg_offset] | k1, ymm7); + + L(".without_amsgrad"); + { + // sqrt(mom2) + eps + vsqrtps(ymm7 | k1, ymm7); + vaddps(ymm7 | k1, ymm7, ymm_eps); - // p + (-lr) * (mom1 / sqrt(mom2) + eps) - vdivps(ymm7 | k1, ymm8, ymm7); - vfmadd213ps(ymm7 | k1, ymm_lr, ptr[reg_param_ptr + reg_offset]); + // p + (-lr) * (mom1 / sqrt(mom2) + eps) + vdivps(ymm7 | k1, ymm8, ymm7); + vfmadd213ps(ymm7 | k1, ymm_lr, ptr[reg_param_ptr + reg_offset]); - // store p - vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7); + // store p + vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7); + } + outLocalLabel(); } void AdamJitCode::genCode() { @@ -104,18 +130,18 @@ void AdamJitCode::genCode() { loadArgs(); cmp(reg_numel, main_loop_elems_size); - jl("process_tail"); + jl("process_tail", T_NEAR); L("main_loop"); { mainCode(); add(reg_offset, offset_increment); cmp(reg_numel_without_tail, reg_offset); - jg("main_loop"); + jg("main_loop", T_NEAR); } cmp(reg_numel, reg_offset); - je("end"); + je("end", T_NEAR); L("process_tail"); { diff --git a/paddle/phi/kernels/funcs/jit/gen/adam.h b/paddle/phi/kernels/funcs/jit/gen/adam.h index 5c432e03ec9214..c4cbce01ccf16b 100644 --- a/paddle/phi/kernels/funcs/jit/gen/adam.h +++ b/paddle/phi/kernels/funcs/jit/gen/adam.h @@ -44,8 +44,8 @@ class AdamJitCode : public JitCode { reg64_t reg_grad_ptr{abi_param2}; reg64_t reg_mom1_ptr{abi_param3}; reg64_t reg_mom2_ptr{abi_param4}; - reg64_t reg_param_ptr{abi_param5}; - reg64_t reg_mom1_out_ptr{abi_param6}; + reg64_t reg_mom2_max_ptr{abi_param5}; + reg64_t reg_param_ptr{abi_param6}; xmm_t xmm_beta1 = xmm_t(0); xmm_t xmm_beta2 = xmm_t(1); @@ -63,9 +63,12 @@ class AdamJitCode : public JitCode { ymm_t ymm_one_sub_beta2 = ymm_t(5); ymm_t ymm_one = ymm_t(6); - reg64_t reg_mom2_out_ptr{r10}; - reg64_t reg_param_out_ptr{r11}; - reg64_t reg_numel_without_tail{r12}; + reg64_t reg_mom1_out_ptr{r10}; + reg64_t reg_mom2_out_ptr{r11}; + reg64_t reg_mom2_max_out_ptr{r12}; + reg64_t reg_param_out_ptr{r13}; + reg64_t reg_amsgrad{r14}; + reg64_t reg_numel_without_tail{r15}; reg64_t reg_offset{rax}; }; diff --git a/paddle/phi/kernels/funcs/jit/gen/adamw.cc b/paddle/phi/kernels/funcs/jit/gen/adamw.cc index 4a8545c24f9649..2fd0e8e75b248f 100644 --- a/paddle/phi/kernels/funcs/jit/gen/adamw.cc +++ b/paddle/phi/kernels/funcs/jit/gen/adamw.cc @@ -28,8 +28,12 @@ void AdamWJitCode::loadArgs() { static_cast(0xFFFFFFFFFFFFFFF8); static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8; - mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]); - mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]); + mov(reg_mom1_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]); + mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]); + mov(reg_mom2_max_out_ptr, ptr[rsp + (abi_pushes_offset + 24)]); + mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 32)]); + mov(reg_amsgrad, byte[rsp + (abi_pushes_offset + 40)]); + mov(eax, one_as_float); movd(xmm_one, eax); @@ -57,6 +61,9 @@ void AdamWJitCode::loadArgs() { } void AdamWJitCode::setTailOpmask() { + push(r13); + push(r14); + mov(r13, rcx); mov(rcx, reg_numel); @@ -68,6 +75,9 @@ void AdamWJitCode::setTailOpmask() { kmovw(k1, r14d); mov(rcx, r13); + + pop(r14); + pop(r13); } void AdamWJitCode::mainCode() { @@ -98,16 +108,32 @@ void AdamWJitCode::mainCode() { vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm12); vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm10); - // sqrt(mom2) + eps - vsqrtps(ymm10 | k1, ymm10); - vaddps(ymm10 | k1, ymm10, ymm_eps); + // // make a local label: `.without_amsgrad` + inLocalLabel(); + // if not amsgrad then update params + cmp(reg_amsgrad, 0); + je(".without_amsgrad", T_NEAR); + // load mom2_max + vmovups(ymm13 | k1, ptr[reg_mom2_max_ptr + reg_offset]); + // compare mom2 and mom2_max and save to mom2 + vmaxps(ymm10 | k1, ymm10, ymm13); + // store mom2_max + vmovups(ptr[reg_mom2_max_out_ptr + reg_offset] | k1, ymm10); + + L(".without_amsgrad"); + { + // sqrt(mom2) + eps + vsqrtps(ymm10 | k1, ymm10); + vaddps(ymm10 | k1, ymm10, ymm_eps); - // p + (-lr) * (mom1 / sqrt(mom2) + eps) - vdivps(ymm10 | k1, ymm12, ymm10); - vfmadd213ps(ymm10 | k1, ymm_lr, ymm11); + // p + (-lr) * (mom1 / sqrt(mom2) + eps) + vdivps(ymm10 | k1, ymm12, ymm10); + vfmadd213ps(ymm10 | k1, ymm_lr, ymm11); - // store p - vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm10); + // store p + vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm10); + } + outLocalLabel(); } void AdamWJitCode::genCode() { @@ -118,14 +144,14 @@ void AdamWJitCode::genCode() { loadArgs(); cmp(reg_numel, main_loop_elems_size); - jl("process_tail"); + jl("process_tail", T_NEAR); L("main_loop"); { mainCode(); add(reg_offset, offset_increment); cmp(reg_numel_without_tail, reg_offset); - jg("main_loop"); + jg("main_loop", T_NEAR); } cmp(reg_numel, reg_offset); diff --git a/paddle/phi/kernels/funcs/jit/gen/adamw.h b/paddle/phi/kernels/funcs/jit/gen/adamw.h index dab90e0e0f69e1..c9d465c02f6022 100644 --- a/paddle/phi/kernels/funcs/jit/gen/adamw.h +++ b/paddle/phi/kernels/funcs/jit/gen/adamw.h @@ -44,8 +44,8 @@ class AdamWJitCode : public JitCode { reg64_t reg_grad_ptr{abi_param2}; reg64_t reg_mom1_ptr{abi_param3}; reg64_t reg_mom2_ptr{abi_param4}; - reg64_t reg_param_ptr{abi_param5}; - reg64_t reg_mom1_out_ptr{abi_param6}; + reg64_t reg_mom2_max_ptr{abi_param5}; + reg64_t reg_param_ptr{abi_param6}; xmm_t xmm_beta1 = xmm_t(0); xmm_t xmm_beta2 = xmm_t(1); @@ -69,9 +69,12 @@ class AdamWJitCode : public JitCode { ymm_t ymm_one_sub_beta2 = ymm_t(8); ymm_t ymm_one = ymm_t(9); - reg64_t reg_mom2_out_ptr{r10}; - reg64_t reg_param_out_ptr{r11}; - reg64_t reg_numel_without_tail{r12}; + reg64_t reg_mom1_out_ptr{r10}; + reg64_t reg_mom2_out_ptr{r11}; + reg64_t reg_mom2_max_out_ptr{r12}; + reg64_t reg_param_out_ptr{r13}; + reg64_t reg_amsgrad{r14}; + reg64_t reg_numel_without_tail{r15}; reg64_t reg_offset{rax}; }; diff --git a/paddle/phi/kernels/funcs/jit/kernel_base.h b/paddle/phi/kernels/funcs/jit/kernel_base.h index e0c35a51644eb3..a41c96a7562740 100644 --- a/paddle/phi/kernels/funcs/jit/kernel_base.h +++ b/paddle/phi/kernels/funcs/jit/kernel_base.h @@ -266,8 +266,10 @@ struct SgdTuple { typedef struct adam_attr_s { float beta1, beta2; + bool amsgrad; adam_attr_s() = default; - explicit adam_attr_s(float beta1, float beta2) : beta1(beta1), beta2(beta2) {} + explicit adam_attr_s(float beta1, float beta2, bool amsgrad) + : beta1(beta1), beta2(beta2), amsgrad(amsgrad) {} } adam_attr_t; template @@ -292,11 +294,19 @@ struct AdamTuple { bool); }; +typedef struct adamw_attr_s { + float beta1, beta2, coeff; + bool amsgrad; + adamw_attr_s() = default; + explicit adamw_attr_s(float beta1, float beta2, float coeff, bool amsgrad) + : beta1(beta1), beta2(beta2), coeff(coeff), amsgrad(amsgrad) {} +} adamw_attr_t; + template struct AdamWTuple { static constexpr KernelType kernel_type = kAdamW; typedef T data_type; - typedef int attr_type; + typedef adamw_attr_t attr_type; typedef void (*func_type)(T, T, T, diff --git a/paddle/phi/kernels/funcs/jit/kernel_key.cc b/paddle/phi/kernels/funcs/jit/kernel_key.cc index 818b3c0a9f1610..fddd5bd69ee025 100644 --- a/paddle/phi/kernels/funcs/jit/kernel_key.cc +++ b/paddle/phi/kernels/funcs/jit/kernel_key.cc @@ -67,7 +67,16 @@ int64_t JitCodeKey(const sgd_attr_t& attr) { template <> int64_t JitCodeKey(const adam_attr_t& attr) { - return static_cast(attr.beta1 + attr.beta2); + // if use amsgrad, we add `10` for hashcode + return static_cast(attr.beta1 + attr.beta2 + + (attr.amsgrad ? 10 : 0)); +} + +template <> +int64_t JitCodeKey(const adamw_attr_t& attr) { + // if use amsgrad, we add `10` for hashcode + return static_cast(attr.beta1 + attr.beta2 + attr.coeff + + (attr.amsgrad ? 10 : 0)); } } // namespace phi::jit diff --git a/paddle/phi/kernels/funcs/jit/test.cc b/paddle/phi/kernels/funcs/jit/test.cc index 996420d2fdb8ea..fa26bc87f079c3 100644 --- a/paddle/phi/kernels/funcs/jit/test.cc +++ b/paddle/phi/kernels/funcs/jit/test.cc @@ -39,13 +39,6 @@ void RandomVec(const int n, } } -template -void ZeroVec(const int n, T* a) { - for (int i = 0; i < n; ++i) { - a[i] = static_cast(0); - } -} - template void ExpectEQ(const T* target, const T* refer, size_t n) { if (std::is_floating_point::value) { @@ -702,87 +695,47 @@ void TestKernelMatMul() { template void TestKernelAdam() { - using T = typename KernelTuple::data_type; - VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); - const T lr = 0.1; - const T beta1 = 0.99; - const T beta2 = 0.95; - const T beta1_pow = beta1 * beta1; - const T beta2_pow = beta2 * beta2; - - const T epsilon = 0.000001; - const int64_t numel = 123; - - T learning_rate = lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow)); - T eps = epsilon * sqrt(1 - beta2_pow); - bool amsgrad = false; - - std::vector param(numel); - std::vector grad(numel); - std::vector mom1(numel); - std::vector mom2(numel); - std::vector mom2_max(numel); - - std::vector param_out(param.size()); - std::vector mom1_out(mom1.size()); - std::vector mom2_out(mom2.size()); - std::vector mom2_max_out(mom2_max.size()); - - RandomVec(numel, param.data(), 0.5f); - RandomVec(numel, grad.data(), 0.5f); - RandomVec(numel, mom1.data(), 0.5f); - RandomVec(numel, mom2.data(), 0.5f); - ZeroVec(numel, mom2_max.data()); - - auto ref = jit::GetReferFunc(); - EXPECT_TRUE(ref != nullptr); - jit::adam_attr_t attr(beta1, beta2); - ref(beta1, - beta2, - -learning_rate, - eps, - numel, - grad.data(), - mom1.data(), - mom2.data(), - mom2_max.data(), - param.data(), - mom1_out.data(), - mom2_out.data(), - mom2_max_out.data(), - param_out.data(), - amsgrad); - - auto verifier = [](const typename KernelTuple::func_type tgt, - T beta1, - T beta2, - T lr, - T eps, - int64_t numel, - const std::vector& grad, - const std::vector& mom1, - const std::vector& mom2, - const std::vector& mom2_max, - const std::vector& param, - const std::vector& ref_mom1_out, - const std::vector& ref_mom2_out, - const std::vector& ref_mom2_max_out, - const std::vector& ref_param_out, - bool amsgrad) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(param.size(), static_cast(numel)); - EXPECT_EQ(grad.size(), static_cast(numel)); - EXPECT_EQ(mom1.size(), static_cast(numel)); - EXPECT_EQ(mom2.size(), static_cast(numel)); - - std::vector jit_mom1_out(ref_mom1_out.size()); - std::vector jit_mom2_out(ref_mom2_out.size()); - std::vector jit_mom2_max_out(ref_mom2_max_out.size()); - std::vector jit_param_out(ref_param_out.size()); - - tgt(beta1, + for (bool amsgrad : {false, true}) { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + const T lr = 0.1; + const T beta1 = 0.99; + const T beta2 = 0.95; + const T beta1_pow = beta1 * beta1; + const T beta2_pow = beta2 * beta2; + + const T epsilon = 0.000001; + const int64_t numel = 123; + + T learning_rate = lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow)); + T eps = epsilon * sqrt(1 - beta2_pow); + + std::vector param(numel); + std::vector grad(numel); + std::vector mom1(numel); + std::vector mom2(numel); + std::vector mom2_max(numel); + + std::vector param_out(param.size()); + std::vector mom1_out(mom1.size()); + std::vector mom2_out(mom2.size()); + std::vector mom2_max_out(mom2_max.size()); + + RandomVec(numel, param.data(), 0.5f); + RandomVec(numel, grad.data(), 0.5f); + RandomVec(numel, mom1.data(), 0.5f); + RandomVec(numel, mom2.data(), 0.5f); + if (amsgrad) { + RandomVec(numel, mom2_max.data(), 0.5f); + } + + auto ref = jit::GetReferFunc(); + EXPECT_TRUE(ref != nullptr); + jit::adam_attr_t attr(beta1, beta2, amsgrad); + + ref(beta1, beta2, - -lr, + -learning_rate, eps, numel, grad.data(), @@ -790,125 +743,130 @@ void TestKernelAdam() { mom2.data(), mom2_max.data(), param.data(), - jit_mom1_out.data(), - jit_mom2_out.data(), - jit_mom2_max_out.data(), - jit_param_out.data(), + mom1_out.data(), + mom2_out.data(), + mom2_max_out.data(), + param_out.data(), amsgrad); - ExpectEQ(ref_mom1_out.data(), jit_mom1_out.data(), numel); - ExpectEQ(ref_mom2_out.data(), jit_mom2_out.data(), numel); - ExpectEQ(ref_param_out.data(), jit_param_out.data(), numel); - }; - TestAllImpls(attr, - verifier, - beta1, - beta2, - learning_rate, - eps, - numel, - grad, - mom1, - mom2, - mom2_max, - param, - mom1_out, - mom2_out, - mom2_max_out, - param_out, - amsgrad); + auto verifier = [](const typename KernelTuple::func_type tgt, + T beta1, + T beta2, + T lr, + T eps, + int64_t numel, + const std::vector& grad, + const std::vector& mom1, + const std::vector& mom2, + const std::vector& mom2_max, + const std::vector& param, + const std::vector& ref_mom1_out, + const std::vector& ref_mom2_out, + const std::vector& ref_mom2_max_out, + const std::vector& ref_param_out, + bool amsgrad) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(param.size(), static_cast(numel)); + EXPECT_EQ(grad.size(), static_cast(numel)); + EXPECT_EQ(mom1.size(), static_cast(numel)); + EXPECT_EQ(mom2.size(), static_cast(numel)); + if (amsgrad) { + EXPECT_EQ(mom2_max.size(), static_cast(numel)); + } + + std::vector jit_mom1_out(ref_mom1_out.size()); + std::vector jit_mom2_out(ref_mom2_out.size()); + std::vector jit_mom2_max_out(ref_mom2_max_out.size()); + std::vector jit_param_out(ref_param_out.size()); + + tgt(beta1, + beta2, + -lr, + eps, + numel, + grad.data(), + mom1.data(), + mom2.data(), + mom2_max.data(), + param.data(), + jit_mom1_out.data(), + jit_mom2_out.data(), + jit_mom2_max_out.data(), + jit_param_out.data(), + amsgrad); + + ExpectEQ(ref_mom1_out.data(), jit_mom1_out.data(), numel); + ExpectEQ(ref_mom2_out.data(), jit_mom2_out.data(), numel); + if (amsgrad) { + ExpectEQ(ref_mom2_max_out.data(), jit_mom2_max_out.data(), numel); + } + ExpectEQ(ref_param_out.data(), jit_param_out.data(), numel); + }; + TestAllImpls(attr, + verifier, + beta1, + beta2, + learning_rate, + eps, + numel, + grad, + mom1, + mom2, + mom2_max, + param, + mom1_out, + mom2_out, + mom2_max_out, + param_out, + amsgrad); + } } template void TestKernelAdamW() { - using T = typename KernelTuple::data_type; - VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); - const T old_lr = 0.1; - const T beta1 = 0.99; - const T beta2 = 0.95; - const T beta1_pow = beta1 * beta1; - const T beta2_pow = beta2 * beta2; - - const T epsilon = 0.000001; - const int64_t numel = 123; - const T lr_ratio = 0.2; - const T coeff = 0.3; - - T learning_rate = old_lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow)); - T eps = epsilon * sqrt(1 - beta2_pow); - bool amsgrad = false; - - std::vector param(numel); - std::vector grad(numel); - std::vector mom1(numel); - std::vector mom2(numel); - std::vector mom2_max(numel); - - std::vector param_out(param.size()); - std::vector mom1_out(mom1.size()); - std::vector mom2_out(mom2.size()); - std::vector mom2_max_out(mom2_max.size()); - - RandomVec(numel, param.data(), 0.5f); - RandomVec(numel, grad.data(), 0.5f); - RandomVec(numel, mom1.data(), 0.5f); - RandomVec(numel, mom2.data(), 0.5f); - ZeroVec(numel, mom2_max.data()); - - auto ref = jit::GetReferFunc(); - EXPECT_TRUE(ref != nullptr); - ref(beta1, - beta2, - -learning_rate, - eps, - old_lr, - lr_ratio, - coeff, - numel, - grad.data(), - mom1.data(), - mom2.data(), - mom2_max.data(), - param.data(), - mom1_out.data(), - mom2_out.data(), - mom2_max_out.data(), - param_out.data(), - amsgrad); - - auto verifier = [](const typename KernelTuple::func_type tgt, - T beta1, - T beta2, - T lr, - T eps, - T old_lr, - T lr_ratio, - T coeff, - int64_t numel, - const std::vector& grad, - const std::vector& mom1, - const std::vector& mom2, - const std::vector& mom2_max, - const std::vector& param, - const std::vector& ref_mom1_out, - const std::vector& ref_mom2_out, - const std::vector& ref_mom2_max_out, - const std::vector& ref_param_out, - bool amsgrad) { - EXPECT_TRUE(tgt != nullptr); - EXPECT_EQ(param.size(), static_cast(numel)); - EXPECT_EQ(grad.size(), static_cast(numel)); - EXPECT_EQ(mom1.size(), static_cast(numel)); - EXPECT_EQ(mom2.size(), static_cast(numel)); - - std::vector jit_mom1_out(ref_mom1_out.size()); - std::vector jit_mom2_out(ref_mom2_out.size()); - std::vector jit_mom2_max_out(ref_mom2_max_out.size()); - std::vector jit_param_out(ref_param_out.size()); - - tgt(beta1, + for (bool amsgrad : {false, true}) { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + const T old_lr = 0.1; + const T beta1 = 0.99; + const T beta2 = 0.95; + const T beta1_pow = beta1 * beta1; + const T beta2_pow = beta2 * beta2; + + const T epsilon = 0.000001; + const int64_t numel = 123; + const T lr_ratio = 0.2; + const T coeff = 0.3; + + T learning_rate = old_lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow)); + T eps = epsilon * sqrt(1 - beta2_pow); + + std::vector param(numel); + std::vector grad(numel); + std::vector mom1(numel); + std::vector mom2(numel); + std::vector mom2_max(numel); + + std::vector param_out(param.size()); + std::vector mom1_out(mom1.size()); + std::vector mom2_out(mom2.size()); + std::vector mom2_max_out(mom2_max.size()); + + RandomVec(numel, param.data(), 0.5f); + RandomVec(numel, grad.data(), 0.5f); + RandomVec(numel, mom1.data(), 0.5f); + RandomVec(numel, mom2.data(), 0.5f); + if (amsgrad) { + RandomVec(numel, mom2_max.data()); + } + + auto ref = jit::GetReferFunc(); + EXPECT_TRUE(ref != nullptr); + jit::adamw_attr_t attr(beta1, beta2, coeff, amsgrad); + + ref(beta1, beta2, - -lr, + -learning_rate, eps, old_lr, lr_ratio, @@ -919,37 +877,93 @@ void TestKernelAdamW() { mom2.data(), mom2_max.data(), param.data(), - jit_mom1_out.data(), - jit_mom2_out.data(), - jit_mom2_max_out.data(), - jit_param_out.data(), + mom1_out.data(), + mom2_out.data(), + mom2_max_out.data(), + param_out.data(), amsgrad); - ExpectEQ(ref_mom1_out.data(), jit_mom1_out.data(), numel); - ExpectEQ(ref_mom2_out.data(), jit_mom2_out.data(), numel); - ExpectEQ(ref_param_out.data(), jit_param_out.data(), numel); - }; + auto verifier = [](const typename KernelTuple::func_type tgt, + T beta1, + T beta2, + T lr, + T eps, + T old_lr, + T lr_ratio, + T coeff, + int64_t numel, + const std::vector& grad, + const std::vector& mom1, + const std::vector& mom2, + const std::vector& mom2_max, + const std::vector& param, + const std::vector& ref_mom1_out, + const std::vector& ref_mom2_out, + const std::vector& ref_mom2_max_out, + const std::vector& ref_param_out, + bool amsgrad) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(param.size(), static_cast(numel)); + EXPECT_EQ(grad.size(), static_cast(numel)); + EXPECT_EQ(mom1.size(), static_cast(numel)); + EXPECT_EQ(mom2.size(), static_cast(numel)); + if (amsgrad) { + EXPECT_EQ(mom2_max.size(), static_cast(numel)); + } + + std::vector jit_mom1_out(ref_mom1_out.size()); + std::vector jit_mom2_out(ref_mom2_out.size()); + std::vector jit_mom2_max_out(ref_mom2_max_out.size()); + std::vector jit_param_out(ref_param_out.size()); + + tgt(beta1, + beta2, + -lr, + eps, + old_lr, + lr_ratio, + coeff, + numel, + grad.data(), + mom1.data(), + mom2.data(), + mom2_max.data(), + param.data(), + jit_mom1_out.data(), + jit_mom2_out.data(), + jit_mom2_max_out.data(), + jit_param_out.data(), + amsgrad); + + ExpectEQ(ref_mom1_out.data(), jit_mom1_out.data(), numel); + ExpectEQ(ref_mom2_out.data(), jit_mom2_out.data(), numel); + if (amsgrad) { + ExpectEQ(ref_mom2_max_out.data(), jit_mom2_max_out.data(), numel); + } + ExpectEQ(ref_param_out.data(), jit_param_out.data(), numel); + }; - TestAllImpls(1, - verifier, - beta1, - beta2, - learning_rate, - eps, - old_lr, - lr_ratio, - coeff, - numel, - grad, - mom1, - mom2, - mom2_max, - param, - mom1_out, - mom2_out, - mom2_max_out, - param_out, - amsgrad); + TestAllImpls(attr, + verifier, + beta1, + beta2, + learning_rate, + eps, + old_lr, + lr_ratio, + coeff, + numel, + grad, + mom1, + mom2, + mom2_max, + param, + mom1_out, + mom2_out, + mom2_max_out, + param_out, + amsgrad); + } } template @@ -1419,16 +1433,35 @@ TEST(JITKernel_key, emb_seq_pool) { } TEST(JITKernel_key, adam) { - jit::adam_attr_t attr1(0.4f, 0.9f); - jit::adam_attr_t attr2(0.4f, 0.9f); - jit::adam_attr_t attr3(0.1f, 0.3f); + jit::adam_attr_t attr1(0.4f, 0.9f, true); + jit::adam_attr_t attr2(0.4f, 0.9f, true); + jit::adam_attr_t attr3(0.1f, 0.3f, true); + jit::adam_attr_t attr4(0.1f, 0.3f, false); auto key1 = jit::JitCodeKey(attr1); auto key2 = jit::JitCodeKey(attr2); auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); EXPECT_TRUE(key1 == key2); EXPECT_TRUE(key2 != key3); + EXPECT_TRUE(key3 != key4); +} + +TEST(JITKernel_key, adamw) { + jit::adamw_attr_t attr1(0.4f, 0.9f, 0.7f, true); + jit::adamw_attr_t attr2(0.4f, 0.9f, 0.7f, true); + jit::adamw_attr_t attr3(0.1f, 0.3f, 0.2f, true); + jit::adamw_attr_t attr4(0.1f, 0.3f, 0.7f, false); + + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); + + EXPECT_TRUE(key1 == key2); + EXPECT_TRUE(key2 != key3); + EXPECT_TRUE(key3 != key4); } TEST(JITKernel_key, sgd) { From d6e26523583b6661d3feee6fc10cf21eccd2488a Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 10 Oct 2024 15:01:18 +0800 Subject: [PATCH 28/33] [Fix] adamw.cc for assembly jit gen --- paddle/phi/kernels/funcs/jit/gen/adamw.cc | 11 +++++++---- paddle/phi/kernels/funcs/jit/gen/adamw.h | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/funcs/jit/gen/adamw.cc b/paddle/phi/kernels/funcs/jit/gen/adamw.cc index 2fd0e8e75b248f..417e71f9658d8e 100644 --- a/paddle/phi/kernels/funcs/jit/gen/adamw.cc +++ b/paddle/phi/kernels/funcs/jit/gen/adamw.cc @@ -168,13 +168,16 @@ void AdamWJitCode::genCode() { postCode(); } -class AdamWCreator : public JitCodeCreator { +class AdamWCreator : public JitCodeCreator { public: - bool CanBeUsed(const int& attr) const override { + bool CanBeUsed(const adamw_attr_t& attr) const override { return phi::backends::cpu::MayIUse(phi::backends::cpu::avx512f); } - size_t CodeSize(const int& attr) const override { return 96 + 32 * 8; } - std::unique_ptr CreateJitCode(const int& attr) const override { + size_t CodeSize(const adamw_attr_t& attr) const override { + return 96 + 32 * 8; + } + std::unique_ptr CreateJitCode( + const adamw_attr_t& attr) const override { return make_unique(attr, CodeSize(attr)); } }; diff --git a/paddle/phi/kernels/funcs/jit/gen/adamw.h b/paddle/phi/kernels/funcs/jit/gen/adamw.h index c9d465c02f6022..4147c5f0e383ea 100644 --- a/paddle/phi/kernels/funcs/jit/gen/adamw.h +++ b/paddle/phi/kernels/funcs/jit/gen/adamw.h @@ -26,14 +26,14 @@ namespace gen { class AdamWJitCode : public JitCode { public: - explicit AdamWJitCode(const int& attr, + explicit AdamWJitCode(const adamw_attr_t& attr, size_t code_size = 256 * 1024, void* code_ptr = nullptr) : JitCode(code_size, code_ptr) { this->genCode(); } - DECLARE_JIT_CODE(AdamJitCode); + DECLARE_JIT_CODE(AdamWJitCode); void genCode() override; void loadArgs(); void setTailOpmask(); From da2e7430691a984f424229458d8c22db4f3c32c1 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 10 Oct 2024 18:41:10 +0800 Subject: [PATCH 29/33] [Update] adam with old ir test --- test/legacy_test/test_adam_op.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/legacy_test/test_adam_op.py b/test/legacy_test/test_adam_op.py index 75231fb6a5e4a6..4f92c9d6a09fe5 100644 --- a/test/legacy_test/test_adam_op.py +++ b/test/legacy_test/test_adam_op.py @@ -1029,6 +1029,35 @@ def test_adam_op_with_sparse_input_and_weight_decay(self): adam.step() paddle.enable_static() + def test_adam_with_old_ir(self): + """TODO(megemini): old ir not used anymore""" + with paddle.pir_utils.OldIrGuard(): + paddle.enable_static() + paddle.seed(10) + np.random.seed(10) + exe = paddle.static.Executor() + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + optimizer = paddle.optimizer.Adam(amsgrad=self.amsgrad) + + with paddle.static.program_guard(train_program, startup_program): + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float32' + ) + hidden_layer = paddle.nn.Linear(2, 10) + hidden = hidden_layer(data) + loss = paddle.mean(hidden) + optimizer.minimize(loss) + exe.run(startup_program) + x = np.random.random(size=(2, 2)).astype('float32') + out = [] + for _ in range(5): + (loss_data,) = exe.run( + train_program, feed={"X": x}, fetch_list=[loss] + ) + out.append(loss_data) + return out + class TestAdamOpV2AMSGrad(TestAdamOpV2): def setUp(self): From af2733797bdd2b1bebf5974ecb3881c8a123c4a5 Mon Sep 17 00:00:00 2001 From: megemini Date: Fri, 8 Nov 2024 13:39:07 +0800 Subject: [PATCH 30/33] [Update] codestyle --- test/legacy_test/test_adamw_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/legacy_test/test_adamw_op.py b/test/legacy_test/test_adamw_op.py index 57a86f81a86614..7419d2aa66a1de 100644 --- a/test/legacy_test/test_adamw_op.py +++ b/test/legacy_test/test_adamw_op.py @@ -259,6 +259,7 @@ def test_check_output(self): check_pir=True, ) + class TestAdamW2AMSGrad(TestAdamW2): def set_amsgrad(self): # xpu not support `amsgrad` From d7bb19a96bdec5e4494090ad3e7629e5f1664a16 Mon Sep 17 00:00:00 2001 From: megemini Date: Fri, 8 Nov 2024 20:51:46 +0800 Subject: [PATCH 31/33] [Update] npu test rtol adamw --- test/legacy_test/test_adamw_op.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_adamw_op.py b/test/legacy_test/test_adamw_op.py index 7419d2aa66a1de..a941df78c59b65 100644 --- a/test/legacy_test/test_adamw_op.py +++ b/test/legacy_test/test_adamw_op.py @@ -1242,7 +1242,11 @@ def get_numpy_output( np.testing.assert_allclose(params_and_gras[0], fc1_w, rtol=1e-6) np.testing.assert_allclose(params_and_gras[2], fc1_b, rtol=1e-6) - np.testing.assert_allclose(params_and_gras[4], fc2_w, rtol=1e-6) + np.testing.assert_allclose( + params_and_gras[4], + fc2_w, + rtol=1e-6 if not core.is_compiled_with_xpu() else 1e-5, + ) np.testing.assert_allclose(params_and_gras[6], fc2_b, rtol=1e-6) paddle.disable_static() @@ -1508,7 +1512,11 @@ def get_numpy_output( np.testing.assert_allclose(params_and_gras[6], fc1_w, rtol=1e-6) np.testing.assert_allclose(params_and_gras[4], fc1_b, rtol=1e-6) - np.testing.assert_allclose(params_and_gras[2], fc2_w, rtol=1e-6) + np.testing.assert_allclose( + params_and_gras[2], + fc2_w, + rtol=1e-6 if not core.is_compiled_with_xpu() else 1e-5, + ) np.testing.assert_allclose(params_and_gras[0], fc2_b, rtol=1e-6) paddle.disable_static() From ee8d94f98640e61f07b991da9d442d75b5583fb4 Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 27 Nov 2024 15:27:23 +0800 Subject: [PATCH 32/33] [Update] xpu amsgrad raise errors --- paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc | 5 +++++ paddle/phi/kernels/xpu/adam_kernel.cc | 10 ++++++++++ paddle/phi/kernels/xpu/adamw_kernel.cc | 5 +++++ 3 files changed, 20 insertions(+) diff --git a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc index a92aca83298667..8e53a9802c6875 100644 --- a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc @@ -54,6 +54,11 @@ void AdamDenseParamSparseGradKernel( DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { + PADDLE_ENFORCE_NE( + amsgrad, + true, + phi::errors::Unimplemented("Operation amsgrad is not supported yet.")); + using XPUType = typename XPUTypeTrait::Type; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); float* param_ptr = nullptr; diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc index 751911530c6727..828b42654248ee 100644 --- a/paddle/phi/kernels/xpu/adam_kernel.cc +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -53,6 +53,11 @@ void AdamDenseKernel( DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { + PADDLE_ENFORCE_NE( + amsgrad, + true, + phi::errors::Unimplemented("Operation amsgrad is not supported yet.")); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); float* param_ptr = nullptr; funcs::GetDataPointer( @@ -283,6 +288,11 @@ void MergedAdamKernel( std::vector beta1_pow_out, std::vector beta2_pow_out, std::vector master_param_out) { + PADDLE_ENFORCE_NE( + amsgrad, + true, + phi::errors::Unimplemented("Operation amsgrad is not supported yet.")); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; auto beta1_ = beta1.to(); diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index d2291fa4d17f37..1d44be7eaef3dc 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -507,6 +507,11 @@ void AdamwDenseKernel( DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { + PADDLE_ENFORCE_NE( + amsgrad, + true, + phi::errors::Unimplemented("Operation amsgrad is not supported yet.")); + auto dev_version = phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId()); if (dev_version == phi::backends::xpu::XPUVersion::XPU3) { From 40ca555c8b1b68fb49294a02b2f04156cca1d58e Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 27 Nov 2024 20:26:41 +0800 Subject: [PATCH 33/33] [Fix] not test xpu amsgrad --- test/legacy_test/test_adam_op.py | 96 +++++++++++++++++++++++++------ test/legacy_test/test_adamw_op.py | 66 ++++++++++++++++++--- 2 files changed, 134 insertions(+), 28 deletions(-) diff --git a/test/legacy_test/test_adam_op.py b/test/legacy_test/test_adam_op.py index 4f92c9d6a09fe5..9a9bf8e211f1c4 100644 --- a/test/legacy_test/test_adam_op.py +++ b/test/legacy_test/test_adam_op.py @@ -126,8 +126,13 @@ def test_check_output(self): class TestAdamOp1AMSGrad(TestAdamOp1): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamOp2(OpTest): @@ -201,8 +206,13 @@ def set_shape(self): class TestAdamOp2AMSGrad(TestAdamOp2): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamOpMultipleSteps(OpTest): @@ -288,8 +298,13 @@ def test_check_output(self): class TestAdamOpMultipleStepsAMSGrad(TestAdamOpMultipleSteps): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None def adam_step(inputs, attributes): @@ -578,8 +593,13 @@ def test_sparse_adam(self): class TestSparseAdamOpAMSGrad(TestSparseAdamOp): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamOpBetaVariable(OpTest): @@ -642,8 +662,13 @@ def test_check_output(self): class TestAdamOpBetaVariableAMSGrad(TestAdamOpBetaVariable): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamOpBetaEpsilonVariable(OpTest): @@ -707,8 +732,13 @@ def test_check_output(self): class TestAdamOpBetaEpsilonVariableAMSGrad(TestAdamOpBetaEpsilonVariable): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamOpWithGlobalBetaPow(OpTest): @@ -777,8 +807,13 @@ def test_check_output(self): class TestAdamOpWithGlobalBetaPowAMSGrad(TestAdamOpWithGlobalBetaPow): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamOpWithSkipUpdate(OpTest): @@ -844,8 +879,13 @@ def test_check_output(self): class TestAdamOpWithSkipUpdateAMSGrad(TestAdamOpWithSkipUpdate): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamOpV2(unittest.TestCase): @@ -1061,7 +1101,13 @@ def test_adam_with_old_ir(self): class TestAdamOpV2AMSGrad(TestAdamOpV2): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamOpV2WeightDecay(unittest.TestCase): @@ -1111,7 +1157,13 @@ def test_adam_op(self): class TestAdamOpV2GroupAMSGrad(TestAdamOpV2Group): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestMultiTensorAdam(unittest.TestCase): @@ -1343,7 +1395,13 @@ def test_pir_main(self): class TestMultiTensorAdamAMSGrad(TestMultiTensorAdam): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None if __name__ == "__main__": diff --git a/test/legacy_test/test_adamw_op.py b/test/legacy_test/test_adamw_op.py index 6af2be9c49a9cf..e953b9c795e335 100644 --- a/test/legacy_test/test_adamw_op.py +++ b/test/legacy_test/test_adamw_op.py @@ -181,8 +181,13 @@ def test_check_output(self): class TestAdamWAMSGrad(TestAdamW): def set_amsgrad(self): - self.amsgrad = True - self.no_check_set = None + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None @unittest.skipIf( @@ -448,7 +453,13 @@ def test_adamw_op_invalid_input(self): class TestAdamWOpAMSGrad(TestAdamWOp): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamWOpGroup(TestAdamWOp): @@ -504,7 +515,13 @@ def test_adamw_op_dygraph_bypassing_step(self): class TestAdamWOpGroupAMSGrad(TestAdamWOpGroup): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamWOpMultiPrecisionWithMainGrad(unittest.TestCase): @@ -680,7 +697,13 @@ class TestAdamWOpMultiPrecisionWithMainGradAMSGrad( TestAdamWOpMultiPrecisionWithMainGrad ): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamWOpMultiPrecision(unittest.TestCase): @@ -752,7 +775,13 @@ def test_main(self): class TestAdamWOpMultiPrecisionAMSGrad(TestAdamWOpMultiPrecision): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamWOpError(unittest.TestCase): @@ -823,7 +852,13 @@ def test_grad_clip_dtype(): class TestAdamWOpErrorAMSGrad(TestAdamWOpError): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None class TestAdamWOpGroupWithLR(TestAdamWOp): @@ -862,7 +897,13 @@ def test_adamw_op_dygraph(self): class TestAdamWOpGroupWithLRAMSGrad(TestAdamWOpGroupWithLR): def setUp(self): - self.amsgrad = True + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None def simple_lr_setting(param, decay_rate, n_layers): @@ -1671,7 +1712,14 @@ def setUp(self): random.seed(2022) np.random.seed(2022) paddle.seed(2022) - self.amsgrad = True + + # xpu not support `amsgrad` + if core.is_compiled_with_xpu(): + self.amsgrad = False + self.no_check_set = ['Moment2MaxOut'] + else: + self.amsgrad = True + self.no_check_set = None if __name__ == "__main__":