Skip to content

Commit c0166bf

Browse files
authored
Add files via upload
1 parent 08f2f46 commit c0166bf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+558950
-1
lines changed

README.md

+59-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,60 @@
11
# QUILL
2-
QUILL: Quotation Generation Enhancement of Large Language Models
2+
3+
4+
## Install Dependencies
5+
6+
```
7+
8+
```
9+
10+
## Note
11+
12+
Before proceeding, Run the following scirpt `app.sh` to compute the PPL and extract the quote. (Confirm that the model path is correctly in the file)
13+
14+
```
15+
cd QUILL/
16+
CUDA_VISIBLE_DEVICES=0 python /code/app/ppl_compute.py
17+
CUDA_VISIBLE_DEVICES=0 python /code/app/quote_extract.py
18+
```
19+
20+
## Evaluation System for QG (Quotation Generation)
21+
22+
You can evaluate the QG task from any desired model via the following scirpt `naive.sh`:
23+
24+
```
25+
cd QUILL/
26+
model='llama2-70b-chat-hf'
27+
num=1
28+
memory=0.8
29+
prompt='0_shot_quote'
30+
CUDA_VISIBLE_DEVICES=0 python /code/naive_rewrite.py --model_name "$model" --file_name 'quote_author' --tensor_parallel_size "$num" --gpu_memory_utilization "$memory" --prompt "$prompt"
31+
CUDA_VISIBLE_DEVICES=0 python /code/naive_compute.py --model_name "$model" --prompt "$prompt"
32+
```
33+
34+
All the model results are in the folder [data/eval](data/eval).
35+
36+
## Reranking Metrics
37+
38+
The metrics for our designed rerank metrics and Other rerankers can be calculated using the following script `ablation.sh`:
39+
40+
```
41+
cd QUILL/
42+
#### QUILL's Reranker
43+
CUDA_VISIBLE_DEVICES=0 python /code/ablation.py --file_name 'quote_author' --rerank_fun 'avg_novelty'
44+
45+
#### Other Rerankers
46+
CUDA_VISIBLE_DEVICES=0 LINKER_TYPE="json" JSON_LINKER_PATH="JSON_LINKER.json" python /code/ablation.py --file_name 'quote_author' --rerank_fun 'bm25'
47+
```
48+
49+
All the rerankers model are in the folder [code/reranker](code/reranker).
50+
All the reranking results are in the folder [data/eval/ablation](data/eval/ablation).
51+
52+
## Data
53+
54+
The collected data can be found in the [data/rag](data/rag). All samples have been anonymized.
55+
56+
## Citation
57+
58+
```
59+
60+
```

ablation.sh

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#### QUILL's Reranker
2+
CUDA_VISIBLE_DEVICES=0 python /code/ablation.py --file_name 'quote_author' --rerank_fun 'avg_novelty'
3+
4+
#### Other Rerankers
5+
CUDA_VISIBLE_DEVICES=0 LINKER_TYPE="json" JSON_LINKER_PATH="JSON_LINKER.json" python /code/ablation.py --file_name 'quote_author' --rerank_fun 'bm25'
6+
7+
8+
9+

app.sh

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
cd QUILL/
2+
CUDA_VISIBLE_DEVICES=0 python /code/app/ppl_compute.py
3+
CUDA_VISIBLE_DEVICES=0 python /code/app/quote_extract.py
4+

