@@ -218,7 +218,7 @@ def __init__(self, in_channel, out_channel,
218
218
kernel_size = kernel_size ,
219
219
stride = stride ,
220
220
padding = 1 )
221
- self ._odim = update_lens_1d ( torch . IntTensor ([ in_channel ]), self . conv1 )[ 0 ]. item ()
221
+ self ._odim = out_channel
222
222
self .batch_norm1 = nn .BatchNorm1d (out_channel ) if batch_norm else lambda x : x
223
223
self .layer_norm1 = nn .LayerNorm (out_channel ,
224
224
eps = layer_norm_eps ) if layer_norm else lambda x : x
@@ -229,7 +229,7 @@ def __init__(self, in_channel, out_channel,
229
229
kernel_size = kernel_size ,
230
230
stride = stride ,
231
231
padding = 1 )
232
- self ._odim = update_lens_1d ( torch . IntTensor ([ self . _odim ]), self . conv2 )[ 0 ]. item ()
232
+ self ._odim = out_channel
233
233
self .batch_norm2 = nn .BatchNorm1d (out_channel ) if batch_norm else lambda x : x
234
234
self .layer_norm2 = nn .LayerNorm (out_channel ,
235
235
eps = layer_norm_eps ) if layer_norm else lambda x : x
@@ -242,7 +242,7 @@ def __init__(self, in_channel, out_channel,
242
242
padding = 0 ,
243
243
ceil_mode = True )
244
244
# NOTE: If ceil_mode is False, remove last feature when the dimension of features are odd.
245
- self ._odim = update_lens_1d ( torch . IntTensor ([ self ._odim ]), self . pool )[ 0 ]. item ()
245
+ self ._odim = self ._odim
246
246
if self ._odim % 2 != 0 :
247
247
self ._odim = (self ._odim // 2 ) * 2
248
248
# TODO(hirofumi0810): more efficient way?
0 commit comments