@@ -1563,6 +1563,192 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
1563
1563
}
1564
1564
}
1565
1565
1566
+ struct LambUpdatePhaseOneParam : public dmlc ::Parameter<LambUpdatePhaseOneParam> {
1567
+ float beta1;
1568
+ float beta2;
1569
+ float epsilon;
1570
+ float t;
1571
+ bool bias_correction;
1572
+ float wd;
1573
+ float rescale_grad;
1574
+ float clip_gradient;
1575
+ DMLC_DECLARE_PARAMETER (LambUpdatePhaseOneParam) {
1576
+ DMLC_DECLARE_FIELD (beta1)
1577
+ .set_default (0 .9f )
1578
+ .describe (" The decay rate for the 1st moment estimates." );
1579
+ DMLC_DECLARE_FIELD (beta2)
1580
+ .set_default (0 .999f )
1581
+ .describe (" The decay rate for the 2nd moment estimates." );
1582
+ DMLC_DECLARE_FIELD (epsilon)
1583
+ .set_default (1e-6f )
1584
+ .describe (" A small constant for numerical stability." );
1585
+ DMLC_DECLARE_FIELD (t)
1586
+ .describe (" Index update count." );
1587
+ DMLC_DECLARE_FIELD (bias_correction)
1588
+ .set_default (true )
1589
+ .describe (" Whether to use bias correction." );
1590
+ DMLC_DECLARE_FIELD (wd)
1591
+ .describe (" Weight decay augments the objective function with a "
1592
+ " regularization term that penalizes large weights. "
1593
+ " The penalty scales with the square of the magnitude of each weight." );
1594
+ DMLC_DECLARE_FIELD (rescale_grad)
1595
+ .set_default (1 .0f )
1596
+ .describe (" Rescale gradient to grad = rescale_grad*grad." );
1597
+ DMLC_DECLARE_FIELD (clip_gradient)
1598
+ .set_default (-1 .0f )
1599
+ .describe (" Clip gradient to the range of [-clip_gradient, clip_gradient] "
1600
+ " If clip_gradient <= 0, gradient clipping is turned off. "
1601
+ " grad = max(min(grad, clip_gradient), -clip_gradient)." );
1602
+ }
1603
+ };
1604
+
1605
+ struct LambUpdatePhaseTwoParam : public dmlc ::Parameter<LambUpdatePhaseTwoParam> {
1606
+ float lr;
1607
+ float lower_bound;
1608
+ float upper_bound;
1609
+ DMLC_DECLARE_PARAMETER (LambUpdatePhaseTwoParam) {
1610
+ DMLC_DECLARE_FIELD (lr)
1611
+ .describe (" Learning rate" );
1612
+ DMLC_DECLARE_FIELD (lower_bound)
1613
+ .set_default (-1 .0f )
1614
+ .describe (" Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set" );
1615
+ DMLC_DECLARE_FIELD (upper_bound)
1616
+ .set_default (-1 .0f )
1617
+ .describe (" Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set" );
1618
+ }
1619
+ };
1620
+
1621
+ struct LambUpdatePhaseOneKernel {
1622
+ template <typename DType>
1623
+ MSHADOW_XINLINE static void Map (int i, DType* out_data,
1624
+ DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1625
+ const DType clip_gradient, const DType rescale_grad,
1626
+ const DType beta1, const DType beta2, const DType wd,
1627
+ const DType epsilon, const DType t,
1628
+ bool bias_correction, const OpReqType req) {
1629
+ using namespace mshadow_op ;
1630
+
1631
+ DType grad_rescaled = grad_data[i] * rescale_grad;
1632
+ if (clip_gradient >= 0 .f ) {
1633
+ grad_rescaled = clip::Map (grad_rescaled, clip_gradient);
1634
+ }
1635
+
1636
+ mean_data[i] = beta1 * mean_data[i] + (1 .f - beta1) * grad_rescaled;
1637
+ var_data[i] = beta2 * var_data[i] + (1 .f - beta2) * grad_rescaled * grad_rescaled;
1638
+
1639
+ DType g = mean_data[i] / (square_root::Map (var_data[i]) + epsilon) + wd * weight_data[i];
1640
+
1641
+ if (bias_correction) {
1642
+ DType mean_hat = mean_data[i] / (1 . - power::Map (beta1, t));
1643
+ DType var_hat = var_data[i] / (1 - power::Map (beta2, t));
1644
+ g = mean_hat / (square_root::Map (var_hat) + epsilon) + wd * weight_data[i];
1645
+ }
1646
+ KERNEL_ASSIGN (out_data[i], req, g);
1647
+ }
1648
+ };
1649
+
1650
+ template <typename xpu>
1651
+ inline void LambUpdatePhaseOne (const nnvm::NodeAttrs& attrs,
1652
+ const OpContext &ctx,
1653
+ const std::vector<TBlob> &inputs,
1654
+ const std::vector<OpReqType> &req,
1655
+ const std::vector<TBlob> &outputs) {
1656
+ using namespace mxnet_op ;
1657
+ const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed );
1658
+ Stream<xpu>* s = ctx.get_stream <xpu>();
1659
+ MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1660
+ Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1661
+ Tensor<xpu, 2 , DType> grad = inputs[1 ].FlatTo2D <xpu, DType>(s);
1662
+ Tensor<xpu, 2 , DType> mean = inputs[2 ].FlatTo2D <xpu, DType>(s);
1663
+ Tensor<xpu, 2 , DType> var = inputs[3 ].FlatTo2D <xpu, DType>(s);
1664
+ Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1665
+
1666
+ Kernel<LambUpdatePhaseOneKernel, xpu>::Launch (s, weight.shape_ .Size (),
1667
+ out.dptr_ , mean.dptr_ , var.dptr_ , weight.dptr_ , grad.dptr_ ,
1668
+ static_cast <DType>(param.clip_gradient ), static_cast <DType>(param.rescale_grad ),
1669
+ static_cast <DType>(param.beta1 ), static_cast <DType>(param.beta2 ),
1670
+ static_cast <DType>(param.wd ), static_cast <DType>(param.epsilon ),
1671
+ static_cast <DType>(param.t ), static_cast <bool >(param.bias_correction ), req[0 ]);
1672
+ });
1673
+ }
1674
+
1675
+ inline bool LambUpdatePhaseTwoShape (const nnvm::NodeAttrs& attrs,
1676
+ mxnet::ShapeVector* in_attrs,
1677
+ mxnet::ShapeVector* out_attrs) {
1678
+ CHECK_EQ (in_attrs->size (), 4U );
1679
+ CHECK_EQ (out_attrs->size (), 1U );
1680
+
1681
+ mxnet::TShape expected_out (in_attrs->at (0 ).ndim (), -1 );
1682
+
1683
+ mxnet::TShape& weight_shape = in_attrs->at (0 );
1684
+ mxnet::TShape& g_shape = in_attrs->at (1 );
1685
+ CHECK_EQ (weight_shape.ndim (), g_shape.ndim ())
1686
+ << " total no. of dimensions for weights and g must match" ;
1687
+ for (int i=0 ; i < weight_shape.ndim (); ++i) {
1688
+ CHECK_EQ (weight_shape[i], g_shape[i])
1689
+ << " weight and g dimension size mismatch at " << i << " -th index" ;
1690
+ }
1691
+ mxnet::TShape& r1_shape = in_attrs->at (2 );
1692
+ mxnet::TShape& r2_shape = in_attrs->at (3 );
1693
+ CHECK_EQ (r1_shape[0 ], 1U ) << " r1 shape incorrect" ;
1694
+ CHECK_EQ (r2_shape[0 ], 1U ) << " r2 shape incorrect" ;
1695
+ for (int i=0 ; i < expected_out.ndim (); ++i) {
1696
+ expected_out[i] = weight_shape[i];
1697
+ }
1698
+
1699
+ SHAPE_ASSIGN_CHECK (*out_attrs, 0 , expected_out);
1700
+ return shape_is_known (expected_out);
1701
+ }
1702
+
1703
+ struct LambUpdatePhaseTwoKernel {
1704
+ template <typename DType>
1705
+ MSHADOW_XINLINE static void Map (int i, DType* out_data,
1706
+ const DType* weight_data, const DType* g,
1707
+ const DType* r1, const DType* r2,
1708
+ DType lr, const DType lower_bound,
1709
+ const DType upper_bound, const OpReqType req) {
1710
+ using namespace mshadow_op ;
1711
+
1712
+ DType new_r1 = r1[0 ];
1713
+ if (lower_bound >= 0 ) {
1714
+ new_r1 = maximum::Map (new_r1, lower_bound);
1715
+ }
1716
+ if (upper_bound >= 0 ) {
1717
+ new_r1 = minimum::Map (new_r1, upper_bound);
1718
+ }
1719
+ if (new_r1 == 0 .0f || r2[0 ] == 0 .0f ) {
1720
+ lr = lr * 1 .0f ;
1721
+ } else {
1722
+ lr = lr * new_r1 / r2[0 ];
1723
+ }
1724
+
1725
+ KERNEL_ASSIGN (out_data[i], req, weight_data[i] - lr * g[i]);
1726
+ }
1727
+ };
1728
+
1729
+ template <typename xpu>
1730
+ inline void LambUpdatePhaseTwo (const nnvm::NodeAttrs& attrs,
1731
+ const OpContext &ctx,
1732
+ const std::vector<TBlob> &inputs,
1733
+ const std::vector<OpReqType> &req,
1734
+ const std::vector<TBlob> &outputs) {
1735
+ using namespace mxnet_op ;
1736
+ const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed );
1737
+ Stream<xpu>* s = ctx.get_stream <xpu>();
1738
+ MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1739
+ Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1740
+ Tensor<xpu, 2 , DType> g = inputs[1 ].FlatTo2D <xpu, DType>(s);
1741
+ Tensor<xpu, 2 , DType> r1 = inputs[2 ].FlatTo2D <xpu, DType>(s);
1742
+ Tensor<xpu, 2 , DType> r2 = inputs[3 ].FlatTo2D <xpu, DType>(s);
1743
+ Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1744
+
1745
+ Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch (s, weight.shape_ .Size (),
1746
+ out.dptr_ , weight.dptr_ , g.dptr_ , r1.dptr_ , r2.dptr_ ,
1747
+ static_cast <DType>(param.lr ), static_cast <DType>(param.lower_bound ),
1748
+ static_cast <DType>(param.upper_bound ), req[0 ]);
1749
+ });
1750
+ }
1751
+
1566
1752
// This RMSProp code follows the version in
1567
1753
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
1568
1754
// by Alex Graves, 2013.
0 commit comments