|
1 |
| -import torch |
2 |
| -from .ChatRWKV.src.model_run import RWKV_RNN |
3 |
| -from .ChatRWKV.src.utils import TOKENIZER |
4 |
| -import os |
5 |
| -import copy |
6 |
| -import types |
7 |
| -import gc |
8 |
| -import numpy as np |
9 | 1 | from pathlib import Path
|
10 | 2 | from threading import Lock
|
| 3 | +from copy import deepcopy |
| 4 | +import os |
| 5 | +import torch |
11 | 6 |
|
12 |
| -np.set_printoptions(precision=4, suppress=True, linewidth=200) |
13 |
| -args = types.SimpleNamespace() |
14 |
| - |
15 |
| -print('\n\nChatRWKV project: https://github.com/BlinkDL/ChatRWKV') |
16 |
| - |
17 |
| -######################################################################################################## |
| 7 | +os.environ['RWKV_JIT_ON'] = '1' |
| 8 | +os.environ["RWKV_CUDA_ON"] = '0' |
18 | 9 |
|
19 |
| -args.RUN_DEVICE = "cuda" # cuda // cpu |
20 |
| -# fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU) |
21 |
| -args.FLOAT_MODE = "fp16" |
| 10 | +from .prompt import INIT_PROMPT, CHAT_FORMAT |
| 11 | +from .pipeline import PIPELINE, PIPELINE_ARGS |
| 12 | +from rwkv.model import RWKV # pip install rwkv |
22 | 13 |
|
23 |
| -QA_PROMPT = False # True: Q & A prompt // False: User & Bot prompt |
24 |
| -# 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊) |
25 | 14 |
|
26 |
| -# Download RWKV-4 models from https://huggingface.co/BlinkDL (don't use Instruct-test models unless you use their prompt templates) |
| 15 | +MODEL_DIR = Path('resource/chat/models') |
| 16 | +TOKEN_PATH = MODEL_DIR / '20B_tokenizer.json' |
| 17 | +STRATEGY = 'cuda fp16' |
27 | 18 |
|
28 |
| -MODELS = 'resource/chat/models' |
29 |
| -MODEL_FORMAT = '.pth' |
30 |
| -for f in os.listdir(MODELS): |
31 |
| - if not f.endswith(MODEL_FORMAT): |
| 19 | +MODEL_EXT = '.pth' |
| 20 | +MODEL_PATH = None |
| 21 | +for f in MODEL_DIR.glob('*'): |
| 22 | + if f.suffix != MODEL_EXT: |
32 | 23 | continue
|
33 |
| - f = f[:-len(MODEL_FORMAT)] |
34 |
| - args.MODEL_NAME = f'{MODELS}/{f}' |
35 |
| - if 'ctx2048' in f: |
36 |
| - args.ctx_len = 2048 |
37 |
| - else: |
38 |
| - args.ctx_len = 1024 |
| 24 | + MODEL_PATH = f.with_suffix('') |
39 | 25 | break
|
40 | 26 |
|
41 |
| -if not args.MODEL_NAME: |
42 |
| - print('!!!Chat model not found!!!') |
43 |
| - raise Exception('Chat model not found') |
44 |
| - |
45 |
| -CHAT_LEN_SHORT = 40 |
46 |
| -CHAT_LEN_LONG = 150 |
47 |
| -FREE_GEN_LEN = 200 |
48 |
| - |
49 |
| -GEN_TEMP = 1.0 |
50 |
| -GEN_TOP_P = 0.85 |
51 |
| - |
52 |
| -AVOID_REPEAT = ',。:?!' |
53 |
| - |
54 |
| -######################################################################################################## |
55 |
| - |
56 |
| -os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE |
57 |
| - |
58 |
| - |
59 |
| -print( |
60 |
| - f'\nLoading ChatRWKV - "Chinese" - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}') |
61 |
| - |
62 |
| - |
63 |
| -# please tune these (test True/False for all of them). can significantly improve speed. |
64 |
| -# torch._C._jit_set_profiling_executor(True) |
65 |
| -# torch._C._jit_set_profiling_mode(True) |
66 |
| -# torch._C._jit_override_can_fuse_on_cpu(True) |
67 |
| -# torch._C._jit_override_can_fuse_on_gpu(True) |
68 |
| -# torch._C._jit_set_texpr_fuser_enabled(False) |
69 |
| -# torch._C._jit_set_nvfuser_enabled(False) |
70 |
| - |
71 |
| -torch.backends.cudnn.benchmark = True |
72 |
| -torch.backends.cudnn.allow_tf32 = True |
73 |
| -torch.backends.cuda.matmul.allow_tf32 = True |
74 |
| -tokenizer = TOKENIZER( |
75 |
| - str(Path(__file__).parent / "ChatRWKV" / "20B_tokenizer.json")) |
76 |
| - |
77 |
| -args.vocab_size = 50277 |
78 |
| -args.head_qk = 0 |
79 |
| -args.pre_ffn = 0 |
80 |
| -args.grad_cp = 0 |
81 |
| -args.my_pos_emb = 0 |
82 |
| -MODEL_NAME = args.MODEL_NAME |
83 |
| - |
84 |
| - |
85 |
| -# Load Model |
86 |
| - |
87 |
| -print(f'Loading model - {MODEL_NAME}') |
88 |
| -model = RWKV_RNN(args) |
89 |
| - |
90 |
| -model_tokens = [] |
91 |
| -model_state = None |
92 |
| - |
93 |
| -AVOID_REPEAT_TOKENS = [] |
94 |
| -for i in AVOID_REPEAT: |
95 |
| - dd = tokenizer.encode(i) |
96 |
| - assert len(dd) == 1 |
97 |
| - AVOID_REPEAT_TOKENS += dd |
98 |
| - |
99 |
| -######################################################################################################## |
100 |
| - |
101 |
| -user = "博士" |
102 |
| -bot = "牛牛" |
103 |
| -interface = ": " |
104 |
| - |
105 |
| -init_prompt = f''' |
106 |
| -{bot}真名叫帕拉斯,年轻女性,曾经生活在米诺斯,现在在罗德岛工作,{bot}的头上有一朵花、一对角,喜欢喝酒、蹦迪、逛庆典。 |
107 |
| -{user}是罗德岛的领袖阿米娅的顾问,而罗德岛是一个医疗组织,他们的目标是拯救矿石病感染者 |
108 |
| -以下是{bot}与{user}的对话: |
109 |
| -
|
110 |
| -{user}{interface}你是谁? |
111 |
| -
|
112 |
| -{bot}{interface}我是帕拉斯,也可以叫我牛牛 |
| 27 | +print('Chat model:', MODEL_PATH) |
113 | 28 |
|
114 |
| -{user}{interface}我是谁? |
115 |
| -
|
116 |
| -{bot}{interface}你是博士呀 |
117 |
| -
|
118 |
| -{user}{interface}你喜欢喝酒吗? |
119 |
| -
|
120 |
| -{bot}{interface}喜欢,要不要来一杯? |
121 |
| -
|
122 |
| -{user}{interface}你好笨 |
123 |
| -
|
124 |
| -{bot}{interface}这对角可能会不小心撞倒些家具,我会尽量小心。 |
125 |
| -
|
126 |
| -''' |
127 |
| - |
128 |
| - |
129 |
| -def run_rnn(tokens, newline_adj=0): |
130 |
| - global model_tokens, model_state |
131 |
| - |
132 |
| - tokens = [int(x) for x in tokens] |
133 |
| - model_tokens += tokens |
134 |
| - out, model_state = model.forward(tokens, model_state) |
| 29 | +if not MODEL_PATH: |
| 30 | + print(f'!!!!!!Chat model not found, please put it in {MODEL_DIR}!!!!!!') |
| 31 | + print(f'!!!!!!Chat 模型不存在,请放到 {MODEL_DIR} 文件夹下!!!!!!') |
| 32 | + raise Exception('Chat model not found') |
135 | 33 |
|
136 |
| - # print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]') |
| 34 | +if not TOKEN_PATH.exists(): |
| 35 | + print(f'AI Chat updated, please put token file to {TOKEN_PATH}, download: https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json') |
| 36 | + print(f'牛牛的 AI Chat 版本更新了,把 token 文件放到 {TOKEN_PATH} 里再启动, 下载地址:https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json') |
| 37 | + raise Exception('Chat token not found') |
137 | 38 |
|
138 |
| - out[0] = -999999999 # disable <|endoftext|> |
139 |
| - out[187] += newline_adj # adjust \n probability |
140 |
| - # if newline_adj > 0: |
141 |
| - # out[15] += newline_adj / 2 # '.' |
142 |
| - if model_tokens[-1] in AVOID_REPEAT_TOKENS: |
143 |
| - out[model_tokens[-1]] = -999999999 |
144 |
| - return out |
| 39 | +torch.cuda.empty_cache() |
| 40 | +model = RWKV(model=str(MODEL_PATH), strategy=STRATEGY) |
| 41 | +pipeline = PIPELINE(model, str(TOKEN_PATH)) |
| 42 | +args = PIPELINE_ARGS(temperature=1.0, top_p=0.7, |
| 43 | + alpha_frequency=0.25, |
| 44 | + alpha_presence=0.25, |
| 45 | + token_ban=[0], # ban the generation of some tokens |
| 46 | + token_stop=[187]) # stop generation whenever you see any token here |
145 | 47 |
|
146 | 48 |
|
| 49 | +CHAT_INIT = "CHAT_INIT" |
147 | 50 | all_state = {}
|
148 |
| -INIT_SESSION = 'chat_init' |
149 |
| - |
150 |
| - |
151 |
| -def save_all_stat(session: str, last_out): |
152 |
| - all_state[session] = {} |
153 |
| - all_state[session]['out'] = last_out |
154 |
| - all_state[session]['rnn'] = copy.deepcopy(model_state) |
155 |
| - all_state[session]['token'] = copy.deepcopy(model_tokens) |
156 |
| - |
| 51 | +all_state[CHAT_INIT] = deepcopy(pipeline.generate( |
| 52 | + INIT_PROMPT, token_count=200, args=args)[1]) |
157 | 53 |
|
158 |
| -def load_all_stat(session: str): |
159 |
| - global model_tokens, model_state |
| 54 | +chat_locker = Lock() |
160 | 55 |
|
161 |
| - if session not in all_state: |
162 |
| - out = load_all_stat(INIT_SESSION) |
163 |
| - save_all_stat(session, out) |
164 | 56 |
|
165 |
| - model_state = copy.deepcopy(all_state[session]['rnn']) |
166 |
| - model_tokens = copy.deepcopy(all_state[session]['token']) |
167 |
| - return all_state[session]['out'] |
| 57 | +def chat(session: str, text: str, token_count: int = 50) -> str: |
| 58 | + with chat_locker: |
| 59 | + state = all_state[session] if session in all_state else deepcopy( |
| 60 | + all_state[CHAT_INIT]) |
| 61 | + ctx = CHAT_FORMAT.format(text) |
| 62 | + out, state = pipeline.generate( |
| 63 | + ctx, token_count=token_count, args=args, state=state) |
| 64 | + all_state[session] = deepcopy(state) |
| 65 | + return out |
168 | 66 |
|
169 | 67 |
|
170 |
| -def del_all_stat(session: str): |
| 68 | +def del_session(session: str): |
171 | 69 | if session in all_state:
|
172 | 70 | del all_state[session]
|
173 | 71 |
|
174 |
| -######################################################################################################## |
175 |
| - |
176 |
| - |
177 |
| -# Run inference |
178 |
| -print(f'\nRun prompt...') |
179 |
| - |
180 |
| -out = run_rnn(tokenizer.encode(init_prompt)) |
181 |
| -save_all_stat(INIT_SESSION, out) |
182 |
| -gc.collect() |
183 |
| -torch.cuda.empty_cache() |
184 |
| - |
185 |
| -print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n') |
186 |
| - |
187 |
| -chat_locker = Lock() |
188 |
| - |
189 |
| - |
190 |
| -def answer(session: str, text: str): |
191 |
| - with chat_locker: |
192 |
| - global model_tokens, model_state |
193 |
| - |
194 |
| - out = load_all_stat(session) |
195 |
| - new = f"{user}{interface}{text}\n\n{bot}{interface}" |
196 |
| - out = run_rnn(tokenizer.encode(new), newline_adj=-999999999) |
197 |
| - save_all_stat(session, out) |
198 |
| - |
199 |
| - ans = '' |
200 |
| - begin = len(model_tokens) |
201 |
| - out_last = begin |
202 |
| - for i in range(999): |
203 |
| - if i <= 0: |
204 |
| - newline_adj = -999999999 |
205 |
| - elif i <= CHAT_LEN_SHORT: |
206 |
| - newline_adj = (i - CHAT_LEN_SHORT) / 10 |
207 |
| - elif i <= CHAT_LEN_LONG: |
208 |
| - newline_adj = 0 |
209 |
| - else: |
210 |
| - newline_adj = (i - CHAT_LEN_LONG) * \ |
211 |
| - 0.25 # MUST END THE GENERATION |
212 |
| - token = tokenizer.sample_logits( |
213 |
| - out, |
214 |
| - model_tokens, |
215 |
| - args.ctx_len, |
216 |
| - temperature=GEN_TEMP, |
217 |
| - top_p=GEN_TOP_P, |
218 |
| - ) |
219 |
| - out = run_rnn([token], newline_adj=newline_adj) |
220 |
| - xxx = tokenizer.decode(model_tokens[out_last:]) |
221 |
| - if '\ufffd' not in xxx: # avoid utf-8 display issues |
222 |
| - ans += xxx |
223 |
| - # print(xxx, end='', flush=True) |
224 |
| - out_last = begin + i + 1 |
225 |
| - |
226 |
| - send_msg = tokenizer.decode(model_tokens[begin:]) |
227 |
| - if '\n\n' in send_msg: |
228 |
| - send_msg = send_msg.strip() |
229 |
| - break |
230 |
| - |
231 |
| - # send_msg = tokenizer.decode(model_tokens[begin:]).strip() |
232 |
| - # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! |
233 |
| - # send_msg = send_msg[:-len(f'{user}{interface}')].strip() |
234 |
| - # break |
235 |
| - # if send_msg.endswith(f'{bot}{interface}'): |
236 |
| - # send_msg = send_msg[:-len(f'{bot}{interface}')].strip() |
237 |
| - # break |
238 |
| - |
239 |
| - # print(f'{model_tokens}') |
240 |
| - # print(f'[{tokenizer.decode(model_tokens)}]') |
241 |
| - |
242 |
| - save_all_stat(session, out) |
243 |
| - return ans.strip() |
244 |
| - |
245 | 72 |
|
246 | 73 | if __name__ == "__main__":
|
247 | 74 | while True:
|
248 |
| - session = 1 |
| 75 | + session = "main" |
249 | 76 | text = input('text:')
|
250 |
| - answer(session, text) |
| 77 | + result = chat(session, text) |
| 78 | + print(result) |
0 commit comments