Skip to content

Commit 7ff5c3c

Browse files
authored
args
1 parent cd009bd commit 7ff5c3c

File tree

1 file changed

+52
-30
lines changed

1 file changed

+52
-30
lines changed

src/training/vqa_fine_tune.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515
import evaluate
1616
from sklearn import preprocessing
1717
import numpy as np
18+
import sys
1819

1920
from datasets import load_dataset_builder
2021
from datasets import load_dataset
2122

2223
class VQATextDataset(Dataset):
23-
def __init__(self, df, split, transforms, labelencoder, tokenizer=None):
24+
def __init__(self, df, split, transforms, answer_set, tokenizer=None):
2425
self.df = df
2526
self.transforms = transforms
2627
self.tokenize = tokenizer
27-
self.labels = labelencoder.transform(df['multiple_choice_answer'])
28+
self.num_classes = len(answer_set)
2829
def __len__(self):
2930
return len(self.df)
3031

@@ -34,13 +35,20 @@ def __getitem__(self, idx):
3435
image = Image.open(str(img_path))
3536
text = item["question"]
3637
label = self.labels[idx]
38+
target = np.zeros(self.num_classes)
39+
for i in range(df['answer_list']):
40+
target[df['answer_list'][i]] = df['answer_weights'][i]
41+
3742
return {
3843
'image': self.transforms(image),
3944
'text': self.tokenize([text])[0],
40-
'label': torch.tensor(label)
45+
'target': torch.tensor(target)
4146
}
4247

43-
def get_task_dataloaders(path, transforms, labelencoder, args):
48+
def get_score(count: int) -> float:
49+
return min(1.0, count / 3)
50+
51+
def get_task_dataloaders(path, transforms, labelencoder, answer_set, args):
4452
tokenizer = get_tokenizer(args.model)
4553
dataloaders = {}
4654

