@@ -1564,20 +1564,15 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
1564
1564
}
1565
1565
1566
1566
struct LAMBParam : public dmlc ::Parameter<LAMBParam> {
1567
- float lr;
1568
1567
float beta1;
1569
1568
float beta2;
1570
1569
float epsilon;
1571
- float lower_bound;
1572
- float upper_bound;
1573
1570
float t;
1574
1571
bool bias_correction;
1575
1572
float wd;
1576
1573
float rescale_grad;
1577
1574
float clip_gradient;
1578
1575
DMLC_DECLARE_PARAMETER (LAMBParam) {
1579
- DMLC_DECLARE_FIELD (lr)
1580
- .describe (" Learning rate" );
1581
1576
DMLC_DECLARE_FIELD (beta1)
1582
1577
.set_default (0 .9f )
1583
1578
.describe (" The decay rate for the 1st moment estimates." );
@@ -1587,19 +1582,12 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
1587
1582
DMLC_DECLARE_FIELD (epsilon)
1588
1583
.set_default (1e-6f )
1589
1584
.describe (" A small constant for numerical stability." );
1590
- DMLC_DECLARE_FIELD (lower_bound)
1591
- .set_default (1e-3f )
1592
- .describe (" Lower limit of norm of weight." );
1593
- DMLC_DECLARE_FIELD (upper_bound)
1594
- .set_default (10 .0f )
1595
- .describe (" Upper limit of norm of weight." );
1596
1585
DMLC_DECLARE_FIELD (t)
1597
1586
.describe (" Index update count." );
1598
1587
DMLC_DECLARE_FIELD (bias_correction)
1599
1588
.set_default (false )
1600
1589
.describe (" Whether to use bias correction." );
1601
1590
DMLC_DECLARE_FIELD (wd)
1602
- .set_default (0 .0f )
1603
1591
.describe (" Weight decay augments the objective function with a "
1604
1592
" regularization term that penalizes large weights. "
1605
1593
" The penalty scales with the square of the magnitude of each weight." );
@@ -1614,44 +1602,48 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
1614
1602
}
1615
1603
};
1616
1604
1605
+ struct LAMBWeightParam : public dmlc ::Parameter<LAMBWeightParam> {
1606
+ float lr;
1607
+ float lower_bound;
1608
+ float upper_bound;
1609
+ DMLC_DECLARE_PARAMETER (LAMBWeightParam) {
1610
+ DMLC_DECLARE_FIELD (lr)
1611
+ .describe (" Learning rate" );
1612
+ DMLC_DECLARE_FIELD (lower_bound)
1613
+ .set_default (1e-3f )
1614
+ .describe (" Lower limit of norm of weight." );
1615
+ DMLC_DECLARE_FIELD (upper_bound)
1616
+ .set_default (10 .0f )
1617
+ .describe (" Upper limit of norm of weight." );
1618
+ }
1619
+ };
1620
+
1617
1621
struct LAMBUpdateKernel {
1618
1622
template <typename DType>
1619
1623
MSHADOW_XINLINE static void Map (int i, DType* out_data,
1620
1624
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1621
1625
const DType clip_gradient, const DType rescale_grad,
1622
- const DType beta1, const DType beta2,
1623
- DType lr, const DType wd,
1624
- const DType epsilon, const DType lower_bound,
1625
- const DType upper_bound, const DType t,
1626
+ const DType beta1, const DType beta2,const DType wd,
1627
+ const DType epsilon, const DType t,
1626
1628
bool bias_correction, const OpReqType req) {
1627
1629
using namespace mshadow_op ;
1628
1630
1629
- DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd ;
1631
+ DType grad_rescaled = grad_data[i] * rescale_grad;
1630
1632
if (clip_gradient >= 0 .f ) {
1631
1633
grad_rescaled = clip::Map (grad_rescaled, clip_gradient);
1632
1634
}
1633
1635
1634
1636
mean_data[i] = beta1 * mean_data[i] + (1 .f - beta1) * grad_rescaled;
1635
1637
var_data[i] = beta2 * var_data[i] + (1 .f - beta2) * grad_rescaled * grad_rescaled;
1636
1638
1637
- DType r1 = square_root::Map (square::Map (weight_data[i]));
1638
-
1639
- r1 = minimum::Map (maximum::Map (r1, lower_bound), upper_bound);
1640
- DType g = mean_data[i] / square_root::Map (var_data[i] + epsilon) + wd * weight_data[i];
1639
+ DType g = mean_data[i] / (square_root::Map (var_data[i]) + epsilon) + wd * weight_data[i];
1641
1640
1642
1641
if (bias_correction) {
1643
1642
DType mean_hat = mean_data[i] / (1 . - power::Map (beta1, t));
1644
1643
DType var_hat = var_data[i] / (1 - power::Map (beta2, t));
1645
- g = mean_hat / square_root::Map (var_hat + epsilon) + wd * weight_data[i];
1646
- }
1647
- DType r2 = square_root::Map (square::Map (g));
1648
- if (r1 == 0 .0f || r2 == 0 .0f ) {
1649
- lr = lr * 1 .0f ;
1650
- } else {
1651
- lr = lr * r1 / r2;
1644
+ g = mean_hat / (square_root::Map (var_hat) + epsilon) + wd * weight_data[i];
1652
1645
}
1653
-
1654
- KERNEL_ASSIGN (out_data[i], req, weight_data[i] - lr * g);
1646
+ KERNEL_ASSIGN (out_data[i], req, g);
1655
1647
}
1656
1648
};
1657
1649
@@ -1661,9 +1653,9 @@ inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
1661
1653
const std::vector<TBlob> &inputs,
1662
1654
const std::vector<OpReqType> &req,
1663
1655
const std::vector<TBlob> &outputs) {
1664
- using namespace mxnet_op ;
1665
- const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed );
1666
- Stream<xpu>* s = ctx.get_stream <xpu>();
1656
+ using namespace mxnet_op ;
1657
+ const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed );
1658
+ Stream<xpu>* s = ctx.get_stream <xpu>();
1667
1659
MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1668
1660
Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1669
1661
Tensor<xpu, 2 , DType> grad = inputs[1 ].FlatTo2D <xpu, DType>(s);
@@ -1675,13 +1667,85 @@ inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
1675
1667
out.dptr_ , mean.dptr_ , var.dptr_ , weight.dptr_ , grad.dptr_ ,
1676
1668
static_cast <DType>(param.clip_gradient ), static_cast <DType>(param.rescale_grad ),
1677
1669
static_cast <DType>(param.beta1 ), static_cast <DType>(param.beta2 ),
1678
- static_cast <DType>(param.lr ), static_cast <DType>(param.wd ),
1679
- static_cast <DType>(param.epsilon ), static_cast <DType>(param.lower_bound ),
1680
- static_cast <DType>(param.upper_bound ), static_cast <DType>(param.t ),
1681
- static_cast <bool >(param.bias_correction ), req[0 ]);
1682
- });
1670
+ static_cast <DType>(param.wd ), static_cast <DType>(param.epsilon ),
1671
+ static_cast <DType>(param.t ), static_cast <bool >(param.bias_correction ), req[0 ]);
1672
+ });
1673
+ }
1674
+
1675
+ inline bool LambWeightShape (const nnvm::NodeAttrs& attrs,
1676
+ mxnet::ShapeVector* in_attrs,
1677
+ mxnet::ShapeVector* out_attrs) {
1678
+ CHECK_EQ (in_attrs->size (), 4U );
1679
+ CHECK_EQ (out_attrs->size (), 1U );
1680
+
1681
+ mxnet::TShape expected_out (in_attrs->at (0 ).ndim (), -1 );
1682
+
1683
+ mxnet::TShape& weight_shape = in_attrs->at (0 );
1684
+ mxnet::TShape& g_shape = in_attrs->at (1 );
1685
+ CHECK_EQ (weight_shape.ndim (), g_shape.ndim ()) << " total no. of dimensions for weights and g must match" ;
1686
+ for (int i=0 ; i < weight_shape.ndim (); ++i){
1687
+ CHECK_EQ (weight_shape[i], g_shape[i]) << " weight and g dimension size mismatch at " << i << " -th index" ;
1688
+ }
1689
+ mxnet::TShape& r1_shape = in_attrs->at (2 );
1690
+ mxnet::TShape& r2_shape = in_attrs->at (3 );
1691
+ CHECK_EQ (r1_shape[0 ], 1U ) << " r1 shape incorrect" ;
1692
+ CHECK_EQ (r2_shape[0 ], 1U ) << " r2 shape incorrect" ;
1693
+ for (int i=0 ; i < expected_out.ndim (); ++i) {
1694
+ expected_out[i] = weight_shape[i];
1695
+ }
1696
+
1697
+ SHAPE_ASSIGN_CHECK (*out_attrs, 0 , expected_out);
1698
+ return shape_is_known (expected_out);
1683
1699
}
1684
1700
1701
+ struct LAMBWeightUpdateKernel {
1702
+ template <typename DType>
1703
+ MSHADOW_XINLINE static void Map (int i, DType* out_data,
1704
+ const DType* weight_data, const DType* g,
1705
+ const DType* r1, const DType* r2,
1706
+ DType lr, const DType lower_bound,
1707
+ const DType upper_bound, const OpReqType req) {
1708
+ using namespace mshadow_op ;
1709
+
1710
+ DType new_r1 = r1[0 ];
1711
+ if (lower_bound >= 0 ) {
1712
+ new_r1 = maximum::Map (new_r1, lower_bound);
1713
+ }
1714
+ if (upper_bound >= 0 ) {
1715
+ new_r1 = minimum::Map (new_r1, upper_bound);
1716
+ }
1717
+ if (new_r1 == 0 .0f || r2[0 ] == 0 .0f ) {
1718
+ lr = lr * 1 .0f ;
1719
+ } else {
1720
+ lr = lr * new_r1 / r2[0 ];
1721
+ }
1722
+
1723
+ KERNEL_ASSIGN (out_data[i], req, weight_data[i] - lr * g[i]);
1724
+ }
1725
+ };
1726
+
1727
+ template <typename xpu>
1728
+ inline void LAMBWeightUpdate (const nnvm::NodeAttrs& attrs,
1729
+ const OpContext &ctx,
1730
+ const std::vector<TBlob> &inputs,
1731
+ const std::vector<OpReqType> &req,
1732
+ const std::vector<TBlob> &outputs) {
1733
+ using namespace mxnet_op ;
1734
+ const LAMBWeightParam& param = nnvm::get<LAMBWeightParam>(attrs.parsed );
1735
+ Stream<xpu>* s = ctx.get_stream <xpu>();
1736
+ MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1737
+ Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1738
+ Tensor<xpu, 2 , DType> g = inputs[1 ].FlatTo2D <xpu, DType>(s);
1739
+ Tensor<xpu, 2 , DType> r1 = inputs[2 ].FlatTo2D <xpu, DType>(s);
1740
+ Tensor<xpu, 2 , DType> r2 = inputs[3 ].FlatTo2D <xpu, DType>(s);
1741
+ Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1742
+
1743
+ Kernel<LAMBWeightUpdateKernel, xpu>::Launch (s, weight.shape_ .Size (),
1744
+ out.dptr_ , weight.dptr_ , g.dptr_ , r1.dptr_ , r2.dptr_ ,
1745
+ static_cast <DType>(param.lr ), static_cast <DType>(param.lower_bound ),
1746
+ static_cast <DType>(param.upper_bound ), req[0 ]);
1747
+ });
1748
+ }
1685
1749
1686
1750
// This RMSProp code follows the version in
1687
1751
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
0 commit comments