@@ -68,9 +68,10 @@ def __init__(self, config, layer_id):
68
68
self .layer_id = layer_id
69
69
self .time_shift = nn .ZeroPad2d ((0 ,0 ,1 ,0 ))
70
70
71
- self .key = nn .Linear (config .n_embd , 3 * config .n_embd )
72
- self .value = nn .Linear (config .n_embd , 3 * config .n_embd )
73
- self .weight = nn .Linear (3 * config .n_embd , config .n_embd )
71
+ hidden_sz = 5 * config .n_embd // 2 # can use smaller hidden_sz because of R
72
+ self .key = nn .Linear (config .n_embd , hidden_sz )
73
+ self .value = nn .Linear (config .n_embd , hidden_sz )
74
+ self .weight = nn .Linear (hidden_sz , config .n_embd )
74
75
self .receptance = nn .Linear (config .n_embd , config .n_embd )
75
76
76
77
def forward (self , x ):
@@ -166,9 +167,10 @@ class GeGLU(torch.nn.Module):
166
167
def __init__ (self , config , layer_id ):
167
168
super ().__init__ ()
168
169
self .layer_id = layer_id
169
- self .key = nn .Linear (config .n_embd , 3 * config .n_embd )
170
- self .value = nn .Linear (config .n_embd , 3 * config .n_embd )
171
- self .weight = nn .Linear (3 * config .n_embd , config .n_embd )
170
+ hidden_sz = 3 * config .n_embd
171
+ self .key = nn .Linear (config .n_embd , hidden_sz )
172
+ self .value = nn .Linear (config .n_embd , hidden_sz )
173
+ self .weight = nn .Linear (hidden_sz , config .n_embd )
172
174
173
175
def forward (self , x ):
174
176
k = self .key (x )
0 commit comments