Skip to content

Commit 619ed00

Browse files
committed
misc improvement
1 parent a36fc09 commit 619ed00

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/model.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, config, layer_id):
8484
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
8585
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
8686

87-
self.time_shift = nn.ZeroPad2d((0,0,1,0))
87+
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
8888

8989
self.key = nn.Linear(config.n_embd, config.n_attn)
9090
self.value = nn.Linear(config.n_embd, config.n_attn)
@@ -110,15 +110,15 @@ def forward(self, x):
110110
self.mask = self.mask[:T, :T]
111111
w = w.masked_fill(self.mask == 0, 0)
112112

113-
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
113+
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
114114
if hasattr(self, 'tiny_att'):
115115
tiny_att = self.tiny_att(x, self.mask)
116116

117117
k = self.key(x)
118118
v = self.value(x)
119119
r = self.receptance(x)
120120

121-
k = torch.clamp(k, max=30) # clamp extreme values. e^30 = 10^13
121+
k = torch.clamp(k, max=30, min=-60) # clamp extreme values. e^30 = 10^13
122122
k = torch.exp(k)
123123
sum_k = torch.cumsum(k, dim=1)
124124

@@ -138,7 +138,7 @@ class RWKV_ChannelMix(nn.Module):
138138
def __init__(self, config, layer_id):
139139
super().__init__()
140140
self.layer_id = layer_id
141-
self.time_shift = nn.ZeroPad2d((0,0,1,0))
141+
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
142142

143143
hidden_sz = 5 * config.n_ffn // 2 # can use smaller hidden_sz because of receptance gating
144144
self.key = nn.Linear(config.n_embd, hidden_sz)
@@ -152,7 +152,7 @@ def __init__(self, config, layer_id):
152152
def forward(self, x):
153153
B, T, C = x.size()
154154

155-
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
155+
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
156156
k = self.key(x)
157157
v = self.value(x)
158158
r = self.receptance(x)
@@ -235,7 +235,7 @@ def __init__(self, config, layer_id, time_shift = False):
235235
self.head_size = config.n_attn // config.n_head
236236

237237
if time_shift:
238-
self.time_shift = nn.ZeroPad2d((0,0,1,0))
238+
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
239239

240240
self.query = nn.Linear(config.n_embd, config.n_attn)
241241
self.key = nn.Linear(config.n_embd, config.n_attn)
@@ -252,7 +252,7 @@ def forward(self, x):
252252
B, T, C = x.size()
253253

254254
if hasattr(self, 'time_shift'):
255-
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
255+
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
256256

257257
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
258258
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
@@ -281,7 +281,7 @@ def __init__(self, config, layer_id, time_shift = False):
281281
self.layer_id = layer_id
282282

283283
if time_shift:
284-
self.time_shift = nn.ZeroPad2d((0,0,1,0))
284+
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
285285

286286
hidden_sz = 3 * config.n_ffn
287287
self.key = nn.Linear(config.n_embd, hidden_sz)
@@ -291,7 +291,7 @@ def __init__(self, config, layer_id, time_shift = False):
291291
def forward(self, x):
292292
B, T, C = x.size()
293293
if hasattr(self, 'time_shift'):
294-
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
294+
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
295295

296296
k = self.key(x)
297297
v = self.value(x)
@@ -317,7 +317,7 @@ def __init__(self, config, layer_id):
317317
self.time_gamma = nn.Parameter(torch.ones(config.ctx_len, 1))
318318
self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
319319

320-
self.time_shift = nn.ZeroPad2d((0,0,1,0))
320+
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
321321
self.query = nn.Linear(config.n_embd, config.n_attn)
322322
self.key = nn.Linear(config.n_embd, config.n_attn)
323323
self.value = nn.Linear(config.n_embd, config.n_attn)
@@ -338,7 +338,7 @@ def forward(self, x):
338338
w = w[:, :, TT-1:] # w is now a circulant matrix
339339
w = w[:, :T, :T] * self.time_alpha[:, :, :T] * self.time_beta[:, :T, :]
340340

341-
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1) # time-shift mixing
341+
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1) # time-shift mixing
342342
q = self.query(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
343343
k = self.key(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)
344344
v = self.value(x).view(B, T, self.n_head, self.head_size).transpose(1, 2) # (B, T, C) -> (B, nh, T, hs)

0 commit comments

Comments
 (0)