Skip to content

Commit 72a5c32

Browse files
kshitij12345gyshi
authored andcommitted
Fix gradient tensor mutate in {adam/ftrl/rmprop/rmspropalex}_update. (apache#15768)
* update code to fix apache#15759 * add relevant test * re-add the removed conditional dispatch * fix grad mutate for ftrl_update * add test for ftrl_update * fix grad mutate for rmspropalex_update * add test for rmspropalex_update * use KERNEL_ASSIGN in RMSPropAlexUpdateKernel. * fix grad mutate for rmsprop_update * add test for rmsprop_update * add more optimizers for mutation test * retrigger CI * retrigger CI * retrigger CI * retrigger CI * address comments. * refactor code. * retrigger CI * retrigger CI * retrigger CI
1 parent da196b9 commit 72a5c32

File tree

2 files changed

+216
-130
lines changed

2 files changed

+216
-130
lines changed

src/operator/optimizer_op-inl.h

Lines changed: 150 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,15 +1293,37 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
12931293
}
12941294
};
12951295

1296+
struct AdamUpdateKernel {
1297+
template<typename DType>
1298+
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1299+
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1300+
const DType clip_gradient, const DType rescale_grad,
1301+
const DType beta1, const DType beta2,
1302+
const DType lr, const DType wd,
1303+
const DType epsilon, const OpReqType req) {
1304+
using namespace mshadow_op;
1305+
1306+
DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
1307+
if (clip_gradient >= 0.f) {
1308+
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1309+
}
1310+
1311+
mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1312+
var_data[i] = beta2 * var_data[i] +
1313+
(1.f - beta2) * grad_rescaled * grad_rescaled;
1314+
1315+
KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * mean_data[i] /
1316+
(square_root::Map(var_data[i]) + epsilon));
1317+
}
1318+
};
1319+
12961320
template<typename xpu>
12971321
inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
12981322
const OpContext &ctx,
12991323
const std::vector<TBlob> &inputs,
13001324
const std::vector<OpReqType> &req,
13011325
const std::vector<TBlob> &outputs) {
1302-
using namespace mshadow;
1303-
using namespace mshadow::expr;
1304-
using namespace mshadow_op;
1326+
using namespace mxnet_op;
13051327
const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
13061328
Stream<xpu>* s = ctx.get_stream<xpu>();
13071329
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
@@ -1311,22 +1333,12 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
13111333
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
13121334
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
13131335

1314-
grad = scalar<DType>(param.rescale_grad) * grad +
1315-
scalar<DType>(param.wd) * weight;
1316-
1317-
if (param.clip_gradient >= 0.0f) {
1318-
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
1319-
F<clip>(grad, DType(param.clip_gradient));
1320-
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2)*F<square>(
1321-
F<clip>(grad, DType(param.clip_gradient)));
1322-
} else {
1323-
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
1324-
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
1325-
}
1326-
Assign(out, req[0],
1327-
weight -
1328-
scalar<DType>(param.lr) * mean /
1329-
(F<square_root>(var) + scalar<DType>(param.epsilon)));
1336+
Kernel<AdamUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
1337+
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
1338+
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1339+
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
1340+
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
1341+
static_cast<DType>(param.epsilon), req[0]);
13301342
});
13311343
}
13321344

@@ -1596,57 +1608,64 @@ struct RMSPropAlexParam : public dmlc::Parameter<RMSPropAlexParam> {
15961608
}
15971609
};
15981610

