Skip to content

Commit 3b9005e

Browse files
committed
RWKV: now faster and less params
1 parent 546114c commit 3b9005e

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ def __init__(self, config, layer_id):
6868
self.layer_id = layer_id
6969
self.time_shift = nn.ZeroPad2d((0,0,1,0))
7070

71-
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
72-
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
73-
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
71+
hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R
72+
self.key = nn.Linear(config.n_embd, hidden_sz)
73+
self.value = nn.Linear(config.n_embd, hidden_sz)
74+
self.weight = nn.Linear(hidden_sz, config.n_embd)
7475
self.receptance = nn.Linear(config.n_embd, config.n_embd)
7576

7677
def forward(self, x):
@@ -166,9 +167,10 @@ class GeGLU(torch.nn.Module):
166167
def __init__(self, config, layer_id):
167168
super().__init__()
168169
self.layer_id = layer_id
169-
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
170-
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
171-
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
170+
hidden_sz = 3 * config.n_embd
171+
self.key = nn.Linear(config.n_embd, hidden_sz)
172+
self.value = nn.Linear(config.n_embd, hidden_sz)
173+
self.weight = nn.Linear(hidden_sz, config.n_embd)
172174

173175
def forward(self, x):
174176
k = self.key(x)

0 commit comments

Comments
 (0)