Skip to content

Commit 430bc47

Browse files
committed
feat: 集成新版chatrwkv api
1 parent ad89ec9 commit 430bc47

File tree

7 files changed

+206
-243
lines changed

7 files changed

+206
-243
lines changed

.gitmodules

-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,3 @@
22
path = src/plugins/sing/so_vits_svc
33
url = https://github.com/MistEO/so-vits-svc.git
44
branch = 4.0_pallas
5-
[submodule "src/plugins/chat/ChatRWKV"]
6-
path = src/plugins/chat/ChatRWKV
7-
url = https://github.com/MistEO/ChatRWKV.git

docs/AIDeployment.md

+5-10
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,21 @@
4040
4141
## Chat
4242
43-
1. 下载模型,参考 [原仓库说明](https://github.com/BlinkDL/ChatRWKV#%E4%B8%AD%E6%96%87%E6%A8%A1%E5%9E%8B),把文件放到 `resource/chat/models` 文件夹(只要是 `.pth` 都行,根据你的显存和需求选择)
44-
2. 更新 git 子模块
45-
46-
```
47-
git submodule update --init --recursive
48-
```
43+
1. 下载模型,参考 [原仓库说明](https://github.com/BlinkDL/ChatRWKV#%E4%B8%AD%E6%96%87%E6%A8%A1%E5%9E%8B);下载 [token 文件](https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json)。都放到 `resource/chat/models` 文件夹(模型只要是 `.pth` 都行,根据你的显存和需求选择)
4944
50-
3. 安装依赖
45+
2. 安装依赖
5146
5247
- CPU
5348
5449
```bash
55-
python -m pip install torch torchvision torchaudio tokenizers
50+
python -m pip install torch tokenizers rwkv
5651
```
5752
5853
- GPU
5954
6055
```bash
61-
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
62-
python -m pip install tokenizers
56+
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu117
57+
python -m pip install tokenizers rwkv
6358
```
6459
6560
4. `src/plugins/chat/model.py` 里的起手咒语 `init_prompt` 有兴趣可以试着改改

src/plugins/chat/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from nonebot import on_message, get_driver, logger
77
import random
88

9-
from .model import answer, del_all_stat
9+
from .model import chat, del_session
1010
from src.common.config import BotConfig, GroupConfig
1111
try:
1212
from src.common.utils.speech.text_to_speech import text_2_speech
@@ -15,14 +15,14 @@
1515
print('TTS not available, error:', error)
1616
TTS_AVAIABLE = False
1717

18-
TTS_PROBABILITY = 0.3
18+
TTS_MIN_LENGTH = 10
1919

2020

2121
def on_sober_up(bot_id, group_id, drunkenness) -> bool:
2222
session = f'{bot_id}_{group_id}'
2323
logger.info(
2424
f'bot [{bot_id}] sober up in group [{group_id}], clear session [{session}]')
25-
del_all_stat(session)
25+
del_session(session)
2626

2727

2828
BotConfig.register_sober_up(on_sober_up)
@@ -56,16 +56,16 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
5656
session = f'{event.self_id}_{event.group_id}'
5757
if text.startswith('牛牛'):
5858
text = text[2:].strip()
59-
ans = await asyncify(answer)(session, text)
59+
ans = await asyncify(chat)(session, text)
6060

6161
config.reset_cooldown(cd_key)
6262
if not ans:
6363
return
6464

6565
logger.info(f'session [{session}]: {text} -> {ans}')
6666

67-
if TTS_AVAIABLE and random.random() < TTS_PROBABILITY:
68-
bs = await asyncify(text_2_speech)(ans[:50], 1.0)
67+
if TTS_AVAIABLE and len(ans) >= TTS_MIN_LENGTH:
68+
bs = await asyncify(text_2_speech)(bs)
6969
msg = MessageSegment.record(bs)
7070
else:
7171
msg = ans

src/plugins/chat/model.py

+50-222
Original file line numberDiff line numberDiff line change
@@ -1,250 +1,78 @@
1-
import torch
2-
from .ChatRWKV.src.model_run import RWKV_RNN
3-
from .ChatRWKV.src.utils import TOKENIZER
4-
import os
5-
import copy
6-
import types
7-
import gc
8-
import numpy as np
91
from pathlib import Path
102
from threading import Lock
3+
from copy import deepcopy
4+
import os
5+
import torch
116

12-
np.set_printoptions(precision=4, suppress=True, linewidth=200)
13-
args = types.SimpleNamespace()
14-
15-
print('\n\nChatRWKV project: https://github.com/BlinkDL/ChatRWKV')
16-
17-
########################################################################################################
7+
os.environ['RWKV_JIT_ON'] = '1'
8+
os.environ["RWKV_CUDA_ON"] = '0'
189

19-
args.RUN_DEVICE = "cuda" # cuda // cpu
20-
# fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU)
21-
args.FLOAT_MODE = "fp16"
10+
from .prompt import INIT_PROMPT, CHAT_FORMAT
11+
from .pipeline import PIPELINE, PIPELINE_ARGS
12+
from rwkv.model import RWKV # pip install rwkv
2213

23-
QA_PROMPT = False # True: Q & A prompt // False: User & Bot prompt
24-
# 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊)
2514

26-
# Download RWKV-4 models from https://huggingface.co/BlinkDL (don't use Instruct-test models unless you use their prompt templates)
15+
MODEL_DIR = Path('resource/chat/models')
16+
TOKEN_PATH = MODEL_DIR / '20B_tokenizer.json'
17+
STRATEGY = 'cuda fp16'
2718

28-
MODELS = 'resource/chat/models'
29-
MODEL_FORMAT = '.pth'
30-
for f in os.listdir(MODELS):
31-
if not f.endswith(MODEL_FORMAT):
19+
MODEL_EXT = '.pth'
20+
MODEL_PATH = None
21+
for f in MODEL_DIR.glob('*'):
22+
if f.suffix != MODEL_EXT:
3223
continue
33-
f = f[:-len(MODEL_FORMAT)]
34-
args.MODEL_NAME = f'{MODELS}/{f}'
35-
if 'ctx2048' in f:
36-
args.ctx_len = 2048
37-
else:
38-
args.ctx_len = 1024
24+
MODEL_PATH = f.with_suffix('')
3925
break
4026

41-
if not args.MODEL_NAME:
42-
print('!!!Chat model not found!!!')
43-
raise Exception('Chat model not found')
44-
45-
CHAT_LEN_SHORT = 40
46-
CHAT_LEN_LONG = 150
47-
FREE_GEN_LEN = 200
48-
49-
GEN_TEMP = 1.0
50-
GEN_TOP_P = 0.85
51-
52-
AVOID_REPEAT = ',。:?!'
53-
54-
########################################################################################################
55-
56-
os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
57-
58-
59-
print(
60-
f'\nLoading ChatRWKV - "Chinese" - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}')
61-
62-
63-
# please tune these (test True/False for all of them). can significantly improve speed.
64-
# torch._C._jit_set_profiling_executor(True)
65-
# torch._C._jit_set_profiling_mode(True)
66-
# torch._C._jit_override_can_fuse_on_cpu(True)
67-
# torch._C._jit_override_can_fuse_on_gpu(True)
68-
# torch._C._jit_set_texpr_fuser_enabled(False)
69-
# torch._C._jit_set_nvfuser_enabled(False)
70-
71-
torch.backends.cudnn.benchmark = True
72-
torch.backends.cudnn.allow_tf32 = True
73-
torch.backends.cuda.matmul.allow_tf32 = True
74-
tokenizer = TOKENIZER(
75-
str(Path(__file__).parent / "ChatRWKV" / "20B_tokenizer.json"))
76-
77-
args.vocab_size = 50277
78-
args.head_qk = 0
79-
args.pre_ffn = 0
80-
args.grad_cp = 0
81-
args.my_pos_emb = 0
82-
MODEL_NAME = args.MODEL_NAME
83-
84-
85-
# Load Model
86-
87-
print(f'Loading model - {MODEL_NAME}')
88-
model = RWKV_RNN(args)
89-
90-
model_tokens = []
91-
model_state = None
92-
93-
AVOID_REPEAT_TOKENS = []
94-
for i in AVOID_REPEAT:
95-
dd = tokenizer.encode(i)
96-
assert len(dd) == 1
97-
AVOID_REPEAT_TOKENS += dd
98-
99-
########################################################################################################
100-
101-
user = "博士"
102-
bot = "牛牛"
103-
interface = ": "
104-
105-
init_prompt = f'''
106-
{bot}真名叫帕拉斯,年轻女性,曾经生活在米诺斯,现在在罗德岛工作,{bot}的头上有一朵花、一对角,喜欢喝酒、蹦迪、逛庆典。
107-
{user}是罗德岛的领袖阿米娅的顾问,而罗德岛是一个医疗组织,他们的目标是拯救矿石病感染者
108-
以下是{bot}{user}的对话:
109-
110-
{user}{interface}你是谁?
111-
112-
{bot}{interface}我是帕拉斯,也可以叫我牛牛
27+
print('Chat model:', MODEL_PATH)
11328

114-
{user}{interface}我是谁?
115-
116-
{bot}{interface}你是博士呀
117-
118-
{user}{interface}你喜欢喝酒吗?
119-
120-
{bot}{interface}喜欢,要不要来一杯?
121-
122-
{user}{interface}你好笨
123-
124-
{bot}{interface}这对角可能会不小心撞倒些家具,我会尽量小心。
125-
126-
'''
127-
128-
129-
def run_rnn(tokens, newline_adj=0):
130-
global model_tokens, model_state
131-
132-
tokens = [int(x) for x in tokens]
133-
model_tokens += tokens
134-
out, model_state = model.forward(tokens, model_state)
29+
if not MODEL_PATH:
30+
print(f'!!!!!!Chat model not found, please put it in {MODEL_DIR}!!!!!!')
31+
print(f'!!!!!!Chat 模型不存在,请放到 {MODEL_DIR} 文件夹下!!!!!!')
32+
raise Exception('Chat model not found')
13533

136-
# print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]')
34+
if not TOKEN_PATH.exists():
35+
print(f'AI Chat updated, please put token file to {TOKEN_PATH}, download: https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json')
36+
print(f'牛牛的 AI Chat 版本更新了,把 token 文件放到 {TOKEN_PATH} 里再启动, 下载地址:https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json')
37+
raise Exception('Chat token not found')
13738

138-
out[0] = -999999999 # disable <|endoftext|>
139-
out[187] += newline_adj # adjust \n probability
140-
# if newline_adj > 0:
141-
# out[15] += newline_adj / 2 # '.'
142-
if model_tokens[-1] in AVOID_REPEAT_TOKENS:
143-
out[model_tokens[-1]] = -999999999
144-
return out
39+
torch.cuda.empty_cache()
40+
model = RWKV(model=str(MODEL_PATH), strategy=STRATEGY)
41+
pipeline = PIPELINE(model, str(TOKEN_PATH))
42+
args = PIPELINE_ARGS(temperature=1.0, top_p=0.7,
43+
alpha_frequency=0.25,
44+
alpha_presence=0.25,
45+
token_ban=[0], # ban the generation of some tokens
46+
token_stop=[187]) # stop generation whenever you see any token here
14547

14648

49+
CHAT_INIT = "CHAT_INIT"
14750
all_state = {}
148-
INIT_SESSION = 'chat_init'
149-
150-
151-
def save_all_stat(session: str, last_out):
152-
all_state[session] = {}
153-
all_state[session]['out'] = last_out
154-
all_state[session]['rnn'] = copy.deepcopy(model_state)
155-
all_state[session]['token'] = copy.deepcopy(model_tokens)
156-
51+
all_state[CHAT_INIT] = deepcopy(pipeline.generate(
52+
INIT_PROMPT, token_count=200, args=args)[1])
15753

158-
def load_all_stat(session: str):
159-
global model_tokens, model_state
54+
chat_locker = Lock()
16055

161-
if session not in all_state:
162-
out = load_all_stat(INIT_SESSION)
163-
save_all_stat(session, out)
16456

165-
model_state = copy.deepcopy(all_state[session]['rnn'])
166-
model_tokens = copy.deepcopy(all_state[session]['token'])
167-
return all_state[session]['out']
57+
def chat(session: str, text: str, token_count: int = 50) -> str:
58+
with chat_locker:
59+
state = all_state[session] if session in all_state else deepcopy(
60+
all_state[CHAT_INIT])
61+
ctx = CHAT_FORMAT.format(text)
62+
out, state = pipeline.generate(
63+
ctx, token_count=token_count, args=args, state=state)
64+
all_state[session] = deepcopy(state)
65+
return out
16866

16967

170-
def del_all_stat(session: str):
68+
def del_session(session: str):
17169
if session in all_state:
17270
del all_state[session]
17371

174-
########################################################################################################
175-
176-
177-
# Run inference
178-
print(f'\nRun prompt...')
179-
180-
out = run_rnn(tokenizer.encode(init_prompt))
181-
save_all_stat(INIT_SESSION, out)
182-
gc.collect()
183-
torch.cuda.empty_cache()
184-
185-
print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n')
186-
187-
chat_locker = Lock()
188-
189-
190-
def answer(session: str, text: str):
191-
with chat_locker:
192-
global model_tokens, model_state
193-
194-
out = load_all_stat(session)
195-
new = f"{user}{interface}{text}\n\n{bot}{interface}"
196-
out = run_rnn(tokenizer.encode(new), newline_adj=-999999999)
197-
save_all_stat(session, out)
198-
199-
ans = ''
200-
begin = len(model_tokens)
201-
out_last = begin
202-
for i in range(999):
203-
if i <= 0:
204-
newline_adj = -999999999
205-
elif i <= CHAT_LEN_SHORT:
206-
newline_adj = (i - CHAT_LEN_SHORT) / 10
207-
elif i <= CHAT_LEN_LONG:
208-
newline_adj = 0
209-
else:
210-
newline_adj = (i - CHAT_LEN_LONG) * \
211-
0.25 # MUST END THE GENERATION
212-
token = tokenizer.sample_logits(
213-
out,
214-
model_tokens,
215-
args.ctx_len,
216-
temperature=GEN_TEMP,
217-
top_p=GEN_TOP_P,
218-
)
219-
out = run_rnn([token], newline_adj=newline_adj)
220-
xxx = tokenizer.decode(model_tokens[out_last:])
221-
if '\ufffd' not in xxx: # avoid utf-8 display issues
222-
ans += xxx
223-
# print(xxx, end='', flush=True)
224-
out_last = begin + i + 1
225-
226-
send_msg = tokenizer.decode(model_tokens[begin:])
227-
if '\n\n' in send_msg:
228-
send_msg = send_msg.strip()
229-
break
230-
231-
# send_msg = tokenizer.decode(model_tokens[begin:]).strip()
232-
# if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
233-
# send_msg = send_msg[:-len(f'{user}{interface}')].strip()
234-
# break
235-
# if send_msg.endswith(f'{bot}{interface}'):
236-
# send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
237-
# break
238-
239-
# print(f'{model_tokens}')
240-
# print(f'[{tokenizer.decode(model_tokens)}]')
241-
242-
save_all_stat(session, out)
243-
return ans.strip()
244-
24572

24673
if __name__ == "__main__":
24774
while True:
248-
session = 1
75+
session = "main"
24976
text = input('text:')
250-
answer(session, text)
77+
result = chat(session, text)
78+
print(result)

0 commit comments

Comments
 (0)