Skip to content

Commit e9df441

Browse files
committed
fix flake8 errors introduced by me
1 parent 496aec6 commit e9df441

File tree

3 files changed

+98
-47
lines changed

3 files changed

+98
-47
lines changed

run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def main():
393393
judge_kwargs['model'] = 'llama31-8b'
394394
elif listinstr(['VideoMMLU_QA', 'VideoMMLU_CAP'], dataset_name):
395395
judge_kwargs['model'] = 'qwen-72b'
396-
elif listinstr(['CAPTURE_real, CAPTURE_synthetic'], dataset_name):
396+
elif listinstr(['CAPTURE_real', 'CAPTURE_synthetic'], dataset_name):
397397
judge_kwargs['model'] = 'llama31-8b'
398398

399399
if RANK == 0:

vlmeval/dataset/image_vqa.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,12 +1940,12 @@ def build_prompt(self, line):
19401940

19411941

19421942
class CAPTURE(ImageBaseDataset):
1943-
TYPE = 'VQA'
1944-
DATASET_URL = {'CAPTURE_real': '',
1945-
'CAPTURE_synthetic': ''}
1946-
DATASET_MD5 = {'CAPTURE_real': '',
1943+
TYPE = ''
1944+
DATASET_URL = {'CAPTURE_real': '',
19471945
'CAPTURE_synthetic': ''}
1948-
1946+
DATASET_MD5 = {'CAPTURE_real': None,
1947+
'CAPTURE_synthetic': None}
1948+
19491949
def create_tsv_from_hf(self):
19501950
pass
19511951

@@ -1958,17 +1958,19 @@ def evaluate(self, eval_file, **judge_kwargs):
19581958
record_file = eval_file.replace(suffix, f'_{model}.{suffix}')
19591959
score_file = eval_file.replace(suffix, '_score.csv')
19601960
nproc = judge_kwargs.pop('nproc', 4)
1961-
system_prompt = "You are an answer extractor. When given someone's answer to some question, you will only extract their final number answer and will respond with just the number. If there is no exact number answer, respond with -1"
1962-
1961+
system_prompt = (
1962+
"You are an answer extractor. When given someone's answer to "
1963+
"some question, you will only extract their final number answer "
1964+
"and will respond with just the number. If there is no exact "
1965+
"number answer, respond with -1"
1966+
)
19631967
if not osp.exists(record_file):
19641968
data = load(eval_file)
19651969
model = build_judge(**judge_kwargs, system_prompt=system_prompt)
1966-
assert model.working(), ('CAPTURE evaluation requires a working {model}\n' + DEBUG_MESSAGE)
19671970
lt = len(data)
19681971
lines = [data.iloc[i] for i in range(lt)]
19691972
tups = [(model, line) for line in lines]
19701973

1971-
19721974
extracted_answers = track_progress_rich(
19731975
CAPTURE_atomeval,
19741976
tups,
@@ -1978,6 +1980,7 @@ def evaluate(self, eval_file, **judge_kwargs):
19781980
data['extracted_answer'] = extracted_answers
19791981
dump(data, record_file)
19801982

1981-
score = CAPTURE_smape(record_file)
1983+
data = load(record_file)
1984+
score = CAPTURE_smape(data)
19821985
dump(score, score_file)
1983-
return score
1986+
return score

vlmeval/dataset/utils/capture.py

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,90 +2,138 @@
22
import zipfile
33
import os
44
import json
5-
5+
import tqdm
66
from ...smp import *
77

8+
89
def create_csv_from_meta(meta_file, object_key, data_dir, out_file):
9-
with open(meta_file, 'r') as fp:
10+
with open(meta_file, "r") as fp:
1011
meta = json.load(fp)
1112

1213
data = []
13-
for entry in meta:
14-
image_file = entry['image_file']
15-
image_path = osp.join(f'{data_dir}/real_dataset', image_file)
14+
for entry in tqdm(meta):
15+
image_file = entry["image_file"]
16+
image_path = osp.join(data_dir, image_file)
1617
image = encode_image_file_to_base64(image_path)
17-
question = f"Count the exact number of {entry[object_key]} in the image. Assume the pattern of {entry[object_key]} continues behind any black box. Provide the total number of {entry[object_key]} as if the black box were not there. Only count {entry[object_key]} that are visible within the frame (or would be visible without the occluding box). If {entry[object_key]} are partially in the frame (i.e. if any part of {entry[object_key]} are visible), count it. If the {entry[object_key]} would be partially in the frame without the occluding box, count it."
18+
object_name = entry[object_key]
19+
question = (
20+
f"Count the exact number of {object_name} in the image. "
21+
f"Assume the pattern of {object_name} continues behind any "
22+
f"black box. Provide the total number of {object_name} as if "
23+
f"the black box were not there. Only count {object_name} that "
24+
f"are visible within the frame (or would be visible without "
25+
f"the occluding box). If {object_name} are partially in the "
26+
f"frame (i.e. if any part of {object_name} are visible), "
27+
f"count it. If the {object_name} would be partially in the "
28+
f"frame without the occluding box, count it."
29+
)
1830
answer = str(entry["ground_truth"])
19-
data.append(dict(image=image, question=question, answer=answer, image_file=image_file))
20-
df = pd.DataFrame(data).sort_values(by='image_file')
21-
df.to_csv(out_file, index=True, index_label='index')
31+
data.append(
32+
dict(
33+
image=image,
34+
question=question,
35+
answer=answer,
36+
image_file=image_file,
37+
)
38+
)
39+
df = pd.DataFrame(data).sort_values(by="image_file")
40+
df.to_csv(out_file, index=True, index_label="index", sep="\t")
41+
2242

2343
def create_tsv_real():
2444
data_root = LMUDataRoot()
25-
data_dir = osp.join(data_root, 'capture')
45+
data_dir = osp.join(data_root, "capture")
2646
os.makedirs(data_root, exist_ok=True)
27-
real_zip = hf_hub_download(repo_id="atinp/CAPTURe", filename="real_dataset.zip", repo_type="dataset")
28-
29-
with zipfile.ZipFile(real_zip, 'r') as zip_ref:
47+
real_zip = hf_hub_download(
48+
repo_id="atinp/CAPTURe",
49+
filename="real_dataset.zip",
50+
repo_type="dataset",
51+
)
52+
53+
with zipfile.ZipFile(real_zip, "r") as zip_ref:
3054
zip_ref.extractall(data_dir)
3155
# rename the extracted folder (originally called dataset) to real_dataset
3256
os.rename(f"{data_dir}/dataset", f"{data_dir}/real_dataset")
3357

34-
real_meta = hf_hub_download(repo_id="atinp/CAPTURe", filename="real_metadata.json", repo_type="dataset")
35-
out_file = os.path.join(data_root, 'CAPTURE_real.tsv')
36-
create_csv_from_meta(real_meta, 'object', data_dir, out_file)
58+
real_meta = hf_hub_download(
59+
repo_id="atinp/CAPTURe",
60+
filename="real_metadata.json",
61+
repo_type="dataset",
62+
)
63+
out_file = os.path.join(data_root, "CAPTURE_real.tsv")
64+
create_csv_from_meta(
65+
real_meta, "object", f"{data_dir}/real_dataset", out_file
66+
)
3767
return out_file
38-
68+
69+
3970
def create_tsv_synthetic():
40-
syn_zip = hf_hub_download(repo_id="atinp/CAPTURe", filename="synthetic_dataset.zip", repo_type="dataset")
71+
syn_zip = hf_hub_download(
72+
repo_id="atinp/CAPTURe",
73+
filename="synthetic_dataset.zip",
74+
repo_type="dataset",
75+
)
4176
data_root = LMUDataRoot()
42-
data_dir = osp.join(data_root, 'capture')
77+
data_dir = osp.join(data_root, "capture")
4378
os.makedirs(data_root, exist_ok=True)
44-
45-
with zipfile.ZipFile(syn_zip, 'r') as zip_ref:
79+
80+
with zipfile.ZipFile(syn_zip, "r") as zip_ref:
4681
zip_ref.extractall(data_dir)
47-
48-
synth_meta = hf_hub_download(repo_id="atinp/CAPTURe", filename="synthetic_metadata.json", repo_type="dataset")
49-
out_file = os.path.join(data_root, 'CAPTURE_synthetic.tsv')
50-
create_csv_from_meta(synth_meta, 'dot_shape', data_dir, out_file)
82+
83+
synth_meta = hf_hub_download(
84+
repo_id="atinp/CAPTURe",
85+
filename="synthetic_metadata.json",
86+
repo_type="dataset",
87+
)
88+
out_file = os.path.join(data_root, "CAPTURE_synthetic.tsv")
89+
create_csv_from_meta(
90+
synth_meta, "dot_shape", f"{data_dir}/synthetic_dataset", out_file
91+
)
5192
return out_file
5293

94+
5395
def safe_string_to_int(s):
5496
try:
5597
return int(s)
5698
except ValueError:
5799
return -1
58-
100+
101+
59102
def CAPTURE_atomeval(model, line):
60-
ans = model.generate_str(line['prediction'])
103+
ans = model.generate_str(line["prediction"])
61104
return safe_string_to_int(ans)
62105

106+
63107
def CAPTURE_smape(data):
64108
total_percentage_error = 0
65109
count = 0
66110
skip = 0
67111

68112
for i in range(len(data)):
69113
row = data.iloc[i]
70-
ground_truth = int(row['answer'])
71-
answer = row['extracted_answer']
114+
ground_truth = int(row["answer"])
115+
answer = row["extracted_answer"]
72116

73117
if answer == -1:
74-
skip+=1
75-
total_percentage_error+=100
76-
count+=1
118+
skip += 1
119+
total_percentage_error += 100
120+
count += 1
77121
continue
78122

79123
# Compute sMAPE (Symmetric Mean Absolute Percentage Error)
80124
numerator = abs(answer - ground_truth)
81-
denominator = (abs(answer) + abs(ground_truth))
125+
denominator = abs(answer) + abs(ground_truth)
82126
smape = (numerator / denominator) * 100
83127

84-
85128
# Add to total percentage error
86129
total_percentage_error += smape
87130
count += 1
88131

89132
# Calculate MAPE
90133
mape = total_percentage_error / count if count != 0 else 0
91-
return pd.DataFrame([dict(SMAPE=mape, skip=skip)])
134+
return pd.DataFrame([dict(SMAPE=mape, skip=skip)])
135+
136+
137+
if __name__ == "__main__":
138+
create_tsv_real()
139+
create_tsv_synthetic()

0 commit comments

Comments
 (0)