-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpartialconv1d.py
70 lines (61 loc) · 2.97 KB
/
partialconv1d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Modified partialconv source code based on implementation from
# https://github.com/NVIDIA/partialconv/blob/master/models/partialconv2d.py
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu ([email protected])
###############################################################################
# Original Author & Contact: Guilin Liu ([email protected])
# Modified by Kevin Shih ([email protected])
import torch
import torch.nn.functional as F
from torch import nn
class PartialConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
self.multi_channel = False
self.return_mask = False
super(PartialConv1d, self).__init__(*args, **kwargs)
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0])
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]
self.last_size = (None, None, None)
self.update_mask = None
self.mask_ratio = None
@torch.jit.ignore
def forward(self, input: torch.Tensor, mask_in : torch.Tensor = None):
"""
input: standard input to a 1D conv
mask_in: binary mask for valid values, same shape as input
"""
assert len(input.shape) == 3
# if a mask is input, or tensor shape changed, update mask ratio
if mask_in is not None or self.last_size != tuple(input.shape):
self.last_size = tuple(input.shape)
with torch.no_grad():
if self.weight_maskUpdater.type() != input.type():
self.weight_maskUpdater = self.weight_maskUpdater.to(input)
if mask_in is None:
mask = torch.ones(1, 1, input.data.shape[2]).to(input)
else:
mask = mask_in
self.update_mask = F.conv1d(mask, self.weight_maskUpdater,
bias=None, stride=self.stride,
padding=self.padding,
dilation=self.dilation, groups=1)
# for mixed precision training, change 1e-8 to 1e-6
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-6)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
raw_out = super(PartialConv1d, self).forward(
torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output