Skip to content

Commit d774b83

Browse files
authored
【Hackathon 7th PPSCI No.12】Adam、AdamW 优化器支持 amsgrad -part (#68079)
* [init] amsgrad * [update] refer.h * [Add] amsgrad gpu * [Add] amsgrad for adamw and fused * [Fix] adamw gpu kernel * [Update] fused adam kernel for gpu * [Update] xpu adam/adamw param list * [Update] tests for amsgrad * [Fix] moment2 max out settting values without amsgrad * [Update] unittest passed for adam and adamw * [Update] unittest passed for merged and fused amda * [Update] make moment2_max optional * [Update] test_adamw_op.py with new test cast * [Update] adam adamw with amsgrad formula * [Update] adam/adamw for test.cc * [Fix] xpu param name * [Fix] xpu param name & unittest * [Fix] xpu param type * [Fix] xpu unittest * [Fix] xpu unittest * [Fix] xpu unittest * [Fix] merged_adam_ op_compat.yaml * [Fix] remove UNUSED * [Fix] remove UNUSED * [Update] unittest adam op * [Fix] op_compat.yaml * [Update] assembly for adam adamw * [Fix] adamw.cc for assembly jit gen * [Update] adam with old ir test * [Update] codestyle * [Update] npu test rtol adamw * [Update] xpu amsgrad raise errors * [Fix] not test xpu amsgrad
1 parent 279fa69 commit d774b83

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2669
-820
lines changed

paddle/fluid/operators/fused/fused_adam_op.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker {
5757
AddInput("LearningRate", "(Tensor, default Tensor<float>) Learning rate");
5858
AddInput("Moments1", "(Tensor) Input first moments").AsDuplicable();
5959
AddInput("Moments2", "(Tensor) Input second moments").AsDuplicable();
60+
AddInput("Moments2Max", "(Tensor) Input second moments max for amsgrad")
61+
.AsDispensable()
62+
.AsDuplicable();
6063
AddInput("Beta1Pows",
6164
"(Tensor, default Tensor<float>) Input beta1 power accumulator")
6265
.AsDuplicable();
@@ -72,6 +75,10 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker {
7275
AddOutput("ParamsOut", "(Tensor) Output parameters").AsDuplicable();
7376
AddOutput("Moments1Out", "(Tensor) Output first moments").AsDuplicable();
7477
AddOutput("Moments2Out", "(Tensor) Output second moments").AsDuplicable();
78+
AddOutput("Moments2MaxOut",
79+
"(Tensor) Output second moments max for amsgrad")
80+
.AsDispensable()
81+
.AsDuplicable();
7582
AddOutput("Beta1PowsOut", "(Tensor) Output beta1 power accumulator")
7683
.AsDuplicable();
7784
AddOutput("Beta2PowsOut", "(Tensor) Output beta2 power accumulator")
@@ -122,6 +129,10 @@ class FusedAdamOpMaker : public framework::OpProtoAndCheckerMaker {
122129
"Whether to use global beta_pow for whole model instead of "
123130
"creating beta_pow for each parameter.")
124131
.SetDefault(false);
132+
AddAttr<bool>("amsgrad",
133+
"(bool, default false) "
134+
"Whether to use the AMSGrad of this algorithm.")
135+
.SetDefault(false);
125136

126137
AddComment(R"DOC(
127138
Adam Optimizer.

paddle/fluid/operators/ops_signature/adam_sig.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
2424
"LearningRate",
2525
"Moment1",
2626
"Moment2",
27+
"Moment2Max",
2728
"Beta1Pow",
2829
"Beta2Pow",
2930
"MasterParam",
3031
"SkipUpdate"};
3132
paddle::small_vector<const char*> out_names = {"ParamOut",
3233
"Moment1Out",
3334
"Moment2Out",
35+
"Moment2MaxOut",
3436
"Beta1PowOut",
3537
"Beta2PowOut",
3638
"MasterParamOut"};
@@ -46,6 +48,7 @@ KernelSignature AdamOpArgumentMapping(const ArgumentMappingContext& ctx) {
4648
attr_names.emplace_back("min_row_size_to_use_multithread");
4749
attr_names.emplace_back("multi_precision");
4850
attr_names.emplace_back("use_global_beta_pow");
51+
attr_names.emplace_back("amsgrad");
4952

5053
if (ctx.IsSelectedRowsInput("Grad")) {
5154
return KernelSignature("adam_dense_param_sparse_grad",

paddle/fluid/operators/ops_signature/fused_adam_sig.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@ KernelSignature FusedAdamOpArgumentMapping(
2525
"LearningRate",
2626
"Moments1",
2727
"Moments2",
28+
"Moments2Max",
2829
"Beta1Pows",
2930
"Beta2Pows",
3031
"MasterParams",
3132
"SkipUpdate"};
3233
paddle::small_vector<const char*> out_names = {"ParamsOut",
3334
"Moments1Out",
3435
"Moments2Out",
36+
"Moments2MaxOut",
3537
"Beta1PowsOut",
3638
"Beta2PowsOut",
3739
"MasterParamsOut"};
@@ -42,7 +44,8 @@ KernelSignature FusedAdamOpArgumentMapping(
4244
"weight_decay",
4345
"use_adamw",
4446
"multi_precision",
45-
"use_global_beta_pow"};
47+
"use_global_beta_pow",
48+
"amsgrad"};
4649

4750
return KernelSignature("fused_adam",
4851
std::move(in_names),

paddle/fluid/pybind/eager_generator.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3344,27 +3344,31 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
33443344
{"ParamOut",
33453345
"Moment1Out",
33463346
"Moment2Out",
3347+
"Moment2MaxOut",
33473348
"Beta1PowOut",
33483349
"Beta2PowOut",
33493350
"MasterParamOut"}},
33503351
{"merged_adam",
33513352
{"ParamOut",
33523353
"Moment1Out",
33533354
"Moment2Out",
3355+
"Moment2MaxOut",
33543356
"Beta1PowOut",
33553357
"Beta2PowOut",
33563358
"MasterParamOut"}},
33573359
{"fused_adam",
33583360
{"ParamsOut",
33593361
"Moments1Out",
33603362
"Moments2Out",
3363+
"Moments2MaxOut",
33613364
"Beta1PowsOut",
33623365
"Beta2PowsOut",
33633366
"MasterParamsOut"}},
33643367
{"adamw",
33653368
{"ParamOut",
33663369
"Moment1Out",
33673370
"Moment2Out",
3371+
"Moment2MaxOut",
33683372
"Beta1PowOut",
33693373
"Beta2PowOut",
33703374
"MasterParamOut"}},
@@ -3544,6 +3548,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
35443548
"LearningRate",
35453549
"Moment1",
35463550
"Moment2",
3551+
"Moment2Max",
35473552
"Beta1Pow",
35483553
"Beta2Pow",
35493554
"MasterParam"}},
@@ -3553,6 +3558,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
35533558
"LearningRate",
35543559
"Moment1",
35553560
"Moment2",
3561+
"Moment2Max",
35563562
"Beta1Pow",
35573563
"Beta2Pow",
35583564
"MasterParam"}},
@@ -3562,6 +3568,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
35623568
"LearningRate",
35633569
"Moments1",
35643570
"Moments2",
3571+
"Moments2Max",
35653572
"Beta1Pows",
35663573
"Beta2Pows",
35673574
"MasterParams",
@@ -3572,6 +3579,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
35723579
"LearningRate",
35733580
"Moment1",
35743581
"Moment2",
3582+
"Moment2Max",
35753583
"Beta1Pow",
35763584
"Beta2Pow",
35773585
"MasterParam"}},
@@ -3723,27 +3731,31 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
37233731
{"ParamOut",
37243732
"Moment1Out",
37253733
"Moment2Out",
3734+
"Moment2MaxOut",
37263735
"Beta1PowOut",
37273736
"Beta2PowOut",
37283737
"MasterParamOut"}},
37293738
{"merged_adam",
37303739
{"ParamOut",
37313740
"Moment1Out",
37323741
"Moment2Out",
3742+
"Moment2MaxOut",
37333743
"Beta1PowOut",
37343744
"Beta2PowOut",
37353745
"MasterParamOut"}},
37363746
{"fused_adam",
37373747
{"ParamsOut",
37383748
"Moments1Out",
37393749
"Moments2Out",
3750+
"Moments2MaxOut",
37403751
"Beta1PowsOut",
37413752
"Beta2PowsOut",
37423753
"MasterParamsOut"}},
37433754
{"adamw",
37443755
{"ParamOut",
37453756
"Moment1Out",
37463757
"Moment2Out",
3758+
"Moment2MaxOut",
37473759
"Beta1PowOut",
37483760
"Beta2PowOut",
37493761
"MasterParamOut"}},

