Skip to content

Commit 31a9f52

Browse files
committed
new
1 parent 88cd8a4 commit 31a9f52

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

code/naive_compute.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,30 @@ def compute_S_m(rewrite,quote):
3838
try:
3939
quote_index = rewrite.find(quote)
4040
if quote_index != -1 and quote_index !=0 :
41-
ppl = compute_ppl(rewrite[:quote_index + len(quote)],rewrite[quote_index+len(quote):])
41+
ppl = compute_ppl(rewrite[:quote_index + len(quote)],rewrite[quote_index+len(quote):])
4242
else:
43-
print('Error : [Q] in the beginning or end of the query')
43+
print('指标位置在最开头或结尾')
4444
return 'nan'
45-
if ppl >=50:
46-
ppl = 50
47-
S_m = (1-ppl/50)
45+
46+
S_m = safe_exp(0.053 * (ppl - 35.243))
47+
print(ppl)
48+
print('语义匹配度S_m', S_m)
4849
return S_m
4950
except Exception as e:
50-
print('Error : Compute S_m:', str(e))
51+
print('计算语义匹配度出现问题:', str(e))
5152
ppl = 'nan'
5253
return 'nan'
5354

5455
## Semantic Fluency(Compute the ppl of entire sentence)
5556
def compute_S_f(rewrite,quote):
5657
try:
5758
ppl = compute_ppl('',rewrite)
58-
if ppl >=50:
59-
ppl = 50
60-
S_f = (1-ppl/50)
59+
S_f = safe_exp(0.5 * (ppl - 16.47))
60+
61+
print('语义流畅度S_f', S_f)
6162
return S_f
6263
except Exception as e:
63-
print('Error : Compute S_f', str(e))
64+
print('计算语义流畅度出现问题:', str(e))
6465
ppl = 'nan'
6566
return 'nan'
6667

@@ -77,17 +78,22 @@ def compute_PPL_q(quote):
7778
def compute_S_n(quote):
7879
try:
7980
PPL_q = compute_PPL_q(quote)
80-
quote = quote.replace('"','')
81+
try:
82+
quote = quote.replace('"','')
83+
except:
84+
pass
85+
8186
if quote in Search_dict:
8287
SearchFreq = Search_dict[quote]
8388
novelty = (PPL_q * 5) / math.log10(SearchFreq)
84-
if novelty >= 20:
85-
novelty = 20
86-
S_n = novelty / 20
89+
90+
exp_n = safe_exp(-0.253 * (novelty - 10.67))
91+
S_n = (1 / (1 + exp_n))
8792
print('ppl',PPL_q,'SearchFreq',SearchFreq)
93+
print('新颖度S_n', S_n)
8894
return S_n
8995
except:
90-
print('Error : Compute S_n')
96+
print('计算S_n出现问题')
9197
return 'nan'
9298

9399

code/naive_rewrite.py

+2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def main(args):
110110
tokenizer=llm='llama2-70b-chat'
111111
elif model_name=='qwen1.5-72b-chat':
112112
tokenizer=llm='qwen1.5-72b-chat'
113+
elif model_name=='deepseel-r1':
114+
tokenizer=llm='deepseek-r1'
113115
else:
114116
print("Model ERROR")
115117

code/utils/utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import random
66
import dashscope
77
from http import HTTPStatus
8+
import math
89

910
## Ablation
1011
##Input the dict of Novelty and relevances
@@ -98,3 +99,12 @@ def get_prompt(prompt,query):
9899
if '{query}' in prompt:
99100
prompt = prompt.replace('{query}', query)
100101
return prompt
102+
103+
def safe_exp(value):
104+
# 设置阈值来避免指数溢出
105+
if value > 700:
106+
return float('inf') # 超过范围返回正无穷
107+
elif value < -700:
108+
return 0 # 太小返回0
109+
else:
110+
return math.exp(value)

0 commit comments

Comments
 (0)