Skip to content

Commit 76d6258

Browse files
authored
[Model] Add support for varco-vision-2.0-14b (#1177)
* add varco vision * lint
1 parent a2c1913 commit 76d6258

File tree

3 files changed

+277
-1
lines changed

3 files changed

+277
-1
lines changed

vlmeval/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,9 +720,15 @@
720720
"llava_video_qwen2_72b": partial(
721721
LLaVA_OneVision, model_path="lmms-lab/LLaVA-Video-72B-Qwen2"
722722
),
723+
}
724+
725+
varco_vision_series = {
723726
"varco-vision-hf": partial(
724727
LLaVA_OneVision_HF, model_path="NCSOFT/VARCO-VISION-14B-HF"
725728
),
729+
"varco-vision-2-14b": partial(
730+
VarcoVision, model_path="NCSOFT/VARCO-VISION-2.0-14B"
731+
),
726732
}
727733

728734
vita_series = {
@@ -1528,7 +1534,7 @@
15281534
aria_series, smolvlm_series, sail_series, valley_series, vita_series,
15291535
ross_series, emu_series, ola_series, ursa_series, gemma_series,
15301536
long_vita_series, ristretto_series, kimi_series, aguvis_series, hawkvl_series,
1531-
flash_vl, kimi_vllm_series, oryx_series, treevgr_series
1537+
flash_vl, kimi_vllm_series, oryx_series, treevgr_series, varco_vision_series
15321538
]
15331539

15341540
for grp in model_groups:

vlmeval/vlm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,4 @@
102102
from .oryx import Oryx
103103
from .treevgr import TreeVGR
104104
from .glm4_1v import GLM4_1v
105+
from .varco_vision import VarcoVision

