@@ -1563,21 +1563,16 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
1563
1563
}
1564
1564
}
1565
1565
1566
- struct LAMBParam : public dmlc ::Parameter<LAMBParam> {
1567
- float lr;
1566
+ struct LambUpdatePhaseOneParam : public dmlc ::Parameter<LambUpdatePhaseOneParam> {
1568
1567
float beta1;
1569
1568
float beta2;
1570
1569
float epsilon;
1571
- float lower_bound;
1572
- float upper_bound;
1573
1570
float t;
1574
1571
bool bias_correction;
1575
1572
float wd;
1576
1573
float rescale_grad;
1577
1574
float clip_gradient;
1578
- DMLC_DECLARE_PARAMETER (LAMBParam) {
1579
- DMLC_DECLARE_FIELD (lr)
1580
- .describe (" Learning rate" );
1575
+ DMLC_DECLARE_PARAMETER (LambUpdatePhaseOneParam) {
1581
1576
DMLC_DECLARE_FIELD (beta1)
1582
1577
.set_default (0 .9f )
1583
1578
.describe (" The decay rate for the 1st moment estimates." );
@@ -1587,19 +1582,12 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
1587
1582
DMLC_DECLARE_FIELD (epsilon)
1588
1583
.set_default (1e-6f )
1589
1584
.describe (" A small constant for numerical stability." );
1590
- DMLC_DECLARE_FIELD (lower_bound)
1591
- .set_default (1e-3f )
1592
- .describe (" Lower limit of norm of weight." );
1593
- DMLC_DECLARE_FIELD (upper_bound)
1594
- .set_default (10 .0f )
1595
- .describe (" Upper limit of norm of weight." );
1596
1585
DMLC_DECLARE_FIELD (t)
1597
1586
.describe (" Index update count." );
1598
1587
DMLC_DECLARE_FIELD (bias_correction)
1599
1588
.set_default (false )
1600
1589
.describe (" Whether to use bias correction." );
1601
1590
DMLC_DECLARE_FIELD (wd)
1602
- .set_default (0 .0f )
1603
1591
.describe (" Weight decay augments the objective function with a "
1604
1592
" regularization term that penalizes large weights. "
1605
1593
" The penalty scales with the square of the magnitude of each weight." );
@@ -1614,74 +1602,152 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
1614
1602
}
1615
1603
};
1616
1604
1617
- struct LAMBUpdateKernel {
1605
+ struct LambUpdatePhaseTwoParam : public dmlc ::Parameter<LambUpdatePhaseTwoParam> {
1606
+ float lr;
1607
+ float lower_bound;
1608
+ float upper_bound;
1609
+ DMLC_DECLARE_PARAMETER (LambUpdatePhaseTwoParam) {
1610
+ DMLC_DECLARE_FIELD (lr)
1611
+ .describe (" Learning rate" );
1612
+ DMLC_DECLARE_FIELD (lower_bound)
1613
+ .set_default (-1 .0f )
1614
+ .describe (" Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set" );
1615
+ DMLC_DECLARE_FIELD (upper_bound)
1616
+ .set_default (-1 .0f )
1617
+ .describe (" Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set" );
1618
+ }
1619
+ };
1620
+
1621
+ struct LambUpdatePhaseOneKernel {
1618
1622
template <typename DType>
1619
1623
MSHADOW_XINLINE static void Map (int i, DType* out_data,
1620
1624
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1621
1625
const DType clip_gradient, const DType rescale_grad,
1622
- const DType beta1, const DType beta2,
1623
- DType lr, const DType wd,
1624
- const DType epsilon, const DType lower_bound,
1625
- const DType upper_bound, const DType t,
1626
+ const DType beta1, const DType beta2, const DType wd,
1627
+ const DType epsilon, const DType t,
1626
1628
bool bias_correction, const OpReqType req) {
1627
1629
using namespace mshadow_op ;
1628
1630
1629
- DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd ;
1631
+ DType grad_rescaled = grad_data[i] * rescale_grad;
1630
1632
if (clip_gradient >= 0 .f ) {
1631
1633
grad_rescaled = clip::Map (grad_rescaled, clip_gradient);
1632
1634
}
1633
1635
1634
1636
mean_data[i] = beta1 * mean_data[i] + (1 .f - beta1) * grad_rescaled;
1635
1637
var_data[i] = beta2 * var_data[i] + (1 .f - beta2) * grad_rescaled * grad_rescaled;
1636
1638
1637
- DType r1 = square_root::Map (square::Map (weight_data[i]));
1638
-
1639
- r1 = minimum::Map (maximum::Map (r1, lower_bound), upper_bound);
1640
- DType g = mean_data[i] / square_root::Map (var_data[i] + epsilon) + wd * weight_data[i];
1639
+ DType g = mean_data[i] / (square_root::Map (var_data[i]) + epsilon) + wd * weight_data[i];
1641
1640
1642
1641
if (bias_correction) {
1643
1642
DType mean_hat = mean_data[i] / (1 . - power::Map (beta1, t));
1644
1643
DType var_hat = var_data[i] / (1 - power::Map (beta2, t));
1645
- g = mean_hat / square_root::Map (var_hat + epsilon) + wd * weight_data[i];
1646
- }
1647
- DType r2 = square_root::Map (square::Map (g));
1648
- if (r1 == 0 .0f || r2 == 0 .0f ) {
1649
- lr = lr * 1 .0f ;
1650
- } else {
1651
- lr = lr * r1 / r2;
1644
+ g = mean_hat / (square_root::Map (var_hat) + epsilon) + wd * weight_data[i];
1652
1645
}
1653
-
1654
- KERNEL_ASSIGN (out_data[i], req, weight_data[i] - lr * g);
1646
+ KERNEL_ASSIGN (out_data[i], req, g);
1655
1647
}
1656
1648
};
1657
1649
1658
1650
template <typename xpu>
1659
- inline void LAMBUpdate (const nnvm::NodeAttrs& attrs,
1651
+ inline void LambUpdatePhaseOne (const nnvm::NodeAttrs& attrs,
1660
1652
const OpContext &ctx,
1661
1653
const std::vector<TBlob> &inputs,
1662
1654
const std::vector<OpReqType> &req,
1663
1655
const std::vector<TBlob> &outputs) {
1664
- using namespace mxnet_op ;
1665
- const LAMBParam & param = nnvm::get<LAMBParam >(attrs.parsed );
1666
- Stream<xpu>* s = ctx.get_stream <xpu>();
1656
+ using namespace mxnet_op ;
1657
+ const LambUpdatePhaseOneParam & param = nnvm::get<LambUpdatePhaseOneParam >(attrs.parsed );
1658
+ Stream<xpu>* s = ctx.get_stream <xpu>();
1667
1659
MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1668
1660
Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1669
1661
Tensor<xpu, 2 , DType> grad = inputs[1 ].FlatTo2D <xpu, DType>(s);
1670
1662
Tensor<xpu, 2 , DType> mean = inputs[2 ].FlatTo2D <xpu, DType>(s);
1671
1663
Tensor<xpu, 2 , DType> var = inputs[3 ].FlatTo2D <xpu, DType>(s);
1672
1664
Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1673
1665
1674
- Kernel<LAMBUpdateKernel , xpu>::Launch (s, weight.shape_ .Size (),
1666
+ Kernel<LambUpdatePhaseOneKernel , xpu>::Launch (s, weight.shape_ .Size (),
1675
1667
out.dptr_ , mean.dptr_ , var.dptr_ , weight.dptr_ , grad.dptr_ ,
1676
1668
static_cast <DType>(param.clip_gradient ), static_cast <DType>(param.rescale_grad ),
1677
1669
static_cast <DType>(param.beta1 ), static_cast <DType>(param.beta2 ),
1678
- static_cast <DType>(param.lr ), static_cast <DType>(param.wd ),
1679
- static_cast <DType>(param.epsilon ), static_cast <DType>(param.lower_bound ),
1680
- static_cast <DType>(param.upper_bound ), static_cast <DType>(param.t ),
1681
- static_cast <bool >(param.bias_correction ), req[0 ]);
1682
- });
1670
+ static_cast <DType>(param.wd ), static_cast <DType>(param.epsilon ),
1671
+ static_cast <DType>(param.t ), static_cast <bool >(param.bias_correction ), req[0 ]);
1672
+ });
1673
+ }
1674
+
1675
+ inline bool LambUpdatePhaseTwoShape (const nnvm::NodeAttrs& attrs,
1676
+ mxnet::ShapeVector* in_attrs,
1677
+ mxnet::ShapeVector* out_attrs) {
1678
+ CHECK_EQ (in_attrs->size (), 4U );
1679
+ CHECK_EQ (out_attrs->size (), 1U );
1680
+
1681
+ mxnet::TShape expected_out (in_attrs->at (0 ).ndim (), -1 );
1682
+
1683
+ mxnet::TShape& weight_shape = in_attrs->at (0 );
1684
+ mxnet::TShape& g_shape = in_attrs->at (1 );
1685
+ CHECK_EQ (weight_shape.ndim (), g_shape.ndim ())
1686
+ << " total no. of dimensions for weights and g must match" ;
1687
+ for (int i=0 ; i < weight_shape.ndim (); ++i) {
1688
+ CHECK_EQ (weight_shape[i], g_shape[i])
1689
+ << " weight and g dimension size mismatch at " << i << " -th index" ;
1690
+ }
1691
+ mxnet::TShape& r1_shape = in_attrs->at (2 );
1692
+ mxnet::TShape& r2_shape = in_attrs->at (3 );
1693
+ CHECK_EQ (r1_shape[0 ], 1U ) << " r1 shape incorrect" ;
1694
+ CHECK_EQ (r2_shape[0 ], 1U ) << " r2 shape incorrect" ;
1695
+ for (int i=0 ; i < expected_out.ndim (); ++i) {
1696
+ expected_out[i] = weight_shape[i];
1697
+ }
1698
+
1699
+ SHAPE_ASSIGN_CHECK (*out_attrs, 0 , expected_out);
1700
+ return shape_is_known (expected_out);
1683
1701
}
1684
1702
1703
+ struct LambUpdatePhaseTwoKernel {
1704
+ template <typename DType>
1705
+ MSHADOW_XINLINE static void Map (int i, DType* out_data,
1706
+ const DType* weight_data, const DType* g,
1707
+ const DType* r1, const DType* r2,
1708
+ DType lr, const DType lower_bound,
1709
+ const DType upper_bound, const OpReqType req) {
1710
+ using namespace mshadow_op ;
1711
+
1712
+ DType new_r1 = r1[0 ];
1713
+ if (lower_bound >= 0 ) {
1714
+ new_r1 = maximum::Map (new_r1, lower_bound);
1715
+ }
1716
+ if (upper_bound >= 0 ) {
1717
+ new_r1 = minimum::Map (new_r1, upper_bound);
1718
+ }
1719
+ if (new_r1 == 0 .0f || r2[0 ] == 0 .0f ) {
1720
+ lr = lr * 1 .0f ;
1721
+ } else {
1722
+ lr = lr * new_r1 / r2[0 ];
1723
+ }
1724
+
1725
+ KERNEL_ASSIGN (out_data[i], req, weight_data[i] - lr * g[i]);
1726
+ }
1727
+ };
1728
+
1729
+ template <typename xpu>
1730
+ inline void LambUpdatePhaseTwo (const nnvm::NodeAttrs& attrs,
1731
+ const OpContext &ctx,
1732
+ const std::vector<TBlob> &inputs,
1733
+ const std::vector<OpReqType> &req,
1734
+ const std::vector<TBlob> &outputs) {
1735
+ using namespace mxnet_op ;
1736
+ const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed );
1737
+ Stream<xpu>* s = ctx.get_stream <xpu>();
1738
+ MSHADOW_REAL_TYPE_SWITCH (inputs[0 ].type_flag_ , DType, {
1739
+ Tensor<xpu, 2 , DType> weight = inputs[0 ].FlatTo2D <xpu, DType>(s);
1740
+ Tensor<xpu, 2 , DType> g = inputs[1 ].FlatTo2D <xpu, DType>(s);
1741
+ Tensor<xpu, 2 , DType> r1 = inputs[2 ].FlatTo2D <xpu, DType>(s);
1742
+ Tensor<xpu, 2 , DType> r2 = inputs[3 ].FlatTo2D <xpu, DType>(s);
1743
+ Tensor<xpu, 2 , DType> out = outputs[0 ].FlatTo2D <xpu, DType>(s);
1744
+
1745
+ Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch (s, weight.shape_ .Size (),
1746
+ out.dptr_ , weight.dptr_ , g.dptr_ , r1.dptr_ , r2.dptr_ ,
1747
+ static_cast <DType>(param.lr ), static_cast <DType>(param.lower_bound ),
1748
+ static_cast <DType>(param.upper_bound ), req[0 ]);
1749
+ });
1750
+ }
1685
1751
1686
1752
// This RMSProp code follows the version in
1687
1753
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
0 commit comments