@@ -1751,6 +1751,164 @@ inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
1751
1751
});
1752
1752
}
1753
1753
1754
+ template <int n_in, int n_out, int total_in>
1755
+ inline bool MPLambPhaseOneType (const nnvm::NodeAttrs& attrs,
1756
+ std::vector<int > *in_attrs,
1757
+ std::vector<int > *out_attrs) {
1758
+ CHECK_EQ (in_attrs->size (), static_cast <size_t >(total_in)) << " in operator " << attrs.name ;
1759
+ CHECK_EQ (out_attrs->size (), static_cast <size_t >(n_out)) << " in operator " << attrs.name ;
1760
+ for (int i = 0 ; i < n_in; ++i) {
1761
+ TYPE_ASSIGN_CHECK (*in_attrs, i, mshadow::kFloat16 );
1762
+ }
1763
+ for (int i = n_in; i < total_in; ++i) {
1764
+ TYPE_ASSIGN_CHECK (*in_attrs, i, mshadow::kFloat32 );
1765
+ }
1766
+ for (int i = 0 ; i < n_out; ++i) {
1767
+ TYPE_ASSIGN_CHECK (*out_attrs, i, mshadow::kFloat32 );
1768
+ }
1769
+ return true ;
1770
+ }
1771
+
1772
+ struct MPLambUpdatePhaseOneKernel {
1773
+ template <typename DType>
1774
+ MSHADOW_XINLINE static void Map (int i, float * out_data,
1775
+ float * mean_data, float * var_data, const DType* weight_data,
1776
+ const DType* grad_data, const float * weight32_data,
1777
+ const float clip_gradient, const float rescale_grad,
1778
+ const float beta1_t , const float beta1,
1779
+ const float beta2_t , const float beta2,
1780
+ const float wd, const float epsilon, const int t,
1781
+ bool bias_correction, const OpReqType req) {
1782
+ using namespace mshadow_op ;
1783
+
1784
+ float grad_rescaled = grad_data[i] * rescale_grad;
1785
+ if (clip_gradient >= 0 .f ) {
1786
+ grad_rescaled = clip::Map (grad_rescaled, clip_gradient);
1787
+ }
1788
+
1789
+ mean_data[i] = beta1 * mean_data[i] + (1 .f - beta1) * grad_rescaled;
1790
+ var_data[i] = beta2 * var_data[i] + (1 .f - beta2) * grad_rescaled * grad_rescaled;
1791
+
1792
+ float g = mean_data[i] / (square_root::Map (var_data[i]) + epsilon) + wd * weight32_data[i];
1793
+
1794
+ if (bias_correction) {
1795
+ float mean_hat = mean_data[i] / (1 . - beta1_t );
1796
+ float var_hat = var_data[i] / (1 - beta2_t );
1797
+ g = mean_hat / (square_root::Map (var_hat) + epsilon) + wd * weight32_data[i];
1798
+ }
1799
+ KERNEL_ASSIGN (out_data[i], req, g);
1800
+ }
1801
+ };
1802
+
1803
+ template <typename xpu>
1804
+ inline void MPLambUpdatePhaseOne (const nnvm::NodeAttrs& attrs,
1805
+ const OpContext &ctx,
1806
+ const std::vector<TBlob> &inputs,
1807
+ const std::vector<OpReqType> &req,
1808
+ const std::vector<TBlob> &outputs) {
1809
+ using namespace mxnet_op ;
1810
+ const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed );
1811
+ Stream<xpu>* s = ctx.get_stream <xpu>();
1812
+ MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1813
+ float beta1_t = std::pow (param.beta1 , param.t );
1814
+ float beta2_t = std::pow (param.beta2 , param.t );
1815
+ Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1816
+ Tensor<xpu, 2 , DType> grad = inputs[1 ].FlatTo2D <xpu, DType>(s);
1817
+ Tensor<xpu, 2 , float > mean = inputs[2 ].FlatTo2D <xpu, float >(s);
1818
+ Tensor<xpu, 2 , float > var = inputs[3 ].FlatTo2D <xpu, float >(s);
1819
+ Tensor<xpu, 2 , float > weight32 = inputs[4 ].FlatTo2D <xpu, float >(s);
1820
+ Tensor<xpu, 2 , float > out = outputs[0 ].FlatTo2D <xpu, float >(s);
1821
+
1822
+ Kernel<MPLambUpdatePhaseOneKernel, xpu>::Launch (s, weight.shape_ .Size (),
1823
+ out.dptr_ , mean.dptr_ , var.dptr_ , weight.dptr_ , grad.dptr_ , weight32.dptr_ ,
1824
+ param.clip_gradient , param.rescale_grad , beta1_t , param.beta1 , beta2_t , param.beta2 ,
1825
+ param.wd , param.epsilon , param.t , param.bias_correction , req[0 ]);
1826
+ });
1827
+ }
1828
+
1829
+ inline bool MPLambUpdatePhaseTwoShape (const nnvm::NodeAttrs& attrs,
1830
+ mxnet::ShapeVector* in_attrs,
1831
+ mxnet::ShapeVector* out_attrs) {
1832
+ CHECK_EQ (in_attrs->size (), 5U );
1833
+ CHECK_EQ (out_attrs->size (), 1U );
1834
+
1835
+ mxnet::TShape expected_out (in_attrs->at (0 ).ndim (), -1 );
1836
+
1837
+ mxnet::TShape& weight_shape = in_attrs->at (0 );
1838
+ mxnet::TShape& g_shape = in_attrs->at (1 );
1839
+ mxnet::TShape& weight32_shape = in_attrs->at (4 );
1840
+ CHECK_EQ (weight_shape.ndim (), g_shape.ndim ())
1841
+ << " total no. of dimensions for weights and g must match" ;
1842
+ CHECK_EQ (weight_shape.ndim (), weight32_shape.ndim ())
1843
+ << " total no. of dimensions for weights and g must match" ;
1844
+ for (int i=0 ; i < weight_shape.ndim (); ++i) {
1845
+ CHECK_EQ (weight_shape[i], g_shape[i])
1846
+ << " weight and g dimension size mismatch at " << i << " -th index" ;
1847
+ CHECK_EQ (weight_shape[i], weight32_shape[i])
1848
+ << " weight and g dimension size mismatch at " << i << " -th index" ;
1849
+ }
1850
+ mxnet::TShape& r1_shape = in_attrs->at (2 );
1851
+ mxnet::TShape& r2_shape = in_attrs->at (3 );
1852
+ CHECK_EQ (r1_shape[0 ], 1U ) << " r1 shape incorrect" ;
1853
+ CHECK_EQ (r2_shape[0 ], 1U ) << " r2 shape incorrect" ;
1854
+ for (int i=0 ; i < expected_out.ndim (); ++i) {
1855
+ expected_out[i] = weight_shape[i];
1856
+ }
1857
+
1858
+ SHAPE_ASSIGN_CHECK (*out_attrs, 0 , expected_out);
1859
+ return shape_is_known (expected_out);
1860
+ }
1861
+
1862
+ struct MPLambUpdatePhaseTwoKernel {
1863
+ template <typename DType>
1864
+ MSHADOW_XINLINE static void Map (int i, DType* out_data,
1865
+ const DType* weight_data, const float * g,
1866
+ const float * r1, const float * r2, const float * weight32_data,
1867
+ float lr, const float lower_bound,
1868
+ const float upper_bound, const OpReqType req) {
1869
+ using namespace mshadow_op ;
1870
+
1871
+ float new_r1 = r1[0 ];
1872
+ if (lower_bound >= 0 ) {
1873
+ new_r1 = maximum::Map (new_r1, lower_bound);
1874
+ }
1875
+ if (upper_bound >= 0 ) {
1876
+ new_r1 = minimum::Map (new_r1, upper_bound);
1877
+ }
1878
+ if (new_r1 == 0 .0f || r2[0 ] == 0 .0f ) {
1879
+ lr = lr * 1 .0f ;
1880
+ } else {
1881
+ lr = lr * new_r1 / r2[0 ];
1882
+ }
1883
+
1884
+ KERNEL_ASSIGN (out_data[i], req, weight32_data[i] - lr * g[i]);
1885
+ }
1886
+ };
1887
+
1888
+ template <typename xpu>
1889
+ inline void MPLambUpdatePhaseTwo (const nnvm::NodeAttrs& attrs,
1890
+ const OpContext &ctx,
1891
+ const std::vector<TBlob> &inputs,
1892
+ const std::vector<OpReqType> &req,
1893
+ const std::vector<TBlob> &outputs) {
1894
+ using namespace mxnet_op ;
1895
+ const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed );
1896
+ Stream<xpu>* s = ctx.get_stream <xpu>();
1897
+ MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1898
+ Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1899
+ Tensor<xpu, 2 , float > g = inputs[1 ].FlatTo2D <xpu, float >(s);
1900
+ Tensor<xpu, 2 , float > r1 = inputs[2 ].FlatTo2D <xpu, float >(s);
1901
+ Tensor<xpu, 2 , float > r2 = inputs[3 ].FlatTo2D <xpu, float >(s);
1902
+ Tensor<xpu, 2 , float > weight32 = inputs[4 ].FlatTo2D <xpu, float >(s);
1903
+ Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1904
+
1905
+ Kernel<MPLambUpdatePhaseTwoKernel, xpu>::Launch (s, weight.shape_ .Size (),
1906
+ out.dptr_ , weight.dptr_ , g.dptr_ , r1.dptr_ , r2.dptr_ , weight32.dptr_ ,
1907
+ param.lr , param.lower_bound ,
1908
+ param.upper_bound , req[0 ]);
1909
+ });
1910
+ }
1911
+
1754
1912
// This RMSProp code follows the version in
1755
1913
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
1756
1914
// by Alex Graves, 2013.
@@ -2493,5 +2651,4 @@ inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
2493
2651
} // namespace op
2494
2652
} // namespace mxnet
2495
2653
2496
-
2497
2654
#endif // MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
0 commit comments