Skip to content

Commit 88e921b

Browse files
committed
+eval code for 27M ppl 1.65 BPC 0.72 enwik8 model
1 parent 71538e4 commit 88e921b

File tree

4 files changed

+90
-59
lines changed

4 files changed

+90
-59
lines changed

RWKV-v2-RNN/run.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@
44
########################################################################################################
55

66
import numpy as np
7+
import math
78
import time
89
import types
910
import copy
1011
import torch
1112
from torch.nn import functional as F
12-
from src.utils import TOKENIZER
13+
from src.utils import TOKENIZER, Dataset
1314
from src.model_run import RWKV_RNN
1415
torch.backends.cudnn.benchmark = True
1516
torch.backends.cudnn.allow_tf32 = True
1617
torch.backends.cuda.matmul.allow_tf32 = True
18+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
1719

1820
### Step 1: set model ##################################################################################
1921

@@ -26,9 +28,11 @@
2628
MODEL_NAME = 'trained-31'
2729
WORD_NAME = 'vocab' # the .json vocab (generated by train.py
2830

29-
# ### uncompress enwik8-model.zip to test my enwik8 model
31+
# ########## Uncomment these to test my 27M params enwik8 model ##########
3032
# MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13'
3133
# WORD_NAME = 'enwik8-vocab'
34+
# EVAL_DATA = 'enwik8' # uncomment this for EVAL MODE (no text generation)
35+
# ########################################################################
3236

3337
# --> set UNKNOWN_CHAR to the rarest token in your vocab.json <--
3438
# --> all unknown tokens in your context will be denoted by it <--
@@ -50,16 +54,44 @@
5054

5155
########################################################################################################
5256

53-
np.set_printoptions(precision=4, suppress=True, linewidth=200)
54-
57+
print(f'Loading {MODEL_NAME}...')
58+
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
5559
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
60+
61+
########################################################################################################
62+
63+
if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals():
64+
print('Evaluating on ' + EVAL_DATA + ' ...')
65+
66+
data = open(EVAL_DATA, "r", encoding='utf-8').read()
67+
68+
loss_table = np.zeros(ctx_len)
69+
70+
N_SAMPLE = 1000
71+
72+
for iii in range(N_SAMPLE):
73+
pos = np.random.randint(0, len(data) - ctx_len-1)
74+
context = data[pos:pos+ctx_len+1]
75+
ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context]
76+
77+
model.clear()
78+
for i in range(1, ctx_len+1):
79+
x = ctx[:i]
80+
out = model.run(x)
81+
prob = F.softmax(torch.tensor(out), dim=-1)
82+
loss_table[i-1] += -math.log(prob[ctx[i]])
83+
84+
print(f'Tested {iii+1} samples: avg_loss over ctx_len =',
85+
np.mean(loss_table) / (iii+1))
86+
87+
exit(0)
88+
89+
########################################################################################################
90+
5691
context = tokenizer.refine_context(context)
5792
print('\nYour prompt has ' + str(len(context)) + ' tokens.')
5893
print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n')
5994

60-
print(f'Loading {MODEL_NAME}...')
61-
model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len)
62-
6395
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
6496
t_begin = time.time_ns()
6597

RWKV-v2-RNN/src/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
logger = logging.getLogger(__name__)
2323
torch.backends.cudnn.benchmark = True
2424
torch.backends.cudnn.allow_tf32 = True
25-
torch.backends.cuda.matmul.allow_tf32 = True
25+
torch.backends.cuda.matmul.allow_tf32 = True
2626

2727
log_file = open("mylog.txt", "a")
2828

