Skip to content

Commit e30c612

Browse files
committed
test
1 parent 15f394d commit e30c612

File tree

1 file changed

+35
-26
lines changed

1 file changed

+35
-26
lines changed

NMT/code/model.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -84,34 +84,44 @@ def __init__(self, hidden_dim, hidden_dim_s, hidden_dim_t, strategy, bidirection
8484

8585
def forward(self, target_h, source_hs):
8686

87-
if self.strategy in ['dot','general']:
88-
source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
89-
90-
if self.strategy == 'dot':
91-
# with this strategy, no trainable parameters are involved
92-
# here, feat = hidden_dim_t = hidden_dim_s
93-
target_h = target_h.permute(1,2,0) # (1,batch,feat) -> (batch,feat,1)
94-
dot_product = torch.matmul(source_hs, target_h) # (batch,seq,feat) * (batch,feat,1) -> (batch,seq,1)
95-
scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
87+
# if self.strategy in ['dot','general']:
88+
# source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
89+
90+
# if self.strategy == 'dot':
91+
# # with this strategy, no trainable parameters are involved
92+
# # here, feat = hidden_dim_t = hidden_dim_s
93+
# target_h = target_h.permute(1,2,0) # (1,batch,feat) -> (batch,feat,1)
94+
# dot_product = torch.matmul(source_hs, target_h) # (batch,seq,feat) * (batch,feat,1) -> (batch,seq,1)
95+
# scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
9696

97-
elif self.strategy == 'general':
98-
target_h = target_h.permute(1,0,2) # (1,batch,hidden_dim_t) -> (batch,1,hidden_dim_t)
99-
output = self.ff_general(target_h) # -> (batch,1,hidden_dim_s)
100-
output = output.permute(0,2,1) # -> (batch,hidden_dim_s,1)
101-
dot_product = torch.matmul(source_hs, output) # (batch,seq,hidden_dim_s) * (batch,hidden_dim_s,1) -> (batch,seq,1)
102-
scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
97+
# elif self.strategy == 'general':
98+
# target_h = target_h.permute(1,0,2) # (1,batch,hidden_dim_t) -> (batch,1,hidden_dim_t)
99+
# output = self.ff_general(target_h) # -> (batch,1,hidden_dim_s)
100+
# output = output.permute(0,2,1) # -> (batch,hidden_dim_s,1)
101+
# dot_product = torch.matmul(source_hs, output) # (batch,seq,hidden_dim_s) * (batch,hidden_dim_s,1) -> (batch,seq,1)
102+
# scores = dot_product.permute(1,0,2) # -> (seq,batch,1)
103103

104-
elif self.strategy == 'concat':
105-
target_h_rep = target_h.repeat(source_hs.size(0),1,1) # (1,batch,hidden_dim_s) -> (seq,batch,hidden_dim_s)
106-
concat_output = self.ff_concat(torch.cat((target_h_rep,source_hs),-1)) # (seq,batch,hidden_dim_s+hidden_dim_t) -> (seq,batch,hidden_dim)
107-
scores = self.ff_score(torch.tanh(concat_output)) # -> (seq,batch,1)
108-
source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
104+
# elif self.strategy == 'concat':
105+
# target_h_rep = target_h.repeat(source_hs.size(0),1,1) # (1,batch,hidden_dim_s) -> (seq,batch,hidden_dim_s)
106+
# concat_output = self.ff_concat(torch.cat((target_h_rep,source_hs),-1)) # (seq,batch,hidden_dim_s+hidden_dim_t) -> (seq,batch,hidden_dim)
107+
# scores = self.ff_score(torch.tanh(concat_output)) # -> (seq,batch,1)
108+
# source_hs = source_hs.permute(1,0,2) # (seq,batch,hidden_dim_s) -> (batch,seq,hidden_dim_s)
109109

110-
scores = scores.squeeze(dim=2) # (seq,batch,1) -> (seq,batch). We specify a dimension, because we don't want to squeeze the batch dim in case batch size is equal to 1
111-
norm_scores = torch.softmax(scores,0) # sequence-wise normalization
112-
source_hs_p = source_hs.permute((2,1,0)) # (batch,seq,hidden_dim_s) -> (hidden_dim_s,seq,batch)
113-
weighted_source_hs = (norm_scores * source_hs_p) # (seq,batch) * (hidden_dim_s,seq,batch) -> (hidden_dim_s,seq,batch) (we use broadcasting here - the * operator checks from right to left that the dimensions match)
114-
ct = torch.sum(weighted_source_hs.permute((1,2,0)),0,keepdim=True) # (hidden_dim_s,seq,batch) -> (seq,batch,hidden_dim_s) -> (1,batch,hidden_dim_s); we need keepdim as sum squeezes by default
110+
# scores = scores.squeeze(dim=2) # (seq,batch,1) -> (seq,batch). We specify a dimension, because we don't want to squeeze the batch dim in case batch size is equal to 1
111+
# norm_scores = torch.softmax(scores,0) # sequence-wise normalization
112+
# source_hs_p = source_hs.permute((2,1,0)) # (batch,seq,hidden_dim_s) -> (hidden_dim_s,seq,batch)
113+
# weighted_source_hs = (norm_scores * source_hs_p) # (seq,batch) * (hidden_dim_s,seq,batch) -> (hidden_dim_s,seq,batch) (we use broadcasting here - the * operator checks from right to left that the dimensions match)
114+
# ct = torch.sum(weighted_source_hs.permute((1,2,0)),0,keepdim=True) # (hidden_dim_s,seq,batch) -> (seq,batch,hidden_dim_s) -> (1,batch,hidden_dim_s); we need keepdim as sum squeezes by default
115+
116+
target_h_rep = target_h.repeat(source_hs.size(0),1,1) # (1,batch,feat) -> (seq,batch,feat)
117+
concat_output = self.ff_concat(torch.cat((target_h_rep,source_hs),-1)) # source_hs is (seq,batch,feat)
118+
scores = self.ff_score(torch.tanh(concat_output)) # (seq,batch,feat) -> (seq,batch,1)
119+
scores = scores.squeeze(dim=2) # (seq,batch,1) -> (seq,batch). dim=2 because we don't want to squeeze the batch dim if batch size = 1
120+
norm_scores = torch.softmax(scores,0)
121+
source_hs_p = source_hs.permute((2,0,1)) # (seq,batch,feat) -> (feat,seq,batch)
122+
weighted_source_hs = (norm_scores * source_hs_p) # (seq,batch) * (feat,seq,batch) (* checks from right to left that the dimensions match)
123+
ct = torch.sum(weighted_source_hs.permute((1,2,0)),0,keepdim=True) # (feat,seq,batch) -> (seq,batch,feat) -> (1,batch,feat); keepdim otherwise sum squeezes
124+
115125

116126
return ct
117127

@@ -186,7 +196,6 @@ def __init__(self, vocab_s, source_language, vocab_t_inv, embedding_dim_s, embed
186196
self.decoder = Decoder(self.max_target_idx+1, self.embedding_dim_t, self.hidden_dim_t, self.hidden_dim_s, self.num_layers, self.bidirectional, self.padding_token, self.dropout).to(self.device)
187197

188198
if not self.att_strategy == 'none':
189-
print(self.att_strategy)
190199
self.att_mech = seq2seqAtt(self.hidden_dim_att, self.hidden_dim_s, self.hidden_dim_t, self.att_strategy, self.bidirectional).to(self.device)
191200

192201
def my_pad(self, my_list):

0 commit comments

Comments
 (0)