Skip to content

Commit e5548b7

Browse files
authored
Better error message for lm_eval script (#444)
* Better error message * change error message and make output prettier
1 parent 77c9304 commit e5548b7

File tree

1 file changed

+29
-7
lines changed

1 file changed

+29
-7
lines changed

scripts/hf_eval.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import torch
2+
from tabulate import tabulate
23

34
from transformers import AutoModelForCausalLM, AutoTokenizer
4-
5-
from lm_eval.models.huggingface import HFLM
6-
from lm_eval.evaluator import evaluate
7-
from lm_eval.tasks import get_task_dict
5+
try:
6+
from lm_eval.models.huggingface import HFLM
7+
from lm_eval.evaluator import evaluate
8+
from lm_eval.tasks import get_task_dict
9+
except ImportError as e:
10+
print("""
11+
Error: The 'lm_eval' module was not found.
12+
To install, follow these steps:
13+
pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git
14+
""")
15+
raise # Re-raise the ImportError
816

917
from torchao.quantization.quant_api import (
1018
change_linear_weights_to_int4_woqtensors,
@@ -16,6 +24,21 @@
1624
torch._inductor.config.force_fuse_int_mm_with_mul = True
1725
torch._inductor.config.fx_graph_cache = True
1826

27+
def pretty_print_nested_results(results, precision: int = 6):
28+
def format_value(value):
29+
if isinstance(value, float):
30+
return f"{value:.{precision}f}"
31+
return value
32+
33+
main_table = []
34+
for task, metrics in results["results"].items():
35+
subtable = [[k, format_value(v)] for k, v in metrics.items() if k != 'alias']
36+
subtable.sort(key=lambda x: x[0]) # Sort metrics alphabetically
37+
formatted_subtable = tabulate(subtable, tablefmt='grid')
38+
main_table.append([task, formatted_subtable])
39+
40+
print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))
41+
1942
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):
2043

2144
tokenizer = AutoTokenizer.from_pretrained(repo_id)
@@ -33,7 +56,6 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
3356
change_linear_weights_to_int4_woqtensors(model.to(device=device))
3457
elif quantization == "autoquant":
3558
model = autoquant(model.to(device=device))
36-
3759
with torch.no_grad():
3860
result = evaluate(
3961
HFLM(
@@ -44,8 +66,8 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
4466
get_task_dict(tasks),
4567
limit = limit,
4668
)
47-
for task, res in result["results"].items():
48-
print(f"{task}: {res}")
69+
70+
pretty_print_nested_results(result)
4971

5072

5173
if __name__ == '__main__':

0 commit comments

Comments
 (0)