@@ -151,7 +151,7 @@ def run_epoch(split):
151151
self.avg_loss = self.avg_loss * \
152152
(1.0 - factor) + now_loss * factor
153153
pbar.set_description(
154-
f"epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
154+
f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")
155155

156156
self.tokens = 0 # counter used for learning rate decay
157157
for epoch in range(config.max_epochs):

RWKV-v2-RNN/src/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,48 @@
1010
import torch
1111
import torch.nn as nn
1212
from torch.nn import functional as F
13+
from torch.utils.data import Dataset
14+
15+
16+
class Dataset(Dataset):
17+
def __init__(self, data, ctx_len, epoch_length_fixed):
18+
print('building token list...', end=' ')
19+
unique = sorted(list(set(data)))
20+
# print()
21+
# for u in unique:
22+
# print(u, end=' ')
23+
# print('\n\n')
24+
25+
xx = 0
26+
xxObj = {}
27+
for u in unique:
28+
xxObj[xx] = u
29+
xx += 1
30+
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
31+
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
32+
33+
data_size, vocab_size = len(data), len(unique)
34+
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
35+
self.stoi = {ch: i for i, ch in enumerate(unique)}
36+
self.itos = {i: ch for i, ch in enumerate(unique)}
37+
self.ctx_len = ctx_len
38+
self.epoch_length_fixed = epoch_length_fixed
39+
self.vocab_size = vocab_size
40+
self.data = data
41+
42+
def __len__(self):
43+
return self.epoch_length_fixed
44+
45+
def __getitem__(self, idx):
46+
# cheat: pick a random spot in dataset
47+
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
48+
chunk = self.data[i:i+self.ctx_len+1]
49+
dix = [self.stoi[s] for s in chunk]
50+
x = torch.tensor(dix[:-1], dtype=torch.long,
51+
device=torch.device('cuda'))
52+
y = torch.tensor(dix[1:], dtype=torch.long,
53+
device=torch.device('cuda'))
54+
return x, y
1355

1456

1557
class TOKENIZER():

RWKV-v2-RNN/train.py

Lines changed: 7 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import json
88
from src.model import GPT, GPTConfig
99
from src.trainer import Trainer, TrainerConfig
10-
from torch.utils.data import Dataset
10+
from src.utils import Dataset
1111
import torch
1212
import numpy as np
1313
torch.backends.cudnn.benchmark = True
1414
torch.backends.cudnn.allow_tf32 = True
15-
torch.backends.cuda.matmul.allow_tf32 = True
15+
torch.backends.cuda.matmul.allow_tf32 = True
1616

1717
### Step 1: set training data ##########################################################################
1818

@@ -36,21 +36,20 @@
3636
# If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM.
3737
batch_size = 12
3838

39-
### Step 4: set learning rate, training 'epochs' #######################################################
39+
### Step 4: set learning rate, training mini-epochs #######################################################
4040

4141
lr_init = 6e-4
4242
lr_final = 1e-5
43-
# the 'epoch' here is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
43+
# the mini-epoch is very short and of fixed length (ctx_len * epoch_length_fixed tokens)
4444
n_epoch = 500
45-
# 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc.
45+
# 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc.
4646
epoch_save_frequency = 30
4747
epoch_save_path = 'trained-'
4848

4949
epoch_length_fixed = 10000
5050

5151
########################################################################################################
5252

53-
5453
# import src.utils
5554
# src.utils.set_seed(42) # remember to change seed if you load a model
5655

@@ -71,50 +70,8 @@
7170
########################################################################################################
7271

7372
print('loading data... ' + datafile)
74-
75-
76-
class Dataset(Dataset):
77-
def __init__(self, data, ctx_len):
78-
print('building token list...', end=' ')
79-
unique = sorted(list(set(data)))
80-
# print()
81-
# for u in unique:
82-
# print(u, end=' ')
83-
# print('\n\n')
84-
85-
xx = 0
86-
xxObj = {}
87-
for u in unique:
88-
xxObj[xx] = u
89-
xx += 1
90-
with open('vocab.json', "w", encoding="utf-16") as vocab_file:
91-
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
92-
93-
data_size, vocab_size = len(data), len(unique)
94-
print('data has %d tokens, %d unique.' % (data_size, vocab_size))
95-
self.stoi = {ch: i for i, ch in enumerate(unique)}
96-
self.itos = {i: ch for i, ch in enumerate(unique)}
97-
self.ctx_len = ctx_len
98-
self.vocab_size = vocab_size
99-
self.data = data
100-
101-
def __len__(self):
102-
return epoch_length_fixed
103-
104-
def __getitem__(self, idx):
105-
# cheat: pick a random spot in dataset
106-
i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
107-
chunk = self.data[i:i+self.ctx_len+1]
108-
dix = [self.stoi[s] for s in chunk]
109-
x = torch.tensor(dix[:-1], dtype=torch.long,
110-
device=torch.device('cuda'))
111-
y = torch.tensor(dix[1:], dtype=torch.long,
112-
device=torch.device('cuda'))
113-
return x, y
114-
115-
116-
train_dataset = Dataset(
117-
open(datafile, "r", encoding=datafile_encoding).read(), ctx_len)
73+
train_dataset = Dataset(open(
74+
datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed)
11875

11976
########################################################################################################
12077
# Train model

0 commit comments

Comments
 (0)