1611+
struct RMSPropAlexUpdateKernel {
1612+
template<typename DType>
1613+
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1614+
DType* state_n_data, DType* state_g_data, DType* delta_data,
1615+
const DType* weight_data, const DType* grad_data,
1616+
const DType clip_gradient, const DType rescale_grad,
1617+
const DType gamma1, const DType gamma2,
1618+
const DType lr, const DType wd,
1619+
const DType clip_weights, const DType epsilon,
1620+
const OpReqType req) {
1621+
using namespace mshadow_op;
1622+
1623+
DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
1624+
if (clip_gradient >= 0.0f) {
1625+
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1626+
}
1627+
1628+
state_n_data[i] = (1.f - gamma1) * grad_rescaled * grad_rescaled +
1629+
gamma1 * state_n_data[i];
1630+
state_g_data[i] = (1.f - gamma1) * grad_rescaled +
1631+
gamma1 * state_g_data[i];
1632+
delta_data[i] = gamma2 * delta_data[i] -
1633+
(lr * (grad_rescaled) /
1634+
(square_root::Map(state_n_data[i] -
1635+
state_g_data[i] * state_g_data[i] + epsilon)));
1636+
1637+
if (clip_weights >= 0.0f) {
1638+
const DType clipped_weight = clip::Map(weight_data[i] + delta_data[i], clip_weights);
1639+
KERNEL_ASSIGN(out_data[i], req, clipped_weight);
1640+
} else {
1641+
KERNEL_ASSIGN(out_data[i], req, weight_data[i] + delta_data[i]);
1642+
}
1643+
}
1644+
};
1645+
15991646
template <typename xpu>
16001647
inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs,
16011648
const OpContext &ctx,
16021649
const std::vector<TBlob> &inputs,
16031650
const std::vector<OpReqType> &req,
16041651
const std::vector<TBlob> &outputs) {
1605-
using namespace mshadow;
1606-
using namespace mshadow::expr;
1607-
using namespace mshadow_op;
1652+
using namespace mxnet_op;
16081653
const RMSPropAlexParam &param = nnvm::get<RMSPropAlexParam>(attrs.parsed);
16091654
Stream<xpu> *s = ctx.get_stream<xpu>();
16101655
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1611-
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1612-
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1613-
Tensor<xpu, 2, DType> state_n = inputs[2].FlatTo2D<xpu, DType>(s);
1614-
Tensor<xpu, 2, DType> state_g = inputs[3].FlatTo2D<xpu, DType>(s);
1615-
Tensor<xpu, 2, DType> delta = inputs[4].FlatTo2D<xpu, DType>(s);
1616-
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1617-
1618-
grad = scalar<DType>(param.rescale_grad) * grad +
1619-
scalar<DType>(param.wd) * weight;
1620-
1621-
if (param.clip_gradient >= 0.0f) {
1622-
state_n = scalar<DType>(1.f - param.gamma1) *
1623-
F<clip>(grad, DType(param.clip_gradient)) *
1624-
F<clip>(grad, DType(param.clip_gradient)) +
1625-
scalar<DType>(param.gamma1) * state_n;
1626-
state_g = scalar<DType>(1.f - param.gamma1) *
1627-
F<clip>(grad, DType(param.clip_gradient)) +
1628-
scalar<DType>(param.gamma1) * state_g;
1629-
delta = scalar<DType>(param.gamma2) * delta -
1630-
scalar<DType>(param.lr) *
1631-
(F<clip>(grad, DType(param.clip_gradient)) /
1632-
(F<square_root>(state_n - state_g * state_g +
1633-
scalar<DType>(param.epsilon))));
1634-
} else {
1635-
state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) +
1636-
scalar<DType>(param.gamma1) * state_n;
1637-
state_g = scalar<DType>(1.f - param.gamma1) * grad +
1638-
scalar<DType>(param.gamma1) * state_g;
1639-
delta = scalar<DType>(param.gamma2) * delta -
1640-
scalar<DType>(param.lr) *
1641-
(grad / (F<square_root>(state_n - state_g * state_g +
1642-
scalar<DType>(param.epsilon))));
1643-
}
1656+
DType* weight_data = inputs[0].dptr<DType>();
1657+
DType* grad_data = inputs[1].dptr<DType>();
1658+
DType* state_n_data = inputs[2].dptr<DType>();
1659+
DType* state_g_data = inputs[3].dptr<DType>();
1660+
DType* delta_data = inputs[4].dptr<DType>();
1661+
DType* out_data = outputs[0].dptr<DType>();
16441662

1645-
if (param.clip_weights >= 0.0f) {
1646-
Assign(out, req[0], F<clip>(weight + delta, DType(param.clip_weights)));
1647-
} else {
1648-
Assign(out, req[0], weight + delta);
1649-
}
1663+
Kernel<RMSPropAlexUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(),
1664+
out_data, state_n_data, state_g_data, delta_data, weight_data, grad_data,
1665+
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1666+
static_cast<DType>(param.gamma1), static_cast<DType>(param.gamma2),
1667+
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
1668+
static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]);
16501669
});
16511670
}
16521671

