15
15
import evaluate
16
16
from sklearn import preprocessing
17
17
import numpy as np
18
+ import sys
18
19
19
20
from datasets import load_dataset_builder
20
21
from datasets import load_dataset
21
22
22
23
class VQATextDataset (Dataset ):
23
- def __init__ (self , df , split , transforms , labelencoder , tokenizer = None ):
24
+ def __init__ (self , df , split , transforms , answer_set , tokenizer = None ):
24
25
self .df = df
25
26
self .transforms = transforms
26
27
self .tokenize = tokenizer
27
- self .labels = labelencoder . transform ( df [ 'multiple_choice_answer' ] )
28
+ self .num_classes = len ( answer_set )
28
29
def __len__ (self ):
29
30
return len (self .df )
30
31
@@ -34,13 +35,20 @@ def __getitem__(self, idx):
34
35
image = Image .open (str (img_path ))
35
36
text = item ["question" ]
36
37
label = self .labels [idx ]
38
+ target = np .zeros (self .num_classes )
39
+ for i in range (df ['answer_list' ]):
40
+ target [df ['answer_list' ][i ]] = df ['answer_weights' ][i ]
41
+
37
42
return {
38
43
'image' : self .transforms (image ),
39
44
'text' : self .tokenize ([text ])[0 ],
40
- 'label ' : torch .tensor (label )
45
+ 'target ' : torch .tensor (target )
41
46
}
42
47
43
- def get_task_dataloaders (path , transforms , labelencoder , args ):
48
+ def get_score (count : int ) -> float :
49
+ return min (1.0 , count / 3 )
50
+
51
+ def get_task_dataloaders (path , transforms , labelencoder , answer_set , args ):
44
52
tokenizer = get_tokenizer (args .model )
45
53
dataloaders = {}
46
54
@@ -52,29 +60,43 @@ def get_task_dataloaders(path, transforms, labelencoder, args):
52
60
questions = []
53
61
images = []
54
62
answers = []
63
+ weights = []
55
64
for index , row in dataset_df .iterrows ():
56
- if (row ['multiple_choice_answer' ] in answer_set ):
65
+ answer_count = {}
66
+ for answer in row ['answers' ]:
67
+ answer_ = answer ["answer" ]
68
+ answer_count [answer_ ] = answer_count .get (answer_ , 0 ) + 1
69
+ labels = []
70
+ scores = []
71
+ for answer in answer_count :
72
+ if answer not in answer_set :
73
+ continue
74
+ labels .append (labelencoder .transform ([answer ])[0 ])
75
+ score = get_score (answer_count [answer ])
76
+ scores .append (score )
77
+ if (len (labels ) == 0 ):
78
+ continue
57
79
class_id .append (row ['question_id' ])
58
80
questions .append (row ['question' ])
59
81
images .append (row ['image' ])
60
- answers .append (row ['multiple_choice_answer' ])
82
+ answers .append (labels )
83
+ weights .append (scores )
84
+
61
85
class_id = np .array (class_id )
62
86
questions = np .array (questions )
63
87
images = np .array (images )
64
- answers = np .array (answers )
65
-
66
- dataset_df = pd .DataFrame ({'question_id' : class_id , 'question' : questions , 'image' : images , 'multiple_choice_answer' : answers })
88
+ dataset_df = pd .DataFrame ({'question_id' : class_id , 'question' : questions , 'image' : images , 'answer_list' : answers , 'answer_weights' : weights })
67
89
#dataset_df = dataset_df[0:12800]
68
90
b_size = args .batch_size
69
91
if (split == "validation" ):
70
92
b_size = args .batch_size * 20
71
93
dataset_df = dataset_df [0 :12800 ]
72
- dataset = VQATextDataset (dataset_df ,
73
- split ,
74
- transforms ,
75
- labelencoder ,
76
- tokenizer = tokenizer ,
77
- )
94
+ dataset = VQATextDataset (dataset_df ,
95
+ split ,
96
+ transforms ,
97
+ answer_set ,
98
+ tokenizer = tokenizer ,
99
+ )
78
100
dataloader = DataLoader (
79
101
dataset ,
80
102
batch_size = b_size ,
@@ -95,7 +117,7 @@ def __init__(self, encoder, embed_dim, num_labels):
95
117
96
118
self .fc1 = nn .Linear (embed_dim * 2 , 1536 ) #size of answer space
97
119
self .lnorm = nn .LayerNorm (1536 )
98
- self .fc2 = nn .Linear (1536 , num_classes )
120
+ self .fc2 = nn .Linear (1536 , num_labels )
99
121
def forward (self , image , text ):
100
122
# CLIP doesn't have a multimodal encoder, so we concatenate the features
101
123
text_features = self .encoder .encode_text (text )
@@ -136,16 +158,15 @@ def compute_metrics(model, dataloader, device, args):
136
158
metric = evaluate .load ("accuracy" )
137
159
val_loss = 0
138
160
samples_seen = 0
139
- loss_fn = nn .CrossEntropyLoss ()
140
161
for batch in dataloader :
141
162
with torch .no_grad ():
142
163
image = batch ["image" ].to (device )
143
164
text = batch ["text" ].to (device )
144
- label = batch ["label " ].to (device )
165
+ label = batch ["target " ].to (device )
145
166
samples_seen += text .shape [0 ]
146
167
logits = model (image , text )
147
168
predictions = torch .argmax (logits , dim = - 1 )
148
- batch_val_loss = loss_fn (logits , label )
169
+ batch_val_loss = nn . functional . binary_cross_entropy_with_logits (logits , label , reduction = "sum" )
149
170
val_loss += batch_val_loss .item ()
150
171
print (val_loss )
151
172
metric .add_batch (
@@ -164,31 +185,29 @@ def train_single_epoch(model, data, optimizer, args):
164
185
for i , batch in enumerate (data ["train" ]):
165
186
image = batch ["image" ].to (device )
166
187
text = batch ["text" ].to (device )
167
- label = batch ["label " ].to (device )
188
+ label = batch ["target " ].to (device )
168
189
169
190
logits = model (image , text )
170
191
print (label .shape )
171
192
print (logits .shape )
172
- loss_fn = nn .CrossEntropyLoss ()
173
- loss = loss_fn (logits , label )
193
+ loss = nn .functional .binary_cross_entropy_with_logits (logits , label , reduction = "sum" )
174
194
print (loss )
175
195
loss .backward ()
176
196
177
197
178
198
def train_one_epoch (model , data , epoch , optimizer , scheduler , early_stop , device , args ):
179
199
model .train ()
180
- loss_fn = nn .CrossEntropyLoss ()
181
200
progress_bar = tqdm (total = len (data ["train" ]))
182
201
for i , batch in enumerate (data ["train" ]):
183
202
step = epoch * len (data ["train" ]) + i
184
203
scheduler (step )
185
204
186
205
image = batch ["image" ].to (device )
187
206
text = batch ["text" ].to (device )
188
- label = batch ["label " ].to (device )
207
+ label = batch ["target " ].to (device )
189
208
logits = model (image , text )
190
209
191
- loss = loss_fn (logits , label ) #should be cross entropy
210
+ loss = nn . functional . binary_cross_entropy_with_logits (logits , label , reduction = "sum" ) #should be cross entropy
192
211
193
212
optimizer .zero_grad ()
194
213
loss .backward ()
@@ -228,7 +247,7 @@ def parse_args(args):
228
247
parser .add_argument (
229
248
"--epochs" , type = int , default = 10 , help = "Number of epochs to train for."
230
249
)
231
- parser .add_argument ("--lr" , type = float , default = 3e-4 , help = "Learning rate." )
250
+ parser .add_argument ("--lr" , type = float , default = 1e-5 , help = "Learning rate." )
232
251
parser .add_argument ("--beta1" , type = float , default = 0.9 , help = "Adam beta 1." )
233
252
parser .add_argument ("--beta2" , type = float , default = 0.999 , help = "Adam beta 2." )
234
253
parser .add_argument ("--eps" , type = float , default = 1e-8 , help = "Adam epsilon." )
@@ -273,8 +292,8 @@ def parse_args(args):
273
292
args = parser .parse_args (args )
274
293
return args
275
294
276
- if __name__ == "__main__" :
277
- args = parse_args ([] )
295
+ def main ( args ) :
296
+ args = parse_args (args )
278
297
device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
279
298
280
299
model , preprocess_train , preprocess_val = open_clip .factory .create_model_and_transforms (
@@ -287,7 +306,7 @@ def parse_args(args):
287
306
embed_dim = model_cfg ["embed_dim" ]
288
307
289
308
answer_space = []
290
- with open ('answers_vqa.txt' ) as f :
309
+ with open ('src/training/ answers_vqa.txt' ) as f :
291
310
for line in f :
292
311
answer_space .append (line .strip ())
293
312
answer_space = np .array (answer_space )
@@ -298,7 +317,7 @@ def parse_args(args):
298
317
299
318
answer_set = set (labelencoder .classes_ )
300
319
301
- data = get_task_dataloaders ("HuggingFaceM4/VQAv2" , preprocess_val , labelencoder , args )
320
+ data = get_task_dataloaders ("HuggingFaceM4/VQAv2" , preprocess_val , labelencoder , answer_set , args )
302
321
303
322
clf_cls = CLIPMultimodalClassifier
304
323
clf = clf_cls (model , embed_dim , num_classes ).to (device )
@@ -314,3 +333,6 @@ def parse_args(args):
314
333
315
334
for epoch in range (20 ):
316
335
val_metrics , end_training = train_one_epoch (clf , data , epoch , optim , scheduler , early_stop , device , args )
336
+
337
+ if __name__ == "__main__" :
338
+ main (sys .argv [1 :])
0 commit comments