@@ -84,7 +84,7 @@ def __init__(self, config, layer_id):
84
84
self .time_gamma = nn .Parameter (torch .ones (config .ctx_len , 1 ))
85
85
self .register_buffer ("mask" , torch .tril (torch .ones (config .ctx_len , config .ctx_len )))
86
86
87
- self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,0 ))
87
+ self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,- 1 ))
88
88
89
89
self .key = nn .Linear (config .n_embd , config .n_attn )
90
90
self .value = nn .Linear (config .n_embd , config .n_attn )
@@ -110,15 +110,15 @@ def forward(self, x):
110
110
self .mask = self .mask [:T , :T ]
111
111
w = w .masked_fill (self .mask == 0 , 0 )
112
112
113
- x = torch .cat ([self .time_shift (x ) [:, :- 1 , :C // 2 ], x [:, :, C // 2 :]], dim = - 1 )
113
+ x = torch .cat ([self .time_shift (x [:, :, :C // 2 ]) , x [:, :, C // 2 :]], dim = - 1 )
114
114
if hasattr (self , 'tiny_att' ):
115
115
tiny_att = self .tiny_att (x , self .mask )
116
116
117
117
k = self .key (x )
118
118
v = self .value (x )
119
119
r = self .receptance (x )
120
120
121
- k = torch .clamp (k , max = 30 ) # clamp extreme values. e^30 = 10^13
121
+ k = torch .clamp (k , max = 30 , min = - 60 ) # clamp extreme values. e^30 = 10^13
122
122
k = torch .exp (k )
123
123
sum_k = torch .cumsum (k , dim = 1 )
124
124
@@ -138,7 +138,7 @@ class RWKV_ChannelMix(nn.Module):
138
138
def __init__ (self , config , layer_id ):
139
139
super ().__init__ ()
140
140
self .layer_id = layer_id
141
- self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,0 ))
141
+ self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,- 1 ))
142
142
143
143
hidden_sz = 5 * config .n_ffn // 2 # can use smaller hidden_sz because of receptance gating
144
144
self .key = nn .Linear (config .n_embd , hidden_sz )
@@ -152,7 +152,7 @@ def __init__(self, config, layer_id):
152
152
def forward (self , x ):
153
153
B , T , C = x .size ()
154
154
155
- x = torch .cat ([self .time_shift (x ) [:, :- 1 , :C // 2 ], x [:, :, C // 2 :]], dim = - 1 )
155
+ x = torch .cat ([self .time_shift (x [:, :, :C // 2 ]) , x [:, :, C // 2 :]], dim = - 1 )
156
156
k = self .key (x )
157
157
v = self .value (x )
158
158
r = self .receptance (x )
@@ -235,7 +235,7 @@ def __init__(self, config, layer_id, time_shift = False):
235
235
self .head_size = config .n_attn // config .n_head
236
236
237
237
if time_shift :
238
- self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,0 ))
238
+ self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,- 1 ))
239
239
240
240
self .query = nn .Linear (config .n_embd , config .n_attn )
241
241
self .key = nn .Linear (config .n_embd , config .n_attn )
@@ -252,7 +252,7 @@ def forward(self, x):
252
252
B , T , C = x .size ()
253
253
254
254
if hasattr (self , 'time_shift' ):
255
- x = torch .cat ([self .time_shift (x ) [:, :- 1 , :C // 2 ], x [:, :, C // 2 :]], dim = - 1 )
255
+ x = torch .cat ([self .time_shift (x [:, :, :C // 2 ]) , x [:, :, C // 2 :]], dim = - 1 )
256
256
257
257
q = self .query (x ).view (B , T , self .n_head , self .head_size ).transpose (1 , 2 ) # (B, T, C) -> (B, nh, T, hs)
258
258
k = self .key (x ).view (B , T , self .n_head , self .head_size ).transpose (1 , 2 ) # (B, T, C) -> (B, nh, T, hs)
@@ -281,7 +281,7 @@ def __init__(self, config, layer_id, time_shift = False):
281
281
self .layer_id = layer_id
282
282
283
283
if time_shift :
284
- self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,0 ))
284
+ self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,- 1 ))
285
285
286
286
hidden_sz = 3 * config .n_ffn
287
287
self .key = nn .Linear (config .n_embd , hidden_sz )
@@ -291,7 +291,7 @@ def __init__(self, config, layer_id, time_shift = False):
291
291
def forward (self , x ):
292
292
B , T , C = x .size ()
293
293
if hasattr (self , 'time_shift' ):
294
- x = torch .cat ([self .time_shift (x ) [:, :- 1 , :C // 2 ], x [:, :, C // 2 :]], dim = - 1 )
294
+ x = torch .cat ([self .time_shift (x [:, :, :C // 2 ]) , x [:, :, C // 2 :]], dim = - 1 )
295
295
296
296
k = self .key (x )
297
297
v = self .value (x )
@@ -317,7 +317,7 @@ def __init__(self, config, layer_id):
317
317
self .time_gamma = nn .Parameter (torch .ones (config .ctx_len , 1 ))
318
318
self .register_buffer ("mask" , torch .tril (torch .ones (config .ctx_len , config .ctx_len )))
319
319
320
- self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,0 ))
320
+ self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,- 1 ))
321
321
self .query = nn .Linear (config .n_embd , config .n_attn )
322
322
self .key = nn .Linear (config .n_embd , config .n_attn )
323
323
self .value = nn .Linear (config .n_embd , config .n_attn )
@@ -338,7 +338,7 @@ def forward(self, x):
338
338
w = w [:, :, TT - 1 :] # w is now a circulant matrix
339
339
w = w [:, :T , :T ] * self .time_alpha [:, :, :T ] * self .time_beta [:, :T , :]
340
340
341
- x = torch .cat ([self .time_shift (x ) [:, :- 1 , :C // 2 ], x [:, :, C // 2 :]], dim = - 1 ) # time-shift mixing
341
+ x = torch .cat ([self .time_shift (x [:, :, :C // 2 ]) , x [:, :, C // 2 :]], dim = - 1 ) # time-shift mixing
342
342
q = self .query (x ).view (B , T , self .n_head , self .head_size ).transpose (1 , 2 ) # (B, T, C) -> (B, nh, T, hs)
343
343
k = self .key (x ).view (B , T , self .n_head , self .head_size ).transpose (1 , 2 ) # (B, T, C) -> (B, nh, T, hs)
344
344
v = self .value (x ).view (B , T , self .n_head , self .head_size ).transpose (1 , 2 ) # (B, T, C) -> (B, nh, T, hs)
0 commit comments