@@ -1688,64 +1707,52 @@ struct RMSPropParam : public dmlc::Parameter<RMSPropParam> {
16881707
}
16891708
};
16901709

1710+
struct RMSPropUpdateKernel {
1711+
template<typename DType>
1712+
MSHADOW_XINLINE static void Map(int i,
1713+
DType* out_data, DType* state_n_data,
1714+
const DType* weight_data, const DType* grad_data,
1715+
const DType clip_gradient, const DType rescale_grad,
1716+
const DType gamma1, const DType lr, const DType wd,
1717+
const DType clip_weights, const DType epsilon,
1718+
const OpReqType req) {
1719+
using namespace mshadow_op;
1720+
1721+
DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
1722+
if (clip_gradient >= 0.0f) {
1723+
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1724+
}
1725+
1726+
state_n_data[i] = (1.f - gamma1) * (grad_rescaled * grad_rescaled) + gamma1 * state_n_data[i];
1727+
1728+
DType weight = weight_data[i] -
1729+
lr * (grad_rescaled / square_root::Map(state_n_data[i] + epsilon));
1730+
if (clip_weights >= 0.0f) {
1731+
weight = clip::Map(weight, clip_weights);
1732+
}
1733+
KERNEL_ASSIGN(out_data[i], req, weight);
1734+
}
1735+
};
1736+
16911737
template <typename xpu>
16921738
inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
16931739
const std::vector<TBlob> &inputs,
16941740
const std::vector<OpReqType> &req,
16951741
const std::vector<TBlob> &outputs) {
1696-
using namespace mshadow;
1697-
using namespace mshadow::expr;
1698-
using namespace mshadow_op;
1742+
using namespace mxnet_op;
16991743
const RMSPropParam &param = nnvm::get<RMSPropParam>(attrs.parsed);
17001744
Stream<xpu> *s = ctx.get_stream<xpu>();
17011745
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1702-
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1703-
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1704-
Tensor<xpu, 2, DType> state_n = inputs[2].FlatTo2D<xpu, DType>(s);
1705-
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1746+
DType* weight_data = inputs[0].dptr<DType>();
1747+
DType* grad_data = inputs[1].dptr<DType>();
1748+
DType* state_n_data = inputs[2].dptr<DType>();
1749+
DType* out_data = outputs[0].dptr<DType>();
17061750

1707-
grad = scalar<DType>(param.rescale_grad) * grad +
1708-
scalar<DType>(param.wd) * weight;
1709-
1710-
if (param.clip_gradient >= 0.0f) {
1711-
state_n = scalar<DType>(1.f - param.gamma1) *
1712-
F<clip>(grad, DType(param.clip_gradient)) *
1713-
F<clip>(grad, DType(param.clip_gradient)) +
1714-
scalar<DType>(param.gamma1) * state_n;
1715-
if (param.clip_weights >= 0.0f) {
1716-
Assign(out, req[0],
1717-
F<clip>(weight -
1718-
scalar<DType>(param.lr) *
1719-
(F<clip>(grad, DType(param.clip_gradient)) /
1720-
(F<square_root>(state_n +
1721-
scalar<DType>(param.epsilon)))),
1722-
DType(param.clip_weights)));
1723-
} else {
1724-
Assign(out, req[0], weight -
1725-
scalar<DType>(param.lr) *
1726-
(F<clip>(grad, DType(param.clip_gradient)) /
1727-
(F<square_root>(state_n +
1728-
scalar<DType>(param.epsilon)))));
1729-
}
1730-
} else {
1731-
state_n = scalar<DType>(1.f - param.gamma1) * (grad * grad) +
1732-
scalar<DType>(param.gamma1) * state_n;
1733-
if (param.clip_weights >= 0.0f) {
1734-
Assign(out, req[0],
1735-
F<clip>(weight -
1736-
scalar<DType>(param.lr) *
1737-
(grad /
1738-
(F<square_root>(state_n +
1739-
scalar<DType>(param.epsilon)))),
1740-
DType(param.clip_weights)));
1741-
} else {
1742-
Assign(out, req[0], weight -
1743-
scalar<DType>(param.lr) *
1744-
(grad /
1745-
(F<square_root>(state_n +
1746-
scalar<DType>(param.epsilon)))));
1747-
}
1748-
}
1751+
Kernel<RMSPropUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(),
1752+
out_data, state_n_data, weight_data, grad_data,
1753+
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1754+
static_cast<DType>(param.gamma1), static_cast<DType>(param.lr), static_cast<DType>(param.wd),
1755+
static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]);
17491756
});
17501757
}
17511758