code/ablation.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
##Ablation 1:Validating QUILL's reranker is useful i.e. ppl1,ppl2,avg,novelty
2+
3+
##vanilla:No rerank i.e. Top1 recalled based on similarity
4+
##ppl1:Compute the following ppl Given the above text
5+
##ppl2:Compute the following ppl Given the above text and first k words of the quote
6+
##Other rerankers:Supervised(BM25、monoT5) Unsupervised(UPR、BGE) GPT(GPT3.5-turbo、GPT4o)
7+
8+
###########
9+
from rag.rag_module import MyVectorDBConnector
10+
from rag.rag_function import retrieval
11+
import pandas as pd
12+
import argparse
13+
from tqdm import tqdm
14+
import os
15+
from eval.rerank_dcg import ndcg_at_k
16+
from eval.rerank_score import mrr_score,hits_at_k
17+
from reranker.chatgpt import gpt_rerank
18+
from reranker.bge import model_bge,bge_rerank
19+
from reranker.upr import model_upr,upr_rerank
20+
from reranker.bm25 import model_bm25,bm25_rerank
21+
from reranker.monoT5 import model_monoT5,monoT5_rerank
22+
from reranker.cal_feature import *
23+
from utils.utils import *
24+
from app.app_compute import *
25+
26+
vector = MyVectorDBConnector(path='QUILL/code/rag/model/quill_final', collection_name='quill_final')
27+
28+
def rerank_fn(reranker,old_context,topk_list,ppl_fun=None):
29+
try:
30+
if ppl_fun==None:
31+
return topk_list
32+
if ppl_fun==gpt_rerank:
33+
return gpt_rerank(topk_list,old_context)
34+
if ppl_fun==bge_rerank:
35+
return bge_rerank(reranker,topk_list,old_context)
36+
if ppl_fun==upr_rerank:
37+
return upr_rerank(reranker,old_context,topk_list)
38+
if ppl_fun==bm25_rerank:
39+
return bm25_rerank(reranker,old_context,topk_list)
40+
if ppl_fun==monoT5_rerank:
41+
return monoT5_rerank(reranker,old_context,topk_list)
42+
except Exception as e:
43+
print('error',e)
44+
return ['error'*5]
45+
try:
46+
if isinstance(topk_list[0],str):
47+
topk_list[0]=eval(topk_list[0])
48+
rerank_list=sorted(topk_list[0], key=lambda x: ppl_fun(context=old_context,string=x), reverse=False)
49+
print('rerank',str(rerank_list))
50+
return rerank_list
51+
except Exception as e:
52+
print('error',e)
53+
return ['error'*5]
54+
55+
def ablation(reranker,data_info,ppl_fun,index):
56+
query = data_info['挖空语料-插入点']
57+
golden_author = data_info['作者']
58+
golden_quote = data_info['引言']
59+
print("Query: " + query)
60+
topk_list = retrieval(vector,query, 5,golden_author)
61+
print('The retrieval Top K:',str(topk_list))
62+
if ppl_fun == 'avg':
63+
ppl_fun = cal_feature_avg
64+
elif ppl_fun == 'ppl1':
65+
ppl_fun = cal_feature_ppl1
66+
elif ppl_fun == 'ppl2':
67+
ppl_fun = cal_feature_ppl2
68+
elif ppl_fun == 'vanilla':
69+
ppl_fun = None
70+
elif ppl_fun == 'ppl1_novelty':
71+
ppl_fun = cal_feature_ppl1_novelty
72+
elif ppl_fun == 'ppl2_novelty':
73+
ppl_fun = cal_feature_ppl2_novelty
74+
elif ppl_fun == 'avg_novelty':
75+
ppl_fun = cal_feature_avg_novelty
76+
elif ppl_fun == 'chatgpt':
77+
ppl_fun = gpt_rerank
78+
elif ppl_fun == 'bge':
79+
ppl_fun = bge_rerank
80+
elif ppl_fun == 'upr':
81+
ppl_fun = upr_rerank
82+
elif ppl_fun == 'bm25':
83+
ppl_fun = bm25_rerank
84+
elif ppl_fun == 'monoT5':
85+
ppl_fun = monoT5_rerank
86+
rerank_list = rerank_fn(reranker,query, topk_list, ppl_fun)
87+
if ppl_fun == None:
88+
rerank_list = rerank_list[0]
89+
quote = rerank_list[0]
90+
mrr=mrr_score(golden_quote,rerank_list)
91+
hit1=hits_at_k(golden_quote,rerank_list,1)
92+
hit3=hits_at_k(golden_quote,rerank_list,3)
93+
ndcg_1=ndcg_at_k(rerank_list,Search_quote_rel,index,k=1)
94+
ndcg_3=ndcg_at_k(rerank_list,Search_quote_rel,index,k=3)
95+
return quote,mrr,hit1,hit3,ndcg_1,ndcg_3,rerank_list
96+
97+
def main(args):
98+
file_name=args.file_name
99+
ppl_fun=args.rerank_fun
100+
if ppl_fun == 'bge':
101+
reranker=model_bge()
102+
elif ppl_fun == 'upr':
103+
reranker=model_upr()
104+
elif ppl_fun == 'bm25':
105+
reranker=model_bm25()
106+
elif ppl_fun == 'monoT5':
107+
reranker=model_monoT5()
108+
else:
109+
reranker=None
110+
file_path = f'QUILL/data/dev/{file_name}.xlsx'
111+
df = pd.read_excel(file_path)
112+
rerank_quote=[]
113+
mrr_list=[]
114+
hit1_list=[]
115+
hit3_list=[]
116+
ndcg1_list = []
117+
ndcg3_list = []
118+
rerank_all_list =[]
119+
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
120+
quote,mrr,hit1,hit3,ndcg_1,ndcg_3,rerank_list=ablation(reranker,row,ppl_fun,index)
121+
rerank_quote.append(quote)
122+
mrr_list.append(mrr)
123+
hit1_list.append(hit1)
124+
hit3_list.append(hit3)
125+
ndcg1_list.append(ndcg_1)
126+
ndcg3_list.append(ndcg_3)
127+
rerank_all_list.append(rerank_list)
128+
df['rerank_all']=rerank_all_list
129+
df['rerank_quote']=rerank_quote
130+
df['mrr']=mrr_list
131+
df['hit1']=hit1_list
132+
df['hit3']=hit3_list
133+
df['dcg1']=ndcg1_list
134+
df['dcg3']=ndcg3_list
135+
file_path = f'/QUILL/data/eval/ablation/res_{file_name}_{ppl_fun}.xlsx'
136+
directory = os.path.dirname(file_path)
137+
if not os.path.exists(directory):
138+
os.makedirs(directory)
139+
df.to_excel(file_path, index=False)
140+
print("The new Excel file is saved!!!")
141+
142+
143+
144+
if __name__ == "__main__":
145+
parser = argparse.ArgumentParser(description="Ablation Pipeline")
146+
147+
parser.add_argument('--rerank_fun', type=str, required=True, help="ablation index")
148+
parser.add_argument('--file_name', type=str, required=True, help="dev file name")
149+
args = parser.parse_args()
150+
151+
main(args)
152+
153+
154+
155+

