Skip to content

Commit 264e76c

Browse files
authored
[cherrypick-beta-2.0]Fix kl_div,conv and summary api bug (#27195)
* fix some bug
1 parent ed52b00 commit 264e76c

File tree

8 files changed

+115
-32
lines changed

8 files changed

+115
-32
lines changed

paddle/fluid/operators/kldiv_loss_op.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ class KLDivLossKernel : public framework::OpKernel<T> {
7272
loss_t.device(place) = output;
7373
} else if ("batchmean" == reduction) {
7474
auto output_sum = output.sum();
75-
loss_t.device(place) = output_sum / output_sum.constant(n);
75+
if (n > 0) {
76+
loss_t.device(place) = output_sum / output_sum.constant(n);
77+
} else {
78+
loss_t.device(place) = output_sum;
79+
}
7680
} else if ("mean" == reduction) {
7781
loss_t.device(place) = output.mean();
7882
} else if ("sum" == reduction) {

python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ def kldiv_loss(x, target, reduction):
2424
loss = np.where(target >= 0, output, np.zeros_like(x))
2525

2626
if reduction == "batchmean":
27-
return loss.sum() / x.shape[0]
27+
if len(x.shape) > 0:
28+
return loss.sum() / x.shape[0]
29+
else:
30+
return loss.sum()
2831
if reduction == "mean":
2932
return loss.mean()
3033
if reduction == "sum":
@@ -93,6 +96,9 @@ def run_kl_loss(self, reduction, shape=(5, 20)):
9396
def test_kl_loss_batchmean(self):
9497
self.run_kl_loss('batchmean')
9598

99+
def test_kl_loss_batchmean_shape(self):
100+
self.run_kl_loss('batchmean', ())
101+
96102
def test_kl_loss_mean(self):
97103
self.run_kl_loss('mean')
98104

python/paddle/hapi/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,8 +1868,13 @@ def summary(self, input_size=None, batch_size=None, dtype=None):
18681868
print(params_info)
18691869
18701870
"""
1871-
1872-
return summary(self.network, self._inputs, batch_size, dtype)
1871+
assert (input_size is not None or self._inputs is not None
1872+
), "'input_size' or 'self._input' must be set"
1873+
if input_size is not None:
1874+
_input_size = input_size
1875+
else:
1876+
_input_size = self._inputs
1877+
return summary(self.network, _input_size, batch_size, dtype)
18731878

18741879
def _verify_spec(self, specs, is_input=False):
18751880
out_specs = []

python/paddle/hapi/model_summary.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
import numpy as np
17+
import numbers
1618

1719
import paddle
1820
import paddle.nn as nn
@@ -86,8 +88,10 @@ def forward(self, inputs):
8688
elif isinstance(input_size, list):
8789
_input_size = []
8890
for item in input_size:
91+
if isinstance(item, int):
92+
item = (item, )
8993
assert isinstance(item,
90-
(list, InputSpec)), 'When input_size is list, \
94+
(tuple, InputSpec)), 'When input_size is list, \
9195
expect item in input_size is a tuple or InputSpec, but got {}'.format(
9296
type(item))
9397

@@ -97,12 +101,19 @@ def forward(self, inputs):
97101
batch_size = item.shape[0]
98102
else:
99103
_input_size.append(item)
104+
elif isinstance(input_size, int):
105+
_input_size = (input_size, )
100106
else:
101107
_input_size = input_size
102108

103109
if batch_size is None:
104110
batch_size = -1
105111

112+
if not paddle.in_dynamic_mode():
113+
warnings.warn(
114+
"Your model was created in static mode, this may not get correct summary information!"
115+
)
116+
106117
result, params_info = summary_string(net, _input_size, batch_size, dtypes)
107118
print(result)
108119

@@ -117,16 +128,16 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
117128

118129
depth = len(list(model.sublayers()))
119130

120-
def register_hook(module):
121-
def hook(module, input, output):
122-
class_name = str(module.__class__).split(".")[-1].split("'")[0]
131+
def register_hook(layer):
132+
def hook(layer, input, output):
133+
class_name = str(layer.__class__).split(".")[-1].split("'")[0]
123134

124135
try:
125-
module_idx = int(module._full_name.split('_')[-1])
136+
layer_idx = int(layer._full_name.split('_')[-1])
126137
except:
127-
module_idx = len(summary)
138+
layer_idx = len(summary)
128139

129-
m_key = "%s-%i" % (class_name, module_idx + 1)
140+
m_key = "%s-%i" % (class_name, layer_idx + 1)
130141
summary[m_key] = OrderedDict()
131142
summary[m_key]["input_shape"] = list(input[0].shape)
132143
summary[m_key]["input_shape"][0] = batch_size
@@ -138,23 +149,50 @@ def hook(module, input, output):
138149
summary[m_key]["output_shape"][0] = batch_size
139150

140151
params = 0
141-
if hasattr(module, "weight"):
142-
params += np.prod(module.weight.shape)
143-
summary[m_key]["trainable"] = module.weight.trainable or (
144-
not module.weight.stop_gradient)
145-
if hasattr(module, "bias"):
146-
params += np.prod(module.bias.shape)
152+
153+
if paddle.in_dynamic_mode():
154+
layer_state_dict = layer._parameters
155+
else:
156+
layer_state_dict = layer.state_dict()
157+
158+
for k, v in layer_state_dict.items():
159+
params += np.prod(v.shape)
160+
161+
try:
162+
if (getattr(getattr(layer, k), 'trainable')) and (
163+
not getattr(getattr(layer, k), 'stop_gradient')):
164+
summary[m_key]["trainable"] = True
165+
else:
166+
summary[m_key]["trainable"] = False
167+
except:
168+
summary[m_key]["trainable"] = True
169+
147170
summary[m_key]["nb_params"] = params
148171

149-
if (not isinstance(module, nn.Sequential) and
150-
not isinstance(module, nn.LayerList) and
151-
(not (module == model) or depth < 1)):
172+
if (not isinstance(layer, nn.Sequential) and
173+
not isinstance(layer, nn.LayerList) and
174+
(not (layer == model) or depth < 1)):
175+
176+
hooks.append(layer.register_forward_post_hook(hook))
177+
178+
def _check_input_size(input_sizes):
179+
for input_size in input_sizes:
180+
for item in input_size:
181+
if not isinstance(item, numbers.Number):
182+
raise TypeError(
183+
"Expected item in input size be a number, but got {}".
184+
format(type(item)))
152185

153-
hooks.append(module.register_forward_post_hook(hook))
186+
if item <= 0:
187+
raise ValueError(
188+
"Expected item in input size greater than zero, but got {}".
189+
format(item))
154190

155191
if isinstance(input_size, tuple):
156192
input_size = [input_size]
157193

194+
_check_input_size(input_size)
195+
158196
x = [
159197
paddle.rand(
160198
[2] + list(in_size), dtype=dtype)
@@ -193,7 +231,12 @@ def hook(module, input, output):
193231
"{0:,}".format(summary[layer]["nb_params"]), )
194232
total_params += summary[layer]["nb_params"]
195233

196-
total_output += np.prod(summary[layer]["output_shape"])
234+
try:
235+
total_output += np.prod(summary[layer]["output_shape"])
236+
except:
237+
for output_shape in summary[layer]["output_shape"]:
238+
total_output += np.prod(output_shape)
239+
197240
if "trainable" in summary[layer]:
198241
if summary[layer]["trainable"] == True:
199242
trainable_params += summary[layer]["nb_params"]

python/paddle/nn/functional/loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,10 +780,10 @@ def kl_div(input, label, reduction='mean', name=None):
780780
input = np.random.uniform(-10, 10, shape).astype('float32')
781781
target = np.random.uniform(-10, 10, shape).astype('float32')
782782
783-
# 'batchmean' reduction, loss shape will be [N]
783+
# 'batchmean' reduction, loss shape will be [1]
784784
pred_loss = F.kl_div(paddle.to_tensor(input),
785785
paddle.to_tensor(target), reduction='batchmean')
786-
# shape=[5]
786+
# shape=[1]
787787
788788
# 'mean' reduction, loss shape will be [1]
789789
pred_loss = F.kl_div(paddle.to_tensor(input),

python/paddle/nn/layer/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,7 @@ def __init__(self,
10841084
bias_attr=bias_attr,
10851085
data_format=data_format)
10861086

1087-
def forward(self, x, output_size):
1087+
def forward(self, x, output_size=None):
10881088
if output_size is None:
10891089
output_padding = self.output_padding
10901090
else:

python/paddle/nn/layer/loss.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -627,10 +627,13 @@ class KLDivLoss(fluid.dygraph.Layer):
627627
$$l(x, y) = y * (\log(y) - x)$$
628628
629629
Parameters:
630-
reduction (str, optional): Indicate how to average the loss,
631-
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
632-
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
633-
Default is ``'mean'``.
630+
reduction (Tensor): Indicate how to average the loss,
631+
the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
632+
If `reduction` is ``'mean'``, the reduced mean loss is returned;
633+
If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
634+
if `reduction` is ``'sum'``, the reduced sum loss is returned;
635+
if `reduction` is ``'none'``, no reduction will be apllied.
636+
Default is ``'mean'``.
634637
635638
Shape:
636639
@@ -654,11 +657,11 @@ class KLDivLoss(fluid.dygraph.Layer):
654657
x = np.random.uniform(-10, 10, shape).astype('float32')
655658
target = np.random.uniform(-10, 10, shape).astype('float32')
656659
657-
# 'batchmean' reduction, loss shape will be [N]
660+
# 'batchmean' reduction, loss shape will be [1]
658661
kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
659662
pred_loss = kldiv_criterion(paddle.to_tensor(x),
660663
paddle.to_tensor(target))
661-
# shape=[5]
664+
# shape=[1]
662665
663666
# 'mean' reduction, loss shape will be [1]
664667
kldiv_criterion = nn.KLDivLoss(reduction='mean')
@@ -684,7 +687,7 @@ def __init__(self, reduction='mean'):
684687
self.reduction = reduction
685688

686689
def forward(self, input, label):
687-
out = paddle.nn.functional.kl_div(input, label, self.reduction)
690+
out = F.kl_div(input, label, self.reduction)
688691
return out
689692

690693

python/paddle/tests/test_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,28 @@ def _get_param_from_state_dict(state_dict):
519519
np.testing.assert_allclose(params_info['total_params'], gt_params)
520520
print(params_info)
521521

522+
model.summary(input_size=(20))
523+
model.summary(input_size=[(20)])
524+
model.summary(input_size=(20), batch_size=2)
525+
526+
def test_summary_nlp(self):
527+
paddle.enable_static()
528+
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
529+
paddle.summary(nlp_net, (1, 2))
530+
531+
def test_summary_error(self):
532+
with self.assertRaises(TypeError):
533+
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
534+
paddle.summary(nlp_net, (1, '2'))
535+
536+
with self.assertRaises(ValueError):
537+
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
538+
paddle.summary(nlp_net, (-1, -1))
539+
540+
paddle.disable_static()
541+
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
542+
paddle.summary(nlp_net, (1, 2))
543+
522544
def test_export_deploy_model(self):
523545
for dynamic in [True, False]:
524546
fluid.enable_dygraph() if dynamic else None

0 commit comments

Comments
 (0)