Skip to content

Commit b73bb97

Browse files
committed
Add results for OASST model
1 parent ce1d2c1 commit b73bb97

File tree

3 files changed

+612
-6
lines changed

3 files changed

+612
-6
lines changed

fact.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
MODEL_PATH = 'oasst.gguf'
55
TASKS_PATH = 'fact.json'
6-
SKIP_TO = 75
6+
SKIP_TO = 76
77

88
PROMPT_TMPL = """\
99
Decide which of the following Summary is more consistent with the Article Sentence.

main.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,38 @@
1919
Answer:
2020
"""
2121

22-
OASST_TMPL = """\
22+
OASST_BASE_TMPL = """\
2323
<|im_start|>system
2424
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
2525
2626
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
2727
<|im_end|>
2828
<|im_start|>user
29+
{task}
30+
<|im_end|>
31+
<|im_start|>assistant
32+
"""
33+
34+
OASST_PROOFREAD_TMPL = OASST_BASE_TMPL.format(
35+
task="""\
2936
Edit the following essay to ensure the structure reads well.
3037
Check for grammatical errors and spelling mistakes.
3138
Preserve markdown formatting.
3239
3340
{article}
34-
<|im_end|>
35-
<|im_start|>assistant
3641
"""
42+
)
43+
44+
45+
OASST_TRANSLATE_TMPL = OASST_BASE_TMPL.format(
46+
task="Translate into {lang}: {text}"
47+
)
3748

3849
n_ctx = 1000
3950
llm = Llama(model_path=MODEL_PATH, n_gqa=8, verbose=False, n_ctx=n_ctx)
40-
prompt = ORCA_TMPL.format(article=sys.stdin.read())
41-
# print(prompt)
51+
# prompt = ORCA_TMPL.format(article=sys.stdin.read())
52+
prompt = OASST_TRANSLATE_TMPL.format(lang="Chinese", text=sys.stdin.read().strip())
53+
print(prompt)
4254
n_tokens = len(llm.tokenize(prompt.encode('utf-8')))
4355
# print(f'{n_tokens=}')
4456
if n_tokens >= n_ctx:

0 commit comments

Comments
 (0)