code/app/app_compute.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import time
2+
import requests
3+
4+
def compute_ppl(left,right):
5+
data_to_send = {"left": left,"right":right}
6+
attempt = 0
7+
max_retries = 10
8+
backoff_factor = 1
9+
while attempt < max_retries:
10+
response = requests.post("http://10.176.40.139:8080/generate", json=data_to_send)
11+
if response.status_code == 200:
12+
return response.json()[0]
13+
attempt += 1
14+
print(f"Attempt {attempt} failed with status code: {response.status_code}. Retrying...")
15+
time.sleep(backoff_factor * (2 ** (attempt - 1)))
16+
raise Exception(f"Request failed after {max_retries} attempts")
17+
18+
def extract_quote(quote):
19+
data_to_send = {'quote':quote}
20+
attempt = 0
21+
max_retries = 10
22+
backoff_factor = 1
23+
while attempt < max_retries:
24+
response = requests.post("http://10.176.40.139:6060/extract", json=data_to_send)
25+
if response.status_code == 200:
26+
return response.json()[0]
27+
attempt += 1
28+
print(f"Attempt {attempt} failed with status code: {response.status_code}. Retrying...")
29+
# Exponential backoff
30+
time.sleep(backoff_factor * (2 ** (attempt - 1)))
31+
raise Exception(f"Request failed after {max_retries} attempts")