vlmeval/vlm/varco_vision.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import pandas as pd
2+
import string
3+
import torch
4+
from PIL import Image
5+
from .base import BaseModel
6+
from ..smp import *
7+
from ..dataset import DATASET_TYPE, DATASET_MODALITY
8+
9+
10+
class VarcoVision(BaseModel):
11+
INSTALL_REQ = True
12+
INTERLEAVE = True
13+
VIDEO_LLM = True
14+
DEFAULT_IMAGE_TOKEN = "<image>"
15+
IMAGE_TOKEN_INDEX = -200
16+
17+
def __init__(self, model_path="NCSOFT/VARCO-VISION-2.0-14B", **kwargs):
18+
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
19+
assert model_path is not None, "Model path must be provided."
20+
self.model = LlavaOnevisionForConditionalGeneration.from_pretrained(
21+
model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
22+
).to('cuda')
23+
self.processor = AutoProcessor.from_pretrained(model_path)
24+
25+
id_prompt = "You are VARCO-VISION, created by NC AI. "
26+
self.processor.chat_template = self.processor.chat_template.replace(id_prompt, "")
27+
self.processor.tokenizer.chat_template = self.processor.tokenizer.chat_template.replace(id_prompt, "")
28+
29+
self.video_kwargs = kwargs.get("video_kwargs", {})
30+
self.force_sample = self.video_kwargs.get("force_sample", False)
31+
self.nframe = kwargs.get("nframe", 8)
32+
self.fps = 1
33+
self.model_path = model_path
34+
35+
def set_ratio(self, n):
36+
config = self.model.config
37+
processor = self.processor
38+
processor.vision_aspect_ratio = config.vision_aspect_ratio = f"anyres_max_{n}"
39+
40+
def set_grid(self, n, reduced=False):
41+
config = self.model.config
42+
image_processor = self.processor.image_processor
43+
size = min(image_processor.size.values())
44+
grid = []
45+
for i in range(1, n + 1):
46+
for j in range(1, n + 1):
47+
if reduced:
48+
if i * j <= n and i != n and j != n:
49+
grid.append([i * size, j * size])
50+
else:
51+
grid.append([i * size, j * size])
52+
image_processor.image_grid_pinpoints = config.image_grid_pinpoints = grid
53+
54+
def set_res(self, dataset):
55+
res_4_datasets = [
56+
'ChartQA_TEST', 'MMMU_DEV_VAL', 'MMMU_TEST',
57+
'MME-RealWorld', 'VCR_EN', 'VCR_ZH', 'OCRVQA',
58+
'BMMR', 'MMStar', 'HallusionBench', 'MMVet',
59+
'AI2D_MINI', 'AI2D_TEST', 'AI2D_TEST_NO_MASK']
60+
res_16_datasets = [
61+
'InfoVQA_VAL', 'InfoVQA_TEST', 'OCRBench',
62+
'HRBench4K', 'HRBench8K', 'MathVista', 'LLaVABench']
63+
self.set_ratio(9)
64+
self.set_grid(6)
65+
if listinstr(res_4_datasets, dataset):
66+
self.set_ratio(4)
67+
self.set_grid(4, reduced=True)
68+
elif listinstr(res_16_datasets, dataset):
69+
self.set_ratio(16)
70+
self.set_grid(8)
71+
72+
def use_custom_prompt(self, dataset):
73+
if DATASET_TYPE(dataset) == 'Y/N':
74+
return True
75+
if DATASET_TYPE(dataset) == 'MCQ':
76+
return True
77+
if DATASET_TYPE(dataset) == 'VQA' and not dataset.startswith('OCRBench'):
78+
return True
79+
return False
80+
81+
def build_prompt(self, line, dataset=None):
82+
assert self.use_custom_prompt(dataset)
83+
assert isinstance(dataset, str)
84+
tgt_path = self.dump_image(line, dataset)
85+
86+
if dataset.startswith('MathVista_'):
87+
prompt = self.build_mathvista_prompt(line, dataset)
88+
elif dataset.startswith('MMMU_'):
89+
prompt = self.build_mmmu_prompt(line, dataset)
90+
elif DATASET_TYPE(dataset) == 'Y/N':
91+
prompt = self.build_yorn_prompt(line, dataset)
92+
elif DATASET_TYPE(dataset) == 'MCQ':
93+
prompt = self.build_multi_choice_prompt(line, dataset)
94+
elif DATASET_TYPE(dataset) == 'VQA':
95+
prompt = self.build_vqa_prompt(line, dataset)
96+
else:
97+
raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}')
98+
message = []
99+
message.extend([dict(type='image', value=s) for s in tgt_path])
100+
message.append(dict(type='text', value=prompt))
101+
102+
# interleave dataset
103+
if dataset.startswith('MMMU_'):
104+
from .. import MMMUDataset
105+
message = MMMUDataset.split_MMMU(message)
106+
107+
return message
108+
109+
def build_yorn_prompt(self, line, dataset=None):
110+
prompt = line['question']
111+
prompt += '\nAnswer the question using a single word or phrase.'
112+
return prompt
113+
114+
def build_multi_choice_prompt(self, line, dataset=None):
115+
options = {
116+
cand: line[cand]
117+
for cand in string.ascii_uppercase
118+
if cand in line and not pd.isna(line[cand])
119+
}
120+
121+
hint = ''
122+
if 'hint' in line and not pd.isna(line['hint']):
123+
hint = f"{line['hint']}\n"
124+
elif options:
125+
hint = 'Make sure your answer is in the given choice list.\n'
126+
127+
prompt = f"{hint}{line['question']}"
128+
if options:
129+
options_prompt = ''
130+
for key, item in options.items():
131+
options_prompt += f'\n{key}. {item}'
132+
prompt += f"{options_prompt}\nAnswer with the option's letter directly."
133+
else:
134+
prompt += '\nAnswer the question directly.'
135+
return prompt
136+
137+
def build_mathvista_prompt(self, line, dataset=None):
138+
prompt = line['question']
139+
if 'Choices:' in prompt:
140+
for i in range(1, 7):
141+
prompt = prompt.replace(f'({chr(64 + i)})', f'{chr(64 + i)}.')
142+
else:
143+
prompt += '\nAnswer the question directly.'
144+
return prompt
145+
146+
def build_mmmu_prompt(self, line, dataset=None):
147+
options = {
148+
cand: line[cand]
149+
for cand in string.ascii_uppercase
150+
if cand in line and not pd.isna(line[cand])
151+
}
152+
153+
hint = ''
154+
if 'hint' in line and not pd.isna(line['hint']):
155+
hint = f"Hint: {line['hint']}\n"
156+
157+
prompt = f"{hint}Question: {line['question']}"
158+
if options:
159+
options_prompt = '\nOptions:'
160+
for key, item in options.items():
161+
options_prompt += f'\n{key}. {item}'
162+
prompt += f'{options_prompt}\nAnswer the preceding question.'
163+
else:
164+
prompt += ' Preserve details.'
165+
return prompt
166+
167+
def build_vqa_prompt(self, line, dataset=None):
168+
prompt = line['question']
169+
prompt += ' Preserve details.'
170+
return prompt
171+
172+
def generate_inner_image(self, message, dataset=None):
173+
content, images = "", []
174+
image_sizes = []
175+
176+
for msg in message:
177+
if msg["type"] == "text":
178+
content += msg["value"]
179+
elif msg["type"] == "image":
180+
img = Image.open(msg["value"]).convert("RGB")
181+
images.append(img)
182+
image_sizes.append(img.size)
183+
content += f"{self.DEFAULT_IMAGE_TOKEN}\n"
184+
185+
conversation = [
186+
{
187+
"role": "user",
188+
"content": [
189+
{"type": "text", "text": content},
190+
],
191+
}
192+
]
193+
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
194+
inputs = self.processor(images=images, text=prompt, return_tensors="pt").to('cuda', torch.float16)
195+
196+
output = self.model.generate(**inputs, max_new_tokens=512)
197+
return self.processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
198+
199+
def generate_inner_video(self, message, dataset=None):
200+
content, text_content, visual_content, videos = "", "", "", []
201+
202+
for msg in message:
203+
if msg["type"] == "text":
204+
text_content += msg["value"]
205+
elif msg["type"] == "video":
206+
videos.append(msg["value"])
207+
visual_content += f"{self.DEFAULT_IMAGE_TOKEN}\n"
208+
209+
if len(videos) > 1:
210+
raise ValueError("LLaVA-OneVision does not support multiple videos as input.")
211+
212+
video_frames, frame_time, video_time = self.load_video(
213+
videos[0], self.nframe, fps=1, force_sample=self.force_sample
214+
)
215+
216+
time_instruction = (
217+
f"The video lasts for {video_time:.2f} seconds, "
218+
f"and {len(video_frames)} frames are uniformly sampled from it. "
219+
f"These frames are located at {frame_time}. "
220+
f"Please answer the following questions related to this video.\n"
221+
)
222+
223+
content = visual_content + time_instruction + text_content
224+
conversation = [
225+
{
226+
"role": "user",
227+
"content": [{"type": "text", "text": content}, {"type": "video"}],
228+
}
229+
]
230+
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
231+
232+
inputs = self.processor(videos=video_frames, text=prompt, return_tensors="pt").to('cuda', torch.float16)
233+
output = self.model.generate(**inputs, max_new_tokens=512)
234+
return self.processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
235+
236+
def load_video(self, video_path, max_frames_num, fps=1, force_sample=False):
237+
from decord import VideoReader, cpu
238+
import numpy as np
239+
240+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
241+
total_frame_num = len(vr)
242+
avg_fps = vr.get_avg_fps()
243+
244+
if avg_fps == 0:
245+
raise ValueError(f"Video '{video_path}' has an average FPS of 0, which is invalid.")
246+
if fps <= 0:
247+
raise ValueError("FPS argument must be greater than 0.")
248+
249+
effective_fps = round(avg_fps / fps)
250+
frame_idx = list(range(0, total_frame_num, effective_fps))
251+
frame_time = [i / avg_fps for i in frame_idx]
252+
253+
if len(frame_idx) > max_frames_num or force_sample:
254+
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
255+
frame_idx = uniform_sampled_frames.tolist()
256+
frame_time = [i / avg_fps for i in frame_idx]
257+
258+
frame_time_str = ", ".join([f"{t:.2f}s" for t in frame_time])
259+
video_frames = vr.get_batch(frame_idx).asnumpy()
260+
video_time = total_frame_num / avg_fps
261+
262+
return video_frames, frame_time_str, video_time
263+
264+
def generate_inner(self, message, dataset=None):
265+
self.set_res(dataset)
266+
if DATASET_MODALITY(dataset) == "VIDEO" and "megabench" not in dataset.lower():
267+
return self.generate_inner_video(message, dataset)
268+
else:
269+
return self.generate_inner_image(message, dataset)

0 commit comments

Comments
 (0)