@@ -1293,15 +1293,37 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
1293
1293
}
1294
1294
};
1295
1295
1296
+ struct AdamUpdateKernel {
1297
+ template <typename DType>
1298
+ MSHADOW_XINLINE static void Map (int i, DType* out_data,
1299
+ DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1300
+ const DType clip_gradient, const DType rescale_grad,
1301
+ const DType beta1, const DType beta2,
1302
+ const DType lr, const DType wd,
1303
+ const DType epsilon, const OpReqType req) {
1304
+ using namespace mshadow_op ;
1305
+
1306
+ DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
1307
+ if (clip_gradient >= 0 .f ) {
1308
+ grad_rescaled = clip::Map (grad_rescaled, clip_gradient);
1309
+ }
1310
+
1311
+ mean_data[i] = beta1 * mean_data[i] + (1 .f - beta1) * grad_rescaled;
1312
+ var_data[i] = beta2 * var_data[i] +
1313
+ (1 .f - beta2) * grad_rescaled * grad_rescaled;
1314
+
1315
+ KERNEL_ASSIGN (out_data[i], req, weight_data[i] - lr * mean_data[i] /
1316
+ (square_root::Map (var_data[i]) + epsilon));
1317
+ }
1318
+ };
1319
+
1296
1320
template <typename xpu>
1297
1321
inline void AdamUpdate (const nnvm::NodeAttrs& attrs,
1298
1322
const OpContext &ctx,
1299
1323
const std::vector<TBlob> &inputs,
1300
1324
const std::vector<OpReqType> &req,
1301
1325
const std::vector<TBlob> &outputs) {
1302
- using namespace mshadow ;
1303
- using namespace mshadow ::expr;
1304
- using namespace mshadow_op ;
1326
+ using namespace mxnet_op ;
1305
1327
const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed );
1306
1328
Stream<xpu>* s = ctx.get_stream <xpu>();
1307
1329
MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
@@ -1311,22 +1333,12 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
1311
1333
Tensor<xpu, 2 , DType> var = inputs[3 ].FlatTo2D <xpu, DType>(s);
1312
1334
Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1313
1335
1314
- grad = scalar<DType>(param.rescale_grad ) * grad +
1315
- scalar<DType>(param.wd ) * weight;
1316
-
1317
- if (param.clip_gradient >= 0 .0f ) {
1318
- mean = scalar<DType>(param.beta1 )*mean + scalar<DType>(1 .f -param.beta1 ) *
1319
- F<clip>(grad, DType (param.clip_gradient ));
1320
- var = scalar<DType>(param.beta2 )*var + scalar<DType>(1 .f -param.beta2 )*F<square>(
1321
- F<clip>(grad, DType (param.clip_gradient )));
1322
- } else {
1323
- mean = scalar<DType>(param.beta1 )*mean + scalar<DType>(1 .f -param.beta1 ) * grad;
1324
- var = scalar<DType>(param.beta2 )*var + scalar<DType>(1 .f -param.beta2 ) * F<square>(grad);
1325
- }
1326
- Assign (out, req[0 ],
1327
- weight -
1328
- scalar<DType>(param.lr ) * mean /
1329
- (F<square_root>(var) + scalar<DType>(param.epsilon )));
1336
+ Kernel<AdamUpdateKernel, xpu>::Launch (s, weight.shape_ .Size (),
1337
+ out.dptr_ , mean.dptr_ , var.dptr_ , weight.dptr_ , grad.dptr_ ,
1338
+ static_cast <DType>(param.clip_gradient ), static_cast <DType>(param.rescale_grad ),
1339
+ static_cast <DType>(param.beta1 ), static_cast <DType>(param.beta2 ),
1340
+ static_cast <DType>(param.lr ), static_cast <DType>(param.wd ),
1341
+ static_cast <DType>(param.epsilon ), req[0 ]);
1330
1342
});
1331
1343
}
1332
1344
@@ -1596,57 +1608,64 @@ struct RMSPropAlexParam : public dmlc::Parameter<RMSPropAlexParam> {
1596
1608
}
1597
1609
};
1598
1610
1611
+ struct RMSPropAlexUpdateKernel {
1612
+ template <typename DType>
1613
+ MSHADOW_XINLINE static void Map (int i, DType* out_data,
1614
+ DType* state_n_data, DType* state_g_data, DType* delta_data,
1615
+ const DType* weight_data, const DType* grad_data,
1616
+ const DType clip_gradient, const DType rescale_grad,
1617
+ const DType gamma1, const DType gamma2,
1618
+ const DType lr, const DType wd,
1619
+ const DType clip_weights, const DType epsilon,
1620
+ const OpReqType req) {
1621
+ using namespace mshadow_op ;
1622
+
1623
+ DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
1624
+ if (clip_gradient >= 0 .0f ) {
1625
+ grad_rescaled = clip::Map (grad_rescaled, clip_gradient);
1626
+ }
1627
+
1628
+ state_n_data[i] = (1 .f - gamma1) * grad_rescaled * grad_rescaled +
1629
+ gamma1 * state_n_data[i];
1630
+ state_g_data[i] = (1 .f - gamma1) * grad_rescaled +
1631
+ gamma1 * state_g_data[i];
1632
+ delta_data[i] = gamma2 * delta_data[i] -
1633
+ (lr * (grad_rescaled) /
1634
+ (square_root::Map (state_n_data[i] -
1635
+ state_g_data[i] * state_g_data[i] + epsilon)));
1636
+
1637
+ if (clip_weights >= 0 .0f ) {
1638
+ const DType clipped_weight = clip::Map (weight_data[i] + delta_data[i], clip_weights);
1639
+ KERNEL_ASSIGN (out_data[i], req, clipped_weight);
1640
+ } else {
1641
+ KERNEL_ASSIGN (out_data[i], req, weight_data[i] + delta_data[i]);
1642
+ }
1643
+ }
1644
+ };
1645
+
1599
1646
template <typename xpu>
1600
1647
inline void RMSPropAlexUpdate (const nnvm::NodeAttrs &attrs,
1601
1648
const OpContext &ctx,
1602
1649
const std::vector<TBlob> &inputs,
1603
1650
const std::vector<OpReqType> &req,
1604
1651
const std::vector<TBlob> &outputs) {
1605
- using namespace mshadow ;
1606
- using namespace mshadow ::expr;
1607
- using namespace mshadow_op ;
1652
+ using namespace mxnet_op ;
1608
1653
const RMSPropAlexParam ¶m = nnvm::get<RMSPropAlexParam>(attrs.parsed );
1609
1654
Stream<xpu> *s = ctx.get_stream <xpu>();
1610
1655
MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1611
- Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1612
- Tensor<xpu, 2 , DType> grad = inputs[1 ].FlatTo2D <xpu, DType>(s);
1613
- Tensor<xpu, 2 , DType> state_n = inputs[2 ].FlatTo2D <xpu, DType>(s);
1614
- Tensor<xpu, 2 , DType> state_g = inputs[3 ].FlatTo2D <xpu, DType>(s);
1615
- Tensor<xpu, 2 , DType> delta = inputs[4 ].FlatTo2D <xpu, DType>(s);
1616
- Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1617
-
1618
- grad = scalar<DType>(param.rescale_grad ) * grad +
1619
- scalar<DType>(param.wd ) * weight;
1620
-
1621
- if (param.clip_gradient >= 0 .0f ) {
1622
- state_n = scalar<DType>(1 .f - param.gamma1 ) *
1623
- F<clip>(grad, DType (param.clip_gradient )) *
1624
- F<clip>(grad, DType (param.clip_gradient )) +
1625
- scalar<DType>(param.gamma1 ) * state_n;
1626
- state_g = scalar<DType>(1 .f - param.gamma1 ) *
1627
- F<clip>(grad, DType (param.clip_gradient )) +
1628
- scalar<DType>(param.gamma1 ) * state_g;
1629
- delta = scalar<DType>(param.gamma2 ) * delta -
1630
- scalar<DType>(param.lr ) *
1631
- (F<clip>(grad, DType (param.clip_gradient )) /
1632
- (F<square_root>(state_n - state_g * state_g +
1633
- scalar<DType>(param.epsilon ))));
1634
- } else {
1635
- state_n = scalar<DType>(1 .f - param.gamma1 ) * (grad * grad) +
1636
- scalar<DType>(param.gamma1 ) * state_n;
1637
- state_g = scalar<DType>(1 .f - param.gamma1 ) * grad +
1638
- scalar<DType>(param.gamma1 ) * state_g;
1639
- delta = scalar<DType>(param.gamma2 ) * delta -
1640
- scalar<DType>(param.lr ) *
1641
- (grad / (F<square_root>(state_n - state_g * state_g +
1642
- scalar<DType>(param.epsilon ))));
1643
- }
1656
+ DType* weight_data = inputs[0 ].dptr <DType>();
1657
+ DType* grad_data = inputs[1 ].dptr <DType>();
1658
+ DType* state_n_data = inputs[2 ].dptr <DType>();
1659
+ DType* state_g_data = inputs[3 ].dptr <DType>();
1660
+ DType* delta_data = inputs[4 ].dptr <DType>();
1661
+ DType* out_data = outputs[0 ].dptr <DType>();
1644
1662
1645
- if (param.clip_weights >= 0 .0f ) {
1646
- Assign (out, req[0 ], F<clip>(weight + delta, DType (param.clip_weights )));
1647
- } else {
1648
- Assign (out, req[0 ], weight + delta);
1649
- }
1663
+ Kernel<RMSPropAlexUpdateKernel, xpu>::Launch (s, inputs[0 ].shape_ .Size (),
1664
+ out_data, state_n_data, state_g_data, delta_data, weight_data, grad_data,
1665
+ static_cast <DType>(param.clip_gradient ), static_cast <DType>(param.rescale_grad ),
1666
+ static_cast <DType>(param.gamma1 ), static_cast <DType>(param.gamma2 ),
1667
+ static_cast <DType>(param.lr ), static_cast <DType>(param.wd ),
1668
+ static_cast <DType>(param.clip_weights ), static_cast <DType>(param.epsilon ), req[0 ]);
1650
1669
});
1651
1670
}
1652
1671
@@ -1688,64 +1707,52 @@ struct RMSPropParam : public dmlc::Parameter<RMSPropParam> {
1688
1707
}
1689
1708
};
1690
1709
1710
+ struct RMSPropUpdateKernel {
1711
+ template <typename DType>
1712
+ MSHADOW_XINLINE static void Map (int i,
1713
+ DType* out_data, DType* state_n_data,
1714
+ const DType* weight_data, const DType* grad_data,
1715
+ const DType clip_gradient, const DType rescale_grad,
1716
+ const DType gamma1, const DType lr, const DType wd,
1717
+ const DType clip_weights, const DType epsilon,
1718
+ const OpReqType req) {
1719
+ using namespace mshadow_op ;
1720
+
1721
+ DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
1722
+ if (clip_gradient >= 0 .0f ) {
1723
+ grad_rescaled = clip::Map (grad_rescaled, clip_gradient);
1724
+ }
1725
+
1726
+ state_n_data[i] = (1 .f - gamma1) * (grad_rescaled * grad_rescaled) + gamma1 * state_n_data[i];
1727
+
1728
+ DType weight = weight_data[i] -
1729
+ lr * (grad_rescaled / square_root::Map (state_n_data[i] + epsilon));
1730
+ if (clip_weights >= 0 .0f ) {
1731
+ weight = clip::Map (weight, clip_weights);
1732
+ }
1733
+ KERNEL_ASSIGN (out_data[i], req, weight);
1734
+ }
1735
+ };
1736
+
1691
1737
template <typename xpu>
1692
1738
inline void RMSPropUpdate (const nnvm::NodeAttrs &attrs, const OpContext &ctx,
1693
1739
const std::vector<TBlob> &inputs,
1694
1740
const std::vector<OpReqType> &req,
1695
1741
const std::vector<TBlob> &outputs) {
1696
- using namespace mshadow ;
1697
- using namespace mshadow ::expr;
1698
- using namespace mshadow_op ;
1742
+ using namespace mxnet_op ;
1699
1743
const RMSPropParam ¶m = nnvm::get<RMSPropParam>(attrs.parsed );
1700
1744
Stream<xpu> *s = ctx.get_stream <xpu>();
1701
1745
MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1702
- Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s );
1703
- Tensor<xpu, 2 , DType> grad = inputs[1 ].FlatTo2D <xpu, DType>(s );
1704
- Tensor<xpu, 2 , DType> state_n = inputs[2 ].FlatTo2D <xpu, DType>(s );
1705
- Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s );
1746
+ DType* weight_data = inputs[0 ].dptr < DType>();
1747
+ DType* grad_data = inputs[1 ].dptr < DType>();
1748
+ DType* state_n_data = inputs[2 ].dptr < DType>();
1749
+ DType* out_data = outputs[0 ].dptr < DType>();
1706
1750
1707
- grad = scalar<DType>(param.rescale_grad ) * grad +
1708
- scalar<DType>(param.wd ) * weight;
1709
-
1710
- if (param.clip_gradient >= 0 .0f ) {
1711
- state_n = scalar<DType>(1 .f - param.gamma1 ) *
1712
- F<clip>(grad, DType (param.clip_gradient )) *
1713
- F<clip>(grad, DType (param.clip_gradient )) +
1714
- scalar<DType>(param.gamma1 ) * state_n;
1715
- if (param.clip_weights >= 0 .0f ) {
1716
- Assign (out, req[0 ],
1717
- F<clip>(weight -
1718
- scalar<DType>(param.lr ) *
1719
- (F<clip>(grad, DType (param.clip_gradient )) /
1720
- (F<square_root>(state_n +
1721
- scalar<DType>(param.epsilon )))),
1722
- DType (param.clip_weights )));
1723
- } else {
1724
- Assign (out, req[0 ], weight -
1725
- scalar<DType>(param.lr ) *
1726
- (F<clip>(grad, DType (param.clip_gradient )) /
1727
- (F<square_root>(state_n +
1728
- scalar<DType>(param.epsilon )))));
1729
- }
1730
- } else {
1731
- state_n = scalar<DType>(1 .f - param.gamma1 ) * (grad * grad) +
1732
- scalar<DType>(param.gamma1 ) * state_n;
1733
- if (param.clip_weights >= 0 .0f ) {
1734
- Assign (out, req[0 ],
1735
- F<clip>(weight -
1736
- scalar<DType>(param.lr ) *
1737
- (grad /
1738
- (F<square_root>(state_n +
1739
- scalar<DType>(param.epsilon )))),
1740
- DType (param.clip_weights )));
1741
- } else {
1742
- Assign (out, req[0 ], weight -
1743
- scalar<DType>(param.lr ) *
1744
- (grad /
1745
- (F<square_root>(state_n +
1746
- scalar<DType>(param.epsilon )))));
1747
- }
1748
- }
1751
+ Kernel<RMSPropUpdateKernel, xpu>::Launch (s, inputs[0 ].shape_ .Size (),
1752
+ out_data, state_n_data, weight_data, grad_data,
1753
+ static_cast <DType>(param.clip_gradient ), static_cast <DType>(param.rescale_grad ),
1754
+ static_cast <DType>(param.gamma1 ), static_cast <DType>(param.lr ), static_cast <DType>(param.wd ),
1755
+ static_cast <DType>(param.clip_weights ), static_cast <DType>(param.epsilon ), req[0 ]);
1749
1756
});
1750
1757
}
1751
1758
@@ -1781,15 +1788,41 @@ struct FtrlParam : public dmlc::Parameter<FtrlParam> {
1781
1788
}
1782
1789
};
1783
1790
1791
+ struct FtrlUpdateKernel {
1792
+ template <typename DType>
1793
+ MSHADOW_XINLINE static void Map (int i, DType* out_data,
1794
+ DType* n_data, DType* z_data, const DType* weight_data, const DType* grad_data,
1795
+ const DType clip_gradient, const DType rescale_grad,
1796
+ const DType beta, const DType lamda1,
1797
+ const DType lr, const DType wd,
1798
+ const OpReqType req) {
1799
+ using namespace mshadow_op ;
1800
+
1801
+ DType grad_rescaled = grad_data[i] * rescale_grad;
1802
+ if (clip_gradient >= 0 .0f ) {
1803
+ grad_rescaled = clip::Map (grad_rescaled, clip_gradient);
1804
+ }
1805
+
1806
+ z_data[i] += grad_rescaled - (square_root::Map (n_data[i] +
1807
+ square::Map (grad_rescaled)) - square_root::Map (n_data[i])) *
1808
+ weight_data[i] / lr;
1809
+ n_data[i] += square::Map (grad_rescaled);
1810
+
1811
+ KERNEL_ASSIGN (out_data[i], req,
1812
+ (sign::Map (z_data[i]) * lamda1 - z_data[i]) /
1813
+ ((beta + square_root::Map (n_data[i])) / lr + wd) *
1814
+ gt::Map (abs::Map (z_data[i]), lamda1));
1815
+ }
1816
+ };
1817
+
1784
1818
template <typename xpu>
1785
1819
inline void FtrlUpdate (const nnvm::NodeAttrs& attrs,
1786
1820
const OpContext &ctx,
1787
1821
const std::vector<TBlob> &inputs,
1788
1822
const std::vector<OpReqType> &req,
1789
1823
const std::vector<TBlob> &outputs) {
1790
- using namespace mshadow ;
1791
- using namespace mshadow ::expr;
1792
- using namespace mshadow_op ;
1824
+ using namespace mxnet_op ;
1825
+
1793
1826
const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed );
1794
1827
Stream<xpu>* s = ctx.get_stream <xpu>();
1795
1828
MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
@@ -1799,23 +1832,11 @@ inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
1799
1832
Tensor<xpu, 2 , DType> n = inputs[3 ].FlatTo2D <xpu, DType>(s);
1800
1833
Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1801
1834
1802
- grad = scalar<DType>(param.rescale_grad ) * grad;
1803
-
1804
- if (param.clip_gradient >= 0 .0f ) {
1805
- z += F<clip>(grad, DType (param.clip_gradient )) - (F<square_root>(n +
1806
- F<square>(F<clip>(grad, DType (param.clip_gradient )))) - F<square_root>(n)) *
1807
- weight / scalar<DType>(param.lr );
1808
- n += F<square>(F<clip>(grad, DType (param.clip_gradient )));
1809
- } else {
1810
- z += grad - (F<square_root>(n + F<square>(grad)) - F<square_root>(n)) *
1811
- weight / scalar<DType>(param.lr );
1812
- n += F<square>(grad);
1813
- }
1814
- Assign (out, req[0 ],
1815
- (F<sign>(z) * scalar<DType>(param.lamda1 ) - z) /
1816
- ((scalar<DType>(param.beta ) + F<square_root>(n)) /
1817
- scalar<DType>(param.lr ) + scalar<DType>(param.wd )) *
1818
- F<gt>(F<abs>(z), scalar<DType>(param.lamda1 )));
1835
+ Kernel<FtrlUpdateKernel, xpu>::Launch (s, weight.shape_ .Size (),
1836
+ out.dptr_ , n.dptr_ , z.dptr_ , weight.dptr_ , grad.dptr_ ,
1837
+ static_cast <DType>(param.clip_gradient ), static_cast <DType>(param.rescale_grad ),
1838
+ static_cast <DType>(param.beta ), static_cast <DType>(param.lamda1 ),
1839
+ static_cast <DType>(param.lr ), static_cast <DType>(param.wd ), req[0 ]);
1819
1840
});
1820
1841
}
1821
1842
0 commit comments