Skip to content

Commit 5083324

Browse files
更新demo
1. 支持清除历史/重试 2. 更新了新的requirement.txt 3. 记得去huggingface下载最新的配置文件(几天前)
1 parent ee96b07 commit 5083324

File tree

9 files changed

+133
-93
lines changed

9 files changed

+133
-93
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ finetune_demo/formatted_data
77
ToolAlpaca/
88
AdvertiseGen/
99
*.gz
10-
*.idea
10+
*.idea
11+
.DS_Store

composite_demo/client.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
7979

8080
if history is None:
8181
history = []
82+
83+
print("\n== Input ==\n", query)
84+
print("\n==History==\n", history)
85+
8286
if logits_processor is None:
8387
logits_processor = LogitsProcessorList()
8488
logits_processor.append(InvalidScoreLogitsProcessor())
@@ -109,7 +113,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
109113
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
110114
inputs['attention_mask'] = attention_mask
111115
history.append({"role": role, "content": query})
112-
# print("input_shape>", inputs['input_ids'].shape)
113116
input_sequence_length = inputs['input_ids'].shape[1]
114117
if input_sequence_length + max_new_tokens >= self.config.seq_length:
115118
yield "Current input sequence length {} plus max_new_tokens {} is too long. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.".format(
@@ -181,9 +184,7 @@ def generate_stream(self,
181184

182185
query = history[-1].content
183186
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')
184-
185187
text = ''
186-
187188
for new_text, _ in stream_chat(self.model,
188189
self.tokenizer,
189190
query,

composite_demo/demo_chat.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ def append_conversation(
1717
conversation.show(placeholder)
1818

1919

20-
def main(top_p: float,
21-
temperature: float,
22-
system_prompt: str,
23-
prompt_text: str,
24-
repetition_penalty: float,
25-
max_new_tokens: int):
20+
def main(
21+
prompt_text: str,
22+
system_prompt: str,
23+
top_p: float = 0.8,
24+
temperature: float = 0.95,
25+
repetition_penalty: float = 1.0,
26+
max_new_tokens: int = 1024,
27+
retry: bool = False
28+
):
2629
placeholder = st.empty()
2730
with placeholder.container():
2831
if 'chat_history' not in st.session_state:
@@ -33,6 +36,16 @@ def main(top_p: float,
3336
for conversation in history:
3437
conversation.show()
3538

39+
if retry:
40+
last_user_conversation_idx = None
41+
for idx, conversation in enumerate(history):
42+
if conversation.role == Role.USER:
43+
last_user_conversation_idx = idx
44+
if last_user_conversation_idx is not None:
45+
prompt_text = history[last_user_conversation_idx].content
46+
del history[last_user_conversation_idx:]
47+
48+
3649
if prompt_text:
3750
prompt_text = prompt_text.strip()
3851
append_conversation(Conversation(Role.USER, prompt_text), history)
@@ -42,11 +55,6 @@ def main(top_p: float,
4255
tools=None,
4356
history=history,
4457
)
45-
print("=== Input:")
46-
print(input_text)
47-
print("=== History:")
48-
print(history)
49-
5058
placeholder = st.empty()
5159
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
5260
markdown_placeholder = message_placeholder.empty()
@@ -65,9 +73,7 @@ def main(top_p: float,
6573
):
6674
token = response.token
6775
if response.token.special:
68-
print("=== Output:")
69-
print(output_text)
70-
76+
print("\n==Output:==\n", output_text)
7177
match token.text.strip():
7278
case '<|user|>':
7379
break
@@ -81,3 +87,5 @@ def main(top_p: float,
8187
Role.ASSISTANT,
8288
postprocess_text(output_text),
8389
), history, markdown_placeholder)
90+
else:
91+
st.session_state.chat_history = []

composite_demo/demo_ci.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,15 @@ def append_conversation(
221221
conversation.show(placeholder)
222222

223223

224-
def main(top_p: float,
225-
temperature: float,
226-
prompt_text: str,
227-
repetition_penalty: float,
228-
max_new_tokens: int,
229-
truncate_length: int = 1024):
224+
def main(
225+
prompt_text: str,
226+
top_p: float = 0.2,
227+
temperature: float = 0.1,
228+
repetition_penalty: float = 1.1,
229+
max_new_tokens: int = 1024,
230+
truncate_length: int = 1024,
231+
retry: bool = False
232+
):
230233
if 'ci_history' not in st.session_state:
231234
st.session_state.ci_history = []
232235

@@ -235,6 +238,15 @@ def main(top_p: float,
235238
for conversation in history:
236239
conversation.show()
237240

241+
if retry:
242+
last_user_conversation_idx = None
243+
for idx, conversation in enumerate(history):
244+
if conversation.role == Role.USER:
245+
last_user_conversation_idx = idx
246+
if last_user_conversation_idx is not None:
247+
prompt_text = history[last_user_conversation_idx].content
248+
del history[last_user_conversation_idx:]
249+
238250
if prompt_text:
239251
prompt_text = prompt_text.strip()
240252
role = Role.USER
@@ -245,10 +257,6 @@ def main(top_p: float,
245257
None,
246258
history,
247259
)
248-
print("=== Input:")
249-
print(input_text)
250-
print("=== History:")
251-
print(history)
252260

253261
placeholder = st.container()
254262
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
@@ -269,9 +277,7 @@ def main(top_p: float,
269277
):
270278
token = response.token
271279
if response.token.special:
272-
print("=== Output:")
273-
print(output_text)
274-
280+
print("\n==Output:==\n", output_text)
275281
match token.text.strip():
276282
case '<|user|>':
277283
append_conversation(Conversation(
@@ -335,3 +341,5 @@ def main(top_p: float,
335341
postprocess_text(output_text),
336342
), history, markdown_placeholder)
337343
return
344+
else:
345+
st.session_state.chat_history = []

composite_demo/demo_tool.py

+22-21
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
def tool_call(*args, **kwargs) -> dict:
32-
print("=== Tool call:")
32+
print("=== Tool call===")
3333
print(args)
3434
print(kwargs)
3535
st.session_state.calling_tool = True
@@ -60,13 +60,15 @@ def append_conversation(
6060
conversation.show(placeholder)
6161

6262

63-
def main(top_p: float,
64-
temperature: float,
65-
prompt_text: str,
66-
repetition_penalty: float,
67-
max_new_tokens: int,
68-
truncate_length: int = 1024,
69-
):
63+
def main(
64+
prompt_text: str,
65+
top_p: float = 0.2,
66+
temperature: float = 0.1,
67+
repetition_penalty: float = 1.1,
68+
max_new_tokens: int = 1024,
69+
truncate_length: int = 1024,
70+
retry: bool = False
71+
):
7072
manual_mode = st.toggle('Manual mode',
7173
help='Define your tools in YAML format. You need to supply tool call results manually.'
7274
)
@@ -95,22 +97,21 @@ def main(top_p: float,
9597
for conversation in history:
9698
conversation.show()
9799

100+
if retry:
101+
last_user_conversation_idx = None
102+
for idx, conversation in enumerate(history):
103+
if conversation.role == Role.USER:
104+
last_user_conversation_idx = idx
105+
if last_user_conversation_idx is not None:
106+
prompt_text = history[last_user_conversation_idx].content
107+
del history[last_user_conversation_idx:]
108+
98109
if prompt_text:
99110
prompt_text = prompt_text.strip()
100111
role = st.session_state.calling_tool and Role.OBSERVATION or Role.USER
101112
append_conversation(Conversation(role, prompt_text), history)
102113
st.session_state.calling_tool = False
103114

104-
input_text = preprocess_text(
105-
None,
106-
tools,
107-
history,
108-
)
109-
print("=== Input:")
110-
print(input_text)
111-
print("=== History:")
112-
print(history)
113-
114115
placeholder = st.container()
115116
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
116117
markdown_placeholder = message_placeholder.empty()
@@ -130,9 +131,7 @@ def main(top_p: float,
130131
):
131132
token = response.token
132133
if response.token.special:
133-
print("=== Output:")
134-
print(output_text)
135-
134+
print("\n==Output:==\n", output_text)
136135
match token.text.strip():
137136
case '<|user|>':
138137
append_conversation(Conversation(
@@ -199,3 +198,5 @@ def main(top_p: float,
199198
postprocess_text(output_text),
200199
), history, markdown_placeholder)
201200
return
201+
else:
202+
st.session_state.chat_history = []

composite_demo/main.py

+34-18
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ class Mode(str, Enum):
4040
max_new_token = st.slider(
4141
'Output length', 5, 32000, 256, step=1
4242
)
43+
44+
cols = st.columns(2)
45+
export_btn = cols[0]
46+
clear_history = cols[1].button("Clear History", use_container_width=True)
47+
retry = export_btn.button("Retry", use_container_width=True)
48+
4349
system_prompt = st.text_area(
4450
label="System Prompt (Only for chat mode)",
4551
height=300,
@@ -58,27 +64,37 @@ class Mode(str, Enum):
5864
label_visibility='hidden',
5965
)
6066

67+
if clear_history or retry:
68+
prompt_text = ""
69+
6170
match tab:
6271
case Mode.CHAT:
63-
demo_chat.main(top_p=top_p,
64-
temperature=temperature,
65-
prompt_text=prompt_text,
66-
system_prompt=system_prompt,
67-
repetition_penalty=repetition_penalty,
68-
max_new_tokens=max_new_token)
72+
demo_chat.main(
73+
retry=retry,
74+
top_p=top_p,
75+
temperature=temperature,
76+
prompt_text=prompt_text,
77+
system_prompt=system_prompt,
78+
repetition_penalty=repetition_penalty,
79+
max_new_tokens=max_new_token
80+
)
6981
case Mode.TOOL:
70-
demo_tool.main(top_p=top_p,
71-
temperature=temperature,
72-
prompt_text=prompt_text,
73-
repetition_penalty=repetition_penalty,
74-
max_new_tokens=max_new_token,
75-
truncate_length=1024)
82+
demo_tool.main(
83+
retry=retry,
84+
top_p=top_p,
85+
temperature=temperature,
86+
prompt_text=prompt_text,
87+
repetition_penalty=repetition_penalty,
88+
max_new_tokens=max_new_token,
89+
truncate_length=1024)
7690
case Mode.CI:
77-
demo_ci.main(top_p=top_p,
78-
temperature=temperature,
79-
prompt_text=prompt_text,
80-
repetition_penalty=repetition_penalty,
81-
max_new_tokens=max_new_token,
82-
truncate_length=1024)
91+
demo_ci.main(
92+
retry=retry,
93+
top_p=top_p,
94+
temperature=temperature,
95+
prompt_text=prompt_text,
96+
repetition_penalty=repetition_penalty,
97+
max_new_tokens=max_new_token,
98+
truncate_length=1024)
8399
case _:
84100
st.error(f'Unexpected tab: {tab}')

composite_demo/requirements.txt

+11-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
huggingface_hub
2-
ipykernel
3-
ipython
4-
jupyter_client
5-
pillow
6-
sentencepiece
7-
streamlit
8-
tokenizers
9-
torch
10-
transformers
11-
pyyaml
12-
requests
1+
huggingface_hub>=0.19.4
2+
pillow>=10.1.0
3+
streamlit>=1.29.0
4+
tokenizers>=0.15.0
5+
torch>=2.1.0
6+
transformers>=4.36.1
7+
pyyaml>=6.0.1
8+
requests>=2.31.0
9+
ipykernel>=6.26.0
10+
ipython>=8.18.1
11+
jupyter_client>=8.6.0

openai_api_demo/requirements.txt

-2
This file was deleted.

requirements.txt

+17-9
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1-
protobuf
2-
transformers>=4.30.2
3-
cpm_kernels
4-
torch>=2.0
1+
protobuf>=4.25.1
2+
transformers>=4.36.1
3+
cpm_kernels>=1.0.11
4+
torch>=2.1.0
55
gradio~=3.39
6-
sentencepiece
7-
accelerate
8-
sse-starlette
6+
sentencepiece>=0.1.99
7+
accelerate>=0.25.0
98
streamlit>=1.29.0
10-
fastapi>=0.104.1
9+
fastapi>=0.105.0
1110
uvicorn~=0.24.0
1211
loguru~=0.7.2
1312
mdtex2html>=1.2.0
14-
latex2mathml>=3.76.0
13+
latex2mathml>=3.77.0
14+
15+
# for openai demo
16+
openai>=1.4.0
17+
pydantic>=2.5.2
18+
httpx>=0.25.2
19+
fastapi>=0.105.0
20+
sse-starlette>=1.8.2
21+
uvicorn~=0.24.0
22+
timm>=0.9.12

0 commit comments

Comments
 (0)