paddle/phi/infermeta/multiary.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ void AdamInferMeta(const MetaTensor& param,
152152
const MetaTensor& learning_rate,
153153
const MetaTensor& moment1,
154154
const MetaTensor& moment2,
155+
const MetaTensor& moment2_max,
155156
const MetaTensor& beta1_pow,
156157
const MetaTensor& beta2_pow,
157158
const MetaTensor& master_param,
@@ -163,9 +164,11 @@ void AdamInferMeta(const MetaTensor& param,
163164
int64_t min_row_size_to_use_multithread,
164165
bool multi_precision,
165166
bool use_global_beta_pow,
167+
bool amsgrad,
166168
MetaTensor* param_out,
167169
MetaTensor* moment1_out,
168170
MetaTensor* moment2_out,
171+
MetaTensor* moment2_max_out,
169172
MetaTensor* beta1_pow_out,
170173
MetaTensor* beta2_pow_out,
171174
MetaTensor* master_param_outs) {
@@ -232,6 +235,10 @@ void AdamInferMeta(const MetaTensor& param,
232235
moment1_out->set_dtype(moment1.dtype());
233236
moment2_out->set_dims(param_dims);
234237
moment2_out->set_dtype(moment2.dtype());
238+
if (amsgrad) {
239+
moment2_max_out->set_dims(param_dims);
240+
moment2_max_out->set_dtype(moment2.dtype());
241+
}
235242

236243
beta1_pow_out->set_dims(beta1_pow_dims);
237244
beta1_pow_out->set_dtype(beta1_pow.dtype());
@@ -328,6 +335,7 @@ void AdamwInferMeta(const MetaTensor& param,
328335
const MetaTensor& learning_rate,
329336
const MetaTensor& moment1,
330337
const MetaTensor& moment2,
338+
const MetaTensor& moment2_max,
331339
const MetaTensor& beta1_pow,
332340
const MetaTensor& beta2_pow,
333341
const MetaTensor& master_param,
@@ -342,9 +350,11 @@ void AdamwInferMeta(const MetaTensor& param,
342350
int64_t min_row_size_to_use_multithread,
343351
bool multi_precision,
344352
bool use_global_beta_pow,
353+
bool amsgrad,
345354
MetaTensor* param_out,
346355
MetaTensor* moment1_out,
347356
MetaTensor* moment2_out,
357+
MetaTensor* moment2_max_out,
348358
MetaTensor* beta1_pow_out,
349359
MetaTensor* beta2_pow_out,
350360
MetaTensor* master_param_outs) {
@@ -353,6 +363,7 @@ void AdamwInferMeta(const MetaTensor& param,
353363
learning_rate,
354364
moment1,
355365
moment2,
366+
moment2_max,
356367
beta1_pow,
357368
beta2_pow,
358369
master_param,
@@ -364,9 +375,11 @@ void AdamwInferMeta(const MetaTensor& param,
364375
min_row_size_to_use_multithread,
365376
multi_precision,
366377
use_global_beta_pow,
378+
amsgrad,
367379
param_out,
368380
moment1_out,
369381
moment2_out,
382+
moment2_max_out,
370383
beta1_pow_out,
371384
beta2_pow_out,
372385
master_param_outs);
@@ -3866,6 +3879,7 @@ void MergedAdamInferMeta(
38663879
const std::vector<const MetaTensor*>& learning_rate,
38673880
const std::vector<const MetaTensor*>& moment1,
38683881
const std::vector<const MetaTensor*>& moment2,
3882+
const paddle::optional<std::vector<const MetaTensor*>>& moment2_max,
38693883
const std::vector<const MetaTensor*>& beta1_pow,
38703884
const std::vector<const MetaTensor*>& beta2_pow,
38713885
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
@@ -3874,9 +3888,11 @@ void MergedAdamInferMeta(
38743888
const Scalar& epsilon,
38753889
bool multi_precision,
38763890
bool use_global_beta_pow,
3891+
bool amsgrad,
38773892
std::vector<MetaTensor*> param_out,
38783893
std::vector<MetaTensor*> moment1_out,
38793894
std::vector<MetaTensor*> moment2_out,
3895+
std::vector<MetaTensor*> moment2_max_out,
38803896
std::vector<MetaTensor*> beta1_pow_out,
38813897
std::vector<MetaTensor*> beta2_pow_out,
38823898
std::vector<MetaTensor*> master_param_out) {}
@@ -5796,6 +5812,7 @@ void FusedAdamInferMeta(
57965812
const MetaTensor& learning_rate,
57975813
const std::vector<const MetaTensor*>& moments1,
57985814
const std::vector<const MetaTensor*>& moments2,
5815+
const paddle::optional<std::vector<const MetaTensor*>>& moments2_max,
57995816
const std::vector<const MetaTensor*>& beta1_pows,
58005817
const std::vector<const MetaTensor*>& beta2_pows,
58015818
const paddle::optional<std::vector<const MetaTensor*>>& master_params,
@@ -5808,9 +5825,11 @@ void FusedAdamInferMeta(
58085825
bool use_adamw,
58095826
bool multi_precision,
58105827
bool use_global_beta_pow,
5828+
bool amsgrad,
58115829
std::vector<MetaTensor*> params_out,
58125830
std::vector<MetaTensor*> moments1_out,
58135831
std::vector<MetaTensor*> moments2_out,
5832+
std::vector<MetaTensor*> moments2_max_out,
58145833
std::vector<MetaTensor*> beta1_pows_out,
58155834
std::vector<MetaTensor*> beta2_pows_out,
58165835
std::vector<MetaTensor*> master_params_out) {
@@ -5822,6 +5841,10 @@ void FusedAdamInferMeta(
58225841
moments1_out[i]->set_dtype(moments1[i]->dtype());
58235842
moments2_out[i]->set_dims(moments2[i]->dims());
58245843
moments2_out[i]->set_dtype(moments2[i]->dtype());
5844+
if (amsgrad) {
5845+
moments2_max_out[i]->set_dims(moments2_max.get()[i]->dims());
5846+
moments2_max_out[i]->set_dtype(moments2_max.get()[i]->dtype());
5847+
}
58255848
beta1_pows_out[i]->set_dims(beta1_pows[i]->dims());
58265849
beta1_pows_out[i]->set_dtype(beta1_pows[i]->dtype());
58275850
beta2_pows_out[i]->set_dims(beta2_pows[i]->dims());

paddle/phi/infermeta/multiary.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ void AdamInferMeta(const MetaTensor& param,
8686
const MetaTensor& learning_rate,
8787
const MetaTensor& moment1,
8888
const MetaTensor& moment2,
89+
const MetaTensor& moment2_max,
8990
const MetaTensor& beta1_pow,
9091
const MetaTensor& beta2_pow,
9192
const MetaTensor& master_param,
@@ -97,9 +98,11 @@ void AdamInferMeta(const MetaTensor& param,
9798
int64_t min_row_size_to_use_multithread,
9899
bool multi_precision,
99100
bool use_global_beta_pow,
101+
bool amsgrad,
100102
MetaTensor* param_out,
101103
MetaTensor* moment1_out,
102104
MetaTensor* moment2_out,
105+
MetaTensor* moment2_max_out,
103106
MetaTensor* beta1_pow_out,
104107
MetaTensor* beta2_pow_out,
105108
MetaTensor* master_param_outs);
@@ -109,6 +112,7 @@ void AdamwInferMeta(const MetaTensor& param,
109112
const MetaTensor& learning_rate,
110113
const MetaTensor& moment1,
111114
const MetaTensor& moment2,
115+
const MetaTensor& moment2_max,
112116
const MetaTensor& beta1_pow,
113117
const MetaTensor& beta2_pow,
114118
const MetaTensor& master_param,
@@ -123,9 +127,11 @@ void AdamwInferMeta(const MetaTensor& param,
123127
int64_t min_row_size_to_use_multithread,
124128
bool multi_precision,
125129
bool use_global_beta_pow,
130+
bool amsgrad,
126131
MetaTensor* param_out,
127132
MetaTensor* moment1_out,
128133
MetaTensor* moment2_out,
134+
MetaTensor* moment2_max_out,
129135
MetaTensor* beta1_pow_out,
130136
MetaTensor* beta2_pow_out,
131137
MetaTensor* master_param_outs);
@@ -711,6 +717,7 @@ void MergedAdamInferMeta(
711717
const std::vector<const MetaTensor*>& learning_rate,
712718
const std::vector<const MetaTensor*>& moment1,
713719
const std::vector<const MetaTensor*>& moment2,
720+
const paddle::optional<std::vector<const MetaTensor*>>& moment2_max,
714721
const std::vector<const MetaTensor*>& beta1_pow,
715722
const std::vector<const MetaTensor*>& beta2_pow,
716723
const paddle::optional<std::vector<const MetaTensor*>>& master_param,
@@ -719,9 +726,11 @@ void MergedAdamInferMeta(
719726
const Scalar& epsilon,
720727
bool multi_precision,
721728
bool use_global_beta_pow,
729+
bool amsgrad,
722730
std::vector<MetaTensor*> param_out,
723731
std::vector<MetaTensor*> moment1_out,
724732
std::vector<MetaTensor*> moment2_out,
733+
std::vector<MetaTensor*> moment2_max_out,
725734
std::vector<MetaTensor*> beta1_pow_out,
726735
std::vector<MetaTensor*> beta2_pow_out,
727736
std::vector<MetaTensor*> master_param_out);
@@ -1117,6 +1126,7 @@ void FusedAdamInferMeta(
11171126
const MetaTensor& learning_rate,
11181127
const std::vector<const MetaTensor*>& moments1,
11191128
const std::vector<const MetaTensor*>& moments2,
1129+
const paddle::optional<std::vector<const MetaTensor*>>& moments2_max,
11201130
const std::vector<const MetaTensor*>& beta1_pows,
11211131
const std::vector<const MetaTensor*>& beta2_pows,
11221132
const paddle::optional<std::vector<const MetaTensor*>>& master_params,
@@ -1129,9 +1139,11 @@ void FusedAdamInferMeta(
11291139
bool use_adamw,
11301140
bool multi_precision,
11311141
bool use_global_beta_pow,
1142+
bool amsgrad,
11321143
std::vector<MetaTensor*> params_out,
11331144
std::vector<MetaTensor*> moments1_out,
11341145
std::vector<MetaTensor*> moments2_out,
1146+
std::vector<MetaTensor*> moments2_max_out,
11351147
std::vector<MetaTensor*> beta1_pows_out,
11361148
std::vector<MetaTensor*> beta2_pows_out,
11371149
std::vector<MetaTensor*> master_params_out);

0 commit comments

Comments
 (0)