@@ -1781,15 +1788,41 @@ struct FtrlParam : public dmlc::Parameter<FtrlParam> {
17811788
}
17821789
};
17831790

1791+
struct FtrlUpdateKernel {
1792+
template<typename DType>
1793+
MSHADOW_XINLINE static void Map(int i, DType* out_data,
1794+
DType* n_data, DType* z_data, const DType* weight_data, const DType* grad_data,
1795+
const DType clip_gradient, const DType rescale_grad,
1796+
const DType beta, const DType lamda1,
1797+
const DType lr, const DType wd,
1798+
const OpReqType req) {
1799+
using namespace mshadow_op;
1800+
1801+
DType grad_rescaled = grad_data[i] * rescale_grad;
1802+
if (clip_gradient >= 0.0f) {
1803+
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1804+
}
1805+
1806+
z_data[i] += grad_rescaled - (square_root::Map(n_data[i] +
1807+
square::Map(grad_rescaled)) - square_root::Map(n_data[i])) *
1808+
weight_data[i] / lr;
1809+
n_data[i] += square::Map(grad_rescaled);
1810+
1811+
KERNEL_ASSIGN(out_data[i], req,
1812+
(sign::Map(z_data[i]) * lamda1 - z_data[i]) /
1813+
((beta + square_root::Map(n_data[i])) / lr + wd) *
1814+
gt::Map(abs::Map(z_data[i]), lamda1));
1815+
}
1816+
};
1817+
17841818
template<typename xpu>
17851819
inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
17861820
const OpContext &ctx,
17871821
const std::vector<TBlob> &inputs,
17881822
const std::vector<OpReqType> &req,
17891823
const std::vector<TBlob> &outputs) {
1790-
using namespace mshadow;
1791-
using namespace mshadow::expr;
1792-
using namespace mshadow_op;
1824+
using namespace mxnet_op;
1825+
17931826
const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed);
17941827
Stream<xpu>* s = ctx.get_stream<xpu>();
17951828
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
@@ -1799,23 +1832,11 @@ inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
17991832
Tensor<xpu, 2, DType> n = inputs[3].FlatTo2D<xpu, DType>(s);
18001833
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
18011834

1802-
grad = scalar<DType>(param.rescale_grad) * grad;
1803-
1804-
if (param.clip_gradient >= 0.0f) {
1805-
z += F<clip>(grad, DType(param.clip_gradient)) - (F<square_root>(n +
1806-
F<square>(F<clip>(grad, DType(param.clip_gradient)))) - F<square_root>(n)) *
1807-
weight / scalar<DType>(param.lr);
1808-
n += F<square>(F<clip>(grad, DType(param.clip_gradient)));
1809-
} else {
1810-
z += grad - (F<square_root>(n + F<square>(grad)) - F<square_root>(n)) *
1811-
weight / scalar<DType>(param.lr);
1812-
n += F<square>(grad);
1813-
}
1814-
Assign(out, req[0],
1815-
(F<sign>(z) * scalar<DType>(param.lamda1) - z) /
1816-
((scalar<DType>(param.beta) + F<square_root>(n)) /
1817-
scalar<DType>(param.lr) + scalar<DType>(param.wd)) *
1818-
F<gt>(F<abs>(z), scalar<DType>(param.lamda1)));
1835+
Kernel<FtrlUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
1836+
out.dptr_, n.dptr_, z.dptr_, weight.dptr_, grad.dptr_,
1837+
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1838+
static_cast<DType>(param.beta), static_cast<DType>(param.lamda1),
1839+
static_cast<DType>(param.lr), static_cast<DType>(param.wd), req[0]);
18191840
});
18201841
}
18211842

0 commit comments

Comments
 (0)