Skip to content

Commit dc0bc5b

Browse files
committed
Evaluate information retrieval quality using eval script
- Encode article urls in filename indexed in Khoj KB Makes it easier for humans to compare, trace retrieval performance by looking at logs than using content hash (which was previously explored)
1 parent daeba66 commit dc0bc5b

File tree

1 file changed

+186
-2
lines changed

1 file changed

+186
-2
lines changed

tests/evals/eval.py

+186-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import argparse
2+
import base64
23
import concurrent.futures
4+
import hashlib
35
import json
46
import logging
57
import os
68
import re
79
import time
10+
import uuid
811
from datetime import datetime
912
from functools import partial
1013
from io import StringIO
@@ -14,9 +17,16 @@
1417

1518
import pandas as pd
1619
import requests
20+
import yaml
1721
from datasets import Dataset, load_dataset
22+
from tqdm import tqdm
1823

19-
from khoj.utils.helpers import get_cost_of_chat_message, is_none_or_empty, timer
24+
from khoj.utils.helpers import (
25+
batcher,
26+
get_cost_of_chat_message,
27+
is_none_or_empty,
28+
timer,
29+
)
2030

2131
# Configure root logger
2232
logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -61,6 +71,116 @@ def get(self):
6171
running_total_count = Counter(0)
6272

6373

