Skip to content

[Minor] Align MMMU evaluation method with official #966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions vlmeval/dataset/image_mcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .image_base import ImageBaseDataset
from .utils import build_judge, DEBUG_MESSAGE
from ..utils import track_progress_rich
from ..smp import *
import pandas as pd

Expand Down Expand Up @@ -348,6 +349,64 @@ def build_prompt(self, line):
msgs = self.split_MMMU(msgs)
return msgs

def evaluate(self, eval_file, **judge_kwargs):
from .utils.multiple_choice import (
mmmu_evaluation, report_acc
)
nproc = judge_kwargs.pop('nproc', 4)
suffix = eval_file.split('.')[-1]
model = judge_kwargs.get('model', 'exact_matching')
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
name_str_map = {'chatgpt-0125': 'openai', 'gpt-4-0125': 'gpt4'}
name_str = name_str_map[model] if model in name_str_map else model
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.{suffix}')
score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
tmp_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')

if model == 'exact_matching':
model = None
elif gpt_key_set():
model = build_judge(**judge_kwargs)
if not model.working():
warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
warnings.warn(DEBUG_MESSAGE)
model = None
else:
warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
model = None

data = load(eval_file)
lt = len(data)
lines = [data.iloc[i] for i in range(lt)]
tups = [(model, line, self.dataset_name) for line in lines]
indices = [line['index'] for line in lines]

ans = {}
if osp.exists(tmp_file):
ans = load(tmp_file)
tups = [x for x, i in zip(tups, indices) if i not in ans]
indices = [i for i in indices if i not in ans]

if len(indices):
_ = track_progress_rich(
mmmu_evaluation,
tups,
nproc=nproc,
chunksize=nproc,
keys=indices,
save=tmp_file,
)
ans = load(tmp_file)
for key, value in ans.items():
data.loc[data['index'] == key, 'hit'] = value['hit']
data.loc[data['index'] == key, 'log'] = value['log']
dump(data, result_file)

acc = report_acc(data)

dump(acc, score_file)
return acc


class MMMUProDataset(MMMUDataset):

Expand Down
149 changes: 149 additions & 0 deletions vlmeval/dataset/utils/multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,152 @@ def get_dimension_rating(data_path):
results[task]['Avg'] = acc_task
results['Overall'] = succ_all / sum_all
return results


def extract_numbers(string):
"""
Exact all forms of numbers from a string with regex.
"""
# Pattern for numbers with commas
pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
# Pattern for scientific notation
pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
# Pattern for simple numbers without commas
pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'

# Extract numbers with commas
numbers_with_commas = re.findall(pattern_commas, string)
# Extract numbers in scientific notation
numbers_scientific = re.findall(pattern_scientific, string)
# Extract simple numbers without commas
numbers_simple = re.findall(pattern_simple, string)

# Combine all extracted numbers
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
return all_numbers


def parse_open_response(response):
"""
Parse the prediction from the generated response.
Return a list of predicted strings or numbers.
"""
# content = content.strip("\n").strip(".").strip(" ")
def get_key_subresponses(response):
key_responses = []
response = response.strip().strip(".").lower()
sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
indicators_of_keys = ['could be ', 'so ', 'is ', 'thus ', 'therefore ', 'final ', 'answer ', 'result ']
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
indicators_of_keys.extend(['='])
# the shortest response that may contain the answer (tail part of the response)
shortest_key_response = None
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
# key_responses.append(resp.split(indicator)[1].strip())

if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
# pdb.set_trace()
key_responses = get_key_subresponses(response)

pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers(resp))

tmp_pred_list = []
for i in range(len(pred_list)):
tmp_pred_list.extend(normalize_str(pred_list[i]))
pred_list = tmp_pred_list

# remove duplicates
pred_list = list(set(pred_list))

return pred_list


def check_is_number(string):
"""
Check if the given string a number.
"""
try:
float(string.replace(',', ''))
return True
except ValueError:
# check if there's comma inside
return False


def normalize_str(string):
"""
Normalize the str to lower case and make them float numbers if possible.
"""
# check if characters in the string

# if number, numerize it.
string = string.strip()

is_number = check_is_number(string)

if is_number:
string = string.replace(',', '')
string = float(string)
# leave 2 decimal
string = round(string, 2)
return [string]
else: # it's likely to be a string
# lower it
string = string.lower()
if len(string) == 1:
return [" " + string, string + " "] # avoid trivial matches
return [string]


def mmmu_evaluation(model, line, dataset_name):
if 'question_type' in line and line['question_type'] == 'open':
hit = 0
match_log = 'Failed to match'
if isinstance(line['answer'], list):
# use float to avoid trivial matches
norm_answers = []
for answer in line['answer']:
norm_answers.extend(normalize_str(answer))
else:
norm_answers = normalize_str(line['answer'])
parsed_pred = parse_open_response(line['prediction'])
for pred in parsed_pred: # pred is already normalized in parse response phase
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
for norm_ans in norm_answers:
# only see if the string answer in the string pred
if isinstance(norm_ans, str) and norm_ans in pred:
if not hit:
hit = 1
match_log = 'answer in pred'
break
else: # it's a float number
if pred in norm_answers:
if not hit:
hit = 1
match_log = 'pred is float, hit the answer'
break
return dict(hit=hit, log=f'Match Log: {match_log}. ')
else:
res = extract_answer_from_item(model, line, dataset_name=dataset_name)
opt, match_log = res['opt'], res['log']
if opt == line['answer']:
return dict(hit=1, log=f'Match Log: {match_log}. ')
else:
return dict(hit=0, log=f'Match Log: {match_log}. ')