@@ -52,29 +60,43 @@ def get_task_dataloaders(path, transforms, labelencoder, args):
5260
questions = []
5361
images = []
5462
answers = []
63+
weights = []
5564
for index, row in dataset_df.iterrows():
56-
if(row['multiple_choice_answer'] in answer_set):
65+
answer_count = {}
66+
for answer in row['answers']:
67+
answer_ = answer["answer"]
68+
answer_count[answer_] = answer_count.get(answer_, 0) + 1
69+
labels = []
70+
scores = []
71+
for answer in answer_count:
72+
if answer not in answer_set:
73+
continue
74+
labels.append(labelencoder.transform([answer])[0])
75+
score = get_score(answer_count[answer])
76+
scores.append(score)
77+
if(len(labels) == 0):
78+
continue
5779
class_id.append(row['question_id'])
5880
questions.append(row['question'])
5981
images.append(row['image'])
60-
answers.append(row['multiple_choice_answer'])
82+
answers.append(labels)
83+
weights.append(scores)
84+
6185
class_id = np.array(class_id)
6286
questions = np.array(questions)
6387
images = np.array(images)
64-
answers = np.array(answers)
65-
66-
dataset_df = pd.DataFrame({'question_id': class_id, 'question': questions, 'image': images, 'multiple_choice_answer': answers})
88+
dataset_df = pd.DataFrame({'question_id': class_id, 'question': questions, 'image': images, 'answer_list': answers, 'answer_weights': weights})
6789
#dataset_df = dataset_df[0:12800]
6890
b_size = args.batch_size
6991
if(split == "validation"):
7092
b_size = args.batch_size * 20
7193
dataset_df = dataset_df[0:12800]
72-
dataset = VQATextDataset(dataset_df,
73-
split,
74-
transforms,
75-
labelencoder,
76-
tokenizer=tokenizer,
77-
)
94+
dataset = VQATextDataset(dataset_df,
95+
split,
96+
transforms,
97+
answer_set,
98+
tokenizer=tokenizer,
99+
)
78100
dataloader = DataLoader(
79101
dataset,
80102
batch_size=b_size,
@@ -95,7 +117,7 @@ def __init__(self, encoder, embed_dim, num_labels):
95117

96118
self.fc1 = nn.Linear(embed_dim * 2, 1536) #size of answer space
97119
self.lnorm = nn.LayerNorm(1536)
98-
self.fc2 = nn.Linear(1536, num_classes)
120+
self.fc2 = nn.Linear(1536, num_labels)
99121
def forward(self, image, text):
100122
# CLIP doesn't have a multimodal encoder, so we concatenate the features
101123
text_features = self.encoder.encode_text(text)
@@ -136,16 +158,15 @@ def compute_metrics(model, dataloader, device, args):
136158
metric = evaluate.load("accuracy")
137159
val_loss = 0
138160
samples_seen = 0
139-
loss_fn = nn.CrossEntropyLoss()
140161
for batch in dataloader:
141162
with torch.no_grad():
142163
image = batch["image"].to(device)
143164
text = batch["text"].to(device)
144-
label = batch["label"].to(device)
165+
label = batch["target"].to(device)
145166
samples_seen += text.shape[0]
146167
logits = model(image, text)
147168
predictions = torch.argmax(logits, dim=-1)
148-
batch_val_loss = loss_fn(logits, label)
169+
batch_val_loss = nn.functional.binary_cross_entropy_with_logits(logits, label, reduction="sum")
149170
val_loss += batch_val_loss.item()
150171
print(val_loss)
151172
metric.add_batch(
@@ -164,31 +185,29 @@ def train_single_epoch(model, data, optimizer, args):
164185
for i, batch in enumerate(data["train"]):
165186
image = batch["image"].to(device)
166187
text = batch["text"].to(device)
167-
label = batch["label"].to(device)
188+
label = batch["target"].to(device)
168189

169190
logits = model(image, text)
170191
print(label.shape)
171192
print(logits.shape)
172-
loss_fn = nn.CrossEntropyLoss()
173-
loss = loss_fn(logits, label)
193+
loss = nn.functional.binary_cross_entropy_with_logits(logits, label, reduction="sum")
174194
print(loss)
175195
loss.backward()
176196

177197

178198
def train_one_epoch(model, data, epoch, optimizer, scheduler, early_stop, device, args):
179199
model.train()
180-
loss_fn = nn.CrossEntropyLoss()
181200
progress_bar = tqdm(total=len(data["train"]))
182201
for i, batch in enumerate(data["train"]):
183202
step = epoch * len(data["train"]) + i
184203
scheduler(step)
185204

186205
image = batch["image"].to(device)
187206
text = batch["text"].to(device)
188-
label = batch["label"].to(device)
207+
label = batch["target"].to(device)
189208
logits = model(image, text)
190209

191-
loss = loss_fn(logits, label) #should be cross entropy
210+
loss = nn.functional.binary_cross_entropy_with_logits(logits, label, reduction = "sum") #should be cross entropy
192211

193212
optimizer.zero_grad()
194213
loss.backward()
@@ -228,7 +247,7 @@ def parse_args(args):
228247
parser.add_argument(
229248
"--epochs", type=int, default=10, help="Number of epochs to train for."
230249
)
231-
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate.")
250+
parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
232251
parser.add_argument("--beta1", type=float, default=0.9, help="Adam beta 1.")
233252
parser.add_argument("--beta2", type=float, default=0.999, help="Adam beta 2.")
234253
parser.add_argument("--eps", type=float, default=1e-8, help="Adam epsilon.")
@@ -273,8 +292,8 @@ def parse_args(args):
273292
args = parser.parse_args(args)
274293
return args
275294

276-
if __name__ == "__main__":
277-
args = parse_args([])
295+
def main(args):
296+
args = parse_args(args)
278297
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
279298

280299
model, preprocess_train, preprocess_val = open_clip.factory.create_model_and_transforms(
@@ -287,7 +306,7 @@ def parse_args(args):
287306
embed_dim = model_cfg["embed_dim"]
288307

289308
answer_space = []
290-
with open('answers_vqa.txt') as f:
309+
with open('src/training/answers_vqa.txt') as f:
291310
for line in f:
292311
answer_space.append(line.strip())
293312
answer_space = np.array(answer_space)
@@ -298,7 +317,7 @@ def parse_args(args):
298317

299318
answer_set = set(labelencoder.classes_)
300319

301-
data = get_task_dataloaders("HuggingFaceM4/VQAv2", preprocess_val, labelencoder, args)
320+
data = get_task_dataloaders("HuggingFaceM4/VQAv2", preprocess_val, labelencoder, answer_set, args)
302321

303322
clf_cls = CLIPMultimodalClassifier
304323
clf = clf_cls(model, embed_dim, num_classes).to(device)
@@ -314,3 +333,6 @@ def parse_args(args):
314333

315334
for epoch in range(20):
316335
val_metrics, end_training = train_one_epoch(clf, data, epoch, optim, scheduler, early_stop, device, args)
336+
337+
if __name__ == "__main__":
338+
main(sys.argv[1:])

0 commit comments

Comments
 (0)