1
1
import argparse
2
+ import base64
2
3
import concurrent .futures
4
+ import hashlib
3
5
import json
4
6
import logging
5
7
import os
6
8
import re
7
9
import time
10
+ import uuid
8
11
from datetime import datetime
9
12
from functools import partial
10
13
from io import StringIO
14
17
15
18
import pandas as pd
16
19
import requests
20
+ import yaml
17
21
from datasets import Dataset , load_dataset
22
+ from tqdm import tqdm
18
23
19
- from khoj .utils .helpers import get_cost_of_chat_message , is_none_or_empty , timer
24
+ from khoj .utils .helpers import (
25
+ batcher ,
26
+ get_cost_of_chat_message ,
27
+ is_none_or_empty ,
28
+ timer ,
29
+ )
20
30
21
31
# Configure root logger
22
32
logging .basicConfig (level = logging .INFO , format = "%(message)s" )
@@ -61,6 +71,116 @@ def get(self):
61
71
running_total_count = Counter (0 )
62
72
63
73
74
+ def get_article_filename (article : dict [str , str ]) -> str :
75
+ """Create a unique filename for a Wikipedia article"""
76
+ # Construct filename from frames prompt ids associated with each article and url
77
+ encoded_url = base64 .urlsafe_b64encode (article ["link" ].encode ()).decode ()
78
+ return "-" .join (map (str , article ["frames_prompt_id" ])) + f"_{ encoded_url } .txt"
79
+
80
+
81
+ def extract_prompt_ids_from_filename (filename : str ) -> set [int ]:
82
+ """Extract frames prompt id from a indexed file name"""
83
+ return set (map (int , filename .split ("_" , 1 )[0 ].split ("-" )))
84
+
85
+
86
+ def extract_article_url_from_filename (filename : str ) -> set [int ]:
87
+ """Decode URL from filename"""
88
+ encoded_url = filename .split ("_" , 1 )[1 ].rsplit ("." , 1 )[0 ]
89
+ return base64 .urlsafe_b64decode (encoded_url ).decode ()
90
+
91
+
92
+ def get_articles_by_prompt_id (prompt_id : int ):
93
+ """Get all Wikipedia articles relevant to a specific FRAMES prompt ID"""
94
+ try :
95
+ # Load dataset
96
+ dataset = load_dataset ("parasail-ai/frames-benchmark-wikipedia" )
97
+
98
+ # Filter function to check if prompt_id exists in sequence
99
+ def has_prompt_id (example ):
100
+ return prompt_id in example ["frames_prompt_id" ]
101
+
102
+ # Filter dataset and return matching rows
103
+ filtered_dataset = dataset ["train" ].filter (has_prompt_id )
104
+ return filtered_dataset
105
+
106
+ except Exception as e :
107
+ logger .error (f"Error filtering dataset for prompt { prompt_id } : { e } " )
108
+ return None
109
+
110
+
111
+ def load_frames_kb ():
112
+ """
113
+ Load Wikipedia articles used as Knowledge Base by the FRAMES benchmark dataset from HuggingFace
114
+
115
+ FRAMES is a benchmark dataset to evaluate retrieval and answering capabilities of agents.
116
+ It contains ~800 requiring multi-hop retrieval and reasoning across various topics from Wikipedia.
117
+
118
+ ### Data Fields
119
+ - link: The link to the Wikipedia article
120
+ - text: The text content of the Wikipedia article
121
+ - frames_prompt_id: The list of FRAMES prompt ids for which this article is relevant
122
+ """
123
+ try :
124
+ dataset_name = "parasail-ai/frames-benchmark-wikipedia"
125
+ dataset = load_dataset (dataset_name )
126
+ return dataset ["train" ]
127
+
128
+ except Exception as e :
129
+ logger .error (f"Error loading { dataset_name } dataset: { e } " )
130
+ return None
131
+
132
+
133
+ def index_frames_kb ():
134
+ """Index Wikipedia articles from FRAMES dataset into Khoj"""
135
+ try :
136
+ # Load dataset
137
+ dataset = load_frames_kb ()
138
+ dataset_files = set (map (get_article_filename , dataset ))
139
+
140
+ # Get indexed files from Khoj API
141
+ headers = {"Authorization" : f"Bearer { KHOJ_API_KEY } " } if KHOJ_API_KEY else {}
142
+ try :
143
+ response = requests .get (f"{ KHOJ_URL } /api/content/computer" , headers = headers )
144
+ response .raise_for_status ()
145
+ indexed_files = set (response .json ())
146
+ except requests .exceptions .RequestException as e :
147
+ logger .error (f"Failed to get indexed files: { e } " )
148
+ return False
149
+
150
+ # Find missing files to index
151
+ missing_files = dataset_files - indexed_files
152
+ filtered_dataset = [
153
+ article
154
+ for article in dataset
155
+ if get_article_filename (article ) in missing_files and not is_none_or_empty (article ["text" ])
156
+ ]
157
+ if not filtered_dataset :
158
+ return True
159
+ logger .info (f"Found { len (filtered_dataset )} files to index" )
160
+
161
+ # Process Wikipedia articles from FRAMES knowledge base in batches
162
+ batch_size = 300
163
+ total_batches = len (filtered_dataset ) // batch_size + 1
164
+ for batch in tqdm (batcher (filtered_dataset , batch_size ), total = total_batches , desc = "Indexing FRAMES KB" ):
165
+ # Create files batch to index
166
+ files = []
167
+ for article in batch :
168
+ filename = get_article_filename (article )
169
+ files .append (("files" , (filename , article ["text" ], "text/plaintext" )))
170
+ # Send files batch to index
171
+ try :
172
+ response = requests .patch (f"{ KHOJ_URL } /api/content?client=eval" , headers = headers , files = files )
173
+ response .raise_for_status ()
174
+ time .sleep (SLEEP_SECONDS ) # Rate limiting
175
+ except Exception as e :
176
+ logger .error (f"Failed to index batch: { e } " )
177
+ return False
178
+ return True
179
+ except Exception as e :
180
+ logger .error (f"Failed to index KB: { e } " )
181
+ return False
182
+
183
+
64
184
def load_frames_dataset ():
65
185
"""
66
186
Load the Google FRAMES benchmark dataset from HuggingFace
@@ -248,6 +368,62 @@ def get_agent_response(prompt: str) -> Dict[str, Any]:
248
368
return {"response" : "" , "usage" : {}, "references" : {}}
249
369
250
370
371
+ def calculate_precision_recall (numerator : int , denominator : int ) -> float :
372
+ """Calculate precision and recall from numerator and denominator"""
373
+ if numerator == 0 and denominator == 0 :
374
+ return 1.0
375
+ elif numerator > 0 and denominator == 0 :
376
+ return 0.0
377
+ else :
378
+ return numerator / denominator
379
+
380
+
381
+ def calculate_fi (precision : float , recall : float ) -> float :
382
+ """Calculate F1 score from precision and recall"""
383
+ return 2 * (precision * recall ) / (precision + recall ) if precision + recall > 0 else 0.0
384
+
385
+
386
+ def evaluate_response_for_ir (
387
+ query : str , agent_response : str , ground_truth : int , agent_references : dict = {}
388
+ ) -> tuple [bool | None , str , float ]:
389
+ """Evaluate Khoj response against benchmark ground truth using string matching"""
390
+ try :
391
+ # Extract answer from agent response
392
+ referenced_files : list [dict [str , str ]] = agent_references .get ("context" , [])
393
+ count_of_correct_articles_used_by_agent : int = 0
394
+ # Count how many of the expected articles the agent actually retrieved from the KB
395
+ unique_file_refs = {file ["file" ] for file in referenced_files }
396
+ referenced_articles = list (map (extract_article_url_from_filename , unique_file_refs ))
397
+ for file in unique_file_refs :
398
+ frames_ids_for_articles_used_by_agent = extract_prompt_ids_from_filename (file )
399
+ count_of_correct_articles_used_by_agent += int (ground_truth in frames_ids_for_articles_used_by_agent )
400
+
401
+ articles = get_articles_by_prompt_id (ground_truth )
402
+ precision = calculate_precision_recall (count_of_correct_articles_used_by_agent , len (unique_file_refs ))
403
+ recall = calculate_precision_recall (count_of_correct_articles_used_by_agent , len (articles ))
404
+ f1 = calculate_fi (precision , recall )
405
+
406
+ explanation = (
407
+ f"Information Retrieval F1 Score: { f1 :.2%} Recall: { recall :.2%} , Precision: { precision :.2%} .\n "
408
+ f"{ count_of_correct_articles_used_by_agent } of { len (articles )} correct from { len (unique_file_refs )} total retrievals for { ground_truth } .\n "
409
+ f"Queries:\n { yaml .dump (sorted ([r ['query' ] for r in referenced_files ]))} \n "
410
+ f"Expected Articles for { ground_truth } :\n { yaml .dump (sorted ([a ['link' ] for a in articles ]))} \n "
411
+ f"Retrieved Articles for { ground_truth } :\n { yaml .dump (referenced_articles )} \n "
412
+ )
413
+
414
+ # Truncate referenced files for logging
415
+ truncated_refs = [
416
+ {k : v [:200 ] + "..." if len (v ) > 200 else v for k , v in ref .items ()} for ref in referenced_files
417
+ ]
418
+ logger .info (f"Retrieved Article Details:\n { yaml .dump (truncated_refs , sort_keys = False )} \n " )
419
+
420
+ # Return decision, explanation and cost in structured form
421
+ return recall , explanation , 0.0
422
+ except Exception as e :
423
+ logger .error (f"Error in IR evaluation: { e } " )
424
+ return None , f"Evaluation failed: { str (e )} " , 0.0
425
+
426
+
251
427
def evaluate_response_with_mcq_match (
252
428
query : str , agent_response : str , ground_truth : str , agent_references : dict = {}
253
429
) -> tuple [bool | None , str , float ]:
@@ -417,7 +593,7 @@ def parse_args():
417
593
"--dataset" ,
418
594
"-d" ,
419
595
default = "frames" ,
420
- choices = ["frames" , "simpleqa" , "gpqa" , "math500" ],
596
+ choices = ["frames" , "frames_ir" , " simpleqa" , "gpqa" , "math500" ],
421
597
help = "Dataset to use for evaluation (default: frames)" ,
422
598
)
423
599
return parser .parse_args ()
@@ -438,6 +614,12 @@ def main():
438
614
dataset = load_gpqa_dataset ()
439
615
elif args .dataset == "math500" :
440
616
dataset = load_math500_dataset ()
617
+ elif args .dataset == "frames_ir" :
618
+ indexed = index_frames_kb ()
619
+ if indexed :
620
+ dataset = load_frames_dataset ()
621
+ # Rename the index field, 'Unnamed: 0' to 'Answer' for IR evaluation
622
+ dataset ["Answer" ] = dataset ["Unnamed: 0" ]
441
623
if dataset is None :
442
624
return
443
625
@@ -450,6 +632,8 @@ def main():
450
632
response_evaluator = partial (
451
633
evaluate_response_with_gemini , eval_model = os .getenv ("GEMINI_EVAL_MODEL" , "gemini-1.5-flash-002" )
452
634
)
635
+ elif args .dataset == "frames_ir" :
636
+ response_evaluator = evaluate_response_for_ir
453
637
else :
454
638
response_evaluator = evaluate_response_with_gemini
455
639
0 commit comments