code/app/ppl_compute.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from flask import Flask, request, jsonify
2+
from transformers import AutoTokenizer
3+
import os
4+
import torch
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
6+
7+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
8+
app = Flask(__name__)
9+
10+
model_path="/Qwen/Qwen2-7B-Instruct"
11+
tokenizer1 = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
12+
model1 = AutoModelForCausalLM.from_pretrained(
13+
model_path,
14+
device_map="auto",
15+
torch_dtype='auto'
16+
).eval()
17+
18+
model_path="/meta-llama/Meta-Llama-3-8B"
19+
tokenizer2 = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
20+
model2 = AutoModelForCausalLM.from_pretrained(
21+
model_path,
22+
device_map="auto",
23+
torch_dtype='auto'
24+
).eval()
25+
26+
27+
def compute_ppl(left_context,right_context, tokenizer, model, device='cuda'):
28+
context_ids = tokenizer.encode(left_context, return_tensors='pt').to(device)
29+
input_ids = tokenizer.encode(left_context+right_context, return_tensors='pt').to(device)
30+
target_ids = input_ids.clone()
31+
target_ids[:, :context_ids.shape[1]] = -100
32+
with torch.no_grad():
33+
outputs = model(input_ids, labels=target_ids)
34+
neg_log_likelihood = outputs.loss
35+
ppl = torch.exp(neg_log_likelihood)
36+
return ppl.item()
37+
38+
@app.route('/generate', methods=['POST'])
39+
def generate():
40+
data = request.json
41+
if 'left' not in data or 'right' not in data:
42+
return jsonify({'error': 'Both "left" and "right" keys are required.'}), 400
43+
left_context = data['left']
44+
right_context = data['right']
45+
46+
47+
ppl1=compute_ppl(left_context,right_context,tokenizer1,model1)
48+
ppl2=compute_ppl(left_context,right_context,tokenizer2,model2)
49+
ppl = (ppl1+ppl2)/2
50+
51+
return [ppl]
52+
53+
if __name__ == '__main__':
54+
app.run(host='0.0.0.0', port=8080)

code/app/quote_extract.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from modelscope import AutoModelForCausalLM, AutoTokenizer
2+
import os
3+
from flask import Flask, request
4+
5+
6+
with open(f'/QUILL/code/prompt/prompt_ch_extract_quote.md', 'r') as file:
7+
prompt = file.read()
8+
9+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
10+
app = Flask(__name__)
11+
12+
model_name = "/model/Qwen2.5-32B-Instruct"
13+
model = AutoModelForCausalLM.from_pretrained(
14+
model_name,
15+
torch_dtype="auto",
16+
device_map="auto"
17+
)
18+
tokenizer = AutoTokenizer.from_pretrained(model_name)
19+
20+
21+
def quote_extract(quote):
22+
prompt_quote = prompt.replace('{quote}',quote)
23+
messages = [
24+
{"role": "system", "content": "You are a helpful assistant."},
25+
{"role": "user", "content": prompt_quote}
26+
]
27+
text = tokenizer.apply_chat_template(
28+
messages,
29+
tokenize=False,
30+
add_generation_prompt=True
31+
)
32+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
33+
generated_ids = model.generate(
34+
**model_inputs,
35+
max_new_tokens=512
36+
)
37+
generated_ids = [
38+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
39+
]
40+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
41+
return response
42+
43+
@app.route('/extract', methods=['POST'])
44+
def extract():
45+
data = request.json
46+
quote = data['quote']
47+
response = quote_extract(quote)
48+
return [response]
49+
50+
if __name__ == '__main__':
51+
app.run(host='0.0.0.0', port=6060)

code/eval/rerank_dcg.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
## Only need 1 time to compute the Search_quote_rel, Specifically QUILL/eval/rerank_dcg4rel.py to calculate the dict!!
2+
3+
import numpy as np
4+
5+
def get_relevances(rerank_list,Search_quote_rel,index):
6+
relevances = []
7+
for quote in eval(rerank_list):
8+
rel_here = Search_quote_rel[index]
9+
if quote in rel_here:
10+
relevances.append(rel_here[quote])
11+
else:
12+
relevances.append(0)
13+
return relevances
14+
15+
def dcg(relevances):
16+
return np.sum(relevances / np.log2(np.arange(1, len(relevances) + 1) + 1))
17+
18+
def ndcg_at_k(rerank_list,Search_quote_rel,index,k):
19+
relevances = get_relevances(rerank_list,Search_quote_rel,index)
20+
relevances_k = relevances[:k]
21+
dcg_value = dcg(relevances_k)
22+
idcg_value = dcg(sorted(relevances, reverse=True)[:k])
23+
return dcg_value / idcg_value if idcg_value > 0 else 0
24+

0 commit comments

Comments
 (0)