74+
def get_article_filename(article: dict[str, str]) -> str:
75+
"""Create a unique filename for a Wikipedia article"""
76+
# Construct filename from frames prompt ids associated with each article and url
77+
encoded_url = base64.urlsafe_b64encode(article["link"].encode()).decode()
78+
return "-".join(map(str, article["frames_prompt_id"])) + f"_{encoded_url}.txt"
79+
80+
81+
def extract_prompt_ids_from_filename(filename: str) -> set[int]:
82+
"""Extract frames prompt id from a indexed file name"""
83+
return set(map(int, filename.split("_", 1)[0].split("-")))
84+
85+
86+
def extract_article_url_from_filename(filename: str) -> set[int]:
87+
"""Decode URL from filename"""
88+
encoded_url = filename.split("_", 1)[1].rsplit(".", 1)[0]
89+
return base64.urlsafe_b64decode(encoded_url).decode()
90+
91+
92+
def get_articles_by_prompt_id(prompt_id: int):
93+
"""Get all Wikipedia articles relevant to a specific FRAMES prompt ID"""
94+
try:
95+
# Load dataset
96+
dataset = load_dataset("parasail-ai/frames-benchmark-wikipedia")
97+
98+
# Filter function to check if prompt_id exists in sequence
99+
def has_prompt_id(example):
100+
return prompt_id in example["frames_prompt_id"]
101+
102+
# Filter dataset and return matching rows
103+
filtered_dataset = dataset["train"].filter(has_prompt_id)
104+
return filtered_dataset
105+
106+
except Exception as e:
107+
logger.error(f"Error filtering dataset for prompt {prompt_id}: {e}")
108+
return None
109+
110+
111+
def load_frames_kb():
112+
"""
113+
Load Wikipedia articles used as Knowledge Base by the FRAMES benchmark dataset from HuggingFace
114+
115+
FRAMES is a benchmark dataset to evaluate retrieval and answering capabilities of agents.
116+
It contains ~800 requiring multi-hop retrieval and reasoning across various topics from Wikipedia.
117+
118+
### Data Fields
119+
- link: The link to the Wikipedia article
120+
- text: The text content of the Wikipedia article
121+
- frames_prompt_id: The list of FRAMES prompt ids for which this article is relevant
122+
"""
123+
try:
124+
dataset_name = "parasail-ai/frames-benchmark-wikipedia"
125+
dataset = load_dataset(dataset_name)
126+
return dataset["train"]
127+
128+
except Exception as e:
129+
logger.error(f"Error loading {dataset_name} dataset: {e}")
130+
return None
131+
132+
133+
def index_frames_kb():
134+
"""Index Wikipedia articles from FRAMES dataset into Khoj"""
135+
try:
136+
# Load dataset
137+
dataset = load_frames_kb()
138+
dataset_files = set(map(get_article_filename, dataset))
139+
140+
# Get indexed files from Khoj API
141+
headers = {"Authorization": f"Bearer {KHOJ_API_KEY}"} if KHOJ_API_KEY else {}
142+
try:
143+
response = requests.get(f"{KHOJ_URL}/api/content/computer", headers=headers)
144+
response.raise_for_status()
145+
indexed_files = set(response.json())
146+
except requests.exceptions.RequestException as e:
147+
logger.error(f"Failed to get indexed files: {e}")
148+
return False
149+
150+
# Find missing files to index
151+
missing_files = dataset_files - indexed_files
152+
filtered_dataset = [
153+
article
154+
for article in dataset
155+
if get_article_filename(article) in missing_files and not is_none_or_empty(article["text"])
156+
]
157+
if not filtered_dataset:
158+
return True
159+
logger.info(f"Found {len(filtered_dataset)} files to index")
160+
161+
# Process Wikipedia articles from FRAMES knowledge base in batches
162+
batch_size = 300
163+
total_batches = len(filtered_dataset) // batch_size + 1
164+
for batch in tqdm(batcher(filtered_dataset, batch_size), total=total_batches, desc="Indexing FRAMES KB"):
165+
# Create files batch to index
166+
files = []
167+
for article in batch:
168+
filename = get_article_filename(article)
169+
files.append(("files", (filename, article["text"], "text/plaintext")))
170+
# Send files batch to index
171+
try:
172+
response = requests.patch(f"{KHOJ_URL}/api/content?client=eval", headers=headers, files=files)
173+
response.raise_for_status()
174+
time.sleep(SLEEP_SECONDS) # Rate limiting
175+
except Exception as e:
176+
logger.error(f"Failed to index batch: {e}")
177+
return False
178+
return True
179+
except Exception as e:
180+
logger.error(f"Failed to index KB: {e}")
181+
return False
182+
183+
64184
def load_frames_dataset():
65185
"""
66186
Load the Google FRAMES benchmark dataset from HuggingFace
@@ -248,6 +368,62 @@ def get_agent_response(prompt: str) -> Dict[str, Any]:
248368
return {"response": "", "usage": {}, "references": {}}
249369

250370

371+
def calculate_precision_recall(numerator: int, denominator: int) -> float:
372+
"""Calculate precision and recall from numerator and denominator"""
373+
if numerator == 0 and denominator == 0:
374+
return 1.0
375+
elif numerator > 0 and denominator == 0:
376+
return 0.0
377+
else:
378+
return numerator / denominator
379+
380+
381+
def calculate_fi(precision: float, recall: float) -> float:
382+
"""Calculate F1 score from precision and recall"""
383+
return 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
384+
385+
386+
def evaluate_response_for_ir(
387+
query: str, agent_response: str, ground_truth: int, agent_references: dict = {}
388+
) -> tuple[bool | None, str, float]:
389+
"""Evaluate Khoj response against benchmark ground truth using string matching"""
390+
try:
391+
# Extract answer from agent response
392+
referenced_files: list[dict[str, str]] = agent_references.get("context", [])
393+
count_of_correct_articles_used_by_agent: int = 0
394+
# Count how many of the expected articles the agent actually retrieved from the KB
395+
unique_file_refs = {file["file"] for file in referenced_files}
396+
referenced_articles = list(map(extract_article_url_from_filename, unique_file_refs))
397+
for file in unique_file_refs:
398+
frames_ids_for_articles_used_by_agent = extract_prompt_ids_from_filename(file)
399+
count_of_correct_articles_used_by_agent += int(ground_truth in frames_ids_for_articles_used_by_agent)
400+
401+
articles = get_articles_by_prompt_id(ground_truth)
402+
precision = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(unique_file_refs))
403+
recall = calculate_precision_recall(count_of_correct_articles_used_by_agent, len(articles))
404+
f1 = calculate_fi(precision, recall)
405+
406+
explanation = (
407+
f"Information Retrieval F1 Score: {f1:.2%} Recall: {recall:.2%}, Precision: {precision:.2%}.\n"
408+
f"{count_of_correct_articles_used_by_agent} of {len(articles)} correct from {len(unique_file_refs)} total retrievals for {ground_truth}.\n"
409+
f"Queries:\n{yaml.dump(sorted([r['query'] for r in referenced_files]))}\n"
410+
f"Expected Articles for {ground_truth}:\n{yaml.dump(sorted([a['link'] for a in articles]))}\n"
411+
f"Retrieved Articles for {ground_truth}:\n{yaml.dump(referenced_articles)}\n"
412+
)
413+
414+
# Truncate referenced files for logging
415+
truncated_refs = [
416+
{k: v[:200] + "..." if len(v) > 200 else v for k, v in ref.items()} for ref in referenced_files
417+
]
418+
logger.info(f"Retrieved Article Details:\n{yaml.dump(truncated_refs, sort_keys=False)}\n")
419+
420+
# Return decision, explanation and cost in structured form
421+
return recall, explanation, 0.0
422+
except Exception as e:
423+
logger.error(f"Error in IR evaluation: {e}")
424+
return None, f"Evaluation failed: {str(e)}", 0.0
425+
426+
251427
def evaluate_response_with_mcq_match(
252428
query: str, agent_response: str, ground_truth: str, agent_references: dict = {}
253429
) -> tuple[bool | None, str, float]:
@@ -417,7 +593,7 @@ def parse_args():
417593
"--dataset",
418594
"-d",
419595
default="frames",
420-
choices=["frames", "simpleqa", "gpqa", "math500"],
596+
choices=["frames", "frames_ir", "simpleqa", "gpqa", "math500"],
421597
help="Dataset to use for evaluation (default: frames)",
422598
)
423599
return parser.parse_args()
@@ -438,6 +614,12 @@ def main():
438614
dataset = load_gpqa_dataset()
439615
elif args.dataset == "math500":
440616
dataset = load_math500_dataset()
617+
elif args.dataset == "frames_ir":
618+
indexed = index_frames_kb()
619+
if indexed:
620+
dataset = load_frames_dataset()
621+
# Rename the index field, 'Unnamed: 0' to 'Answer' for IR evaluation
622+
dataset["Answer"] = dataset["Unnamed: 0"]
441623
if dataset is None:
442624
return
443625

@@ -450,6 +632,8 @@ def main():
450632
response_evaluator = partial(
451633
evaluate_response_with_gemini, eval_model=os.getenv("GEMINI_EVAL_MODEL", "gemini-1.5-flash-002")
452634
)
635+
elif args.dataset == "frames_ir":
636+
response_evaluator = evaluate_response_for_ir
453637
else:
454638
response_evaluator = evaluate_response_with_gemini
455639

0 commit comments

Comments
 (0)