Skip to content

Commit 008bfc5

Browse files
authored
Update vqa_fine_tune.py
1 parent 55e1501 commit 008bfc5

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/training/vqa_fine_tune.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def get_task_dataloaders(path, transforms, labelencoder, args):
6565

6666
dataset_df = pd.DataFrame({'question_id': class_id, 'question': questions, 'image': images, 'multiple_choice_answer': answers})
6767
#dataset_df = dataset_df[0:12800]
68-
68+
b_size = args.batch_size
69+
if(split == "validation"):
70+
b_size = args.batch_size
6971
dataset = VQATextDataset(dataset_df,
7072
split,
7173
transforms,
@@ -74,7 +76,7 @@ def get_task_dataloaders(path, transforms, labelencoder, args):
7476
)
7577
dataloader = DataLoader(
7678
dataset,
77-
batch_size=args.batch_size,
79+
batch_size=b_size,
7880
shuffle=True,
7981
num_workers=args.workers,
8082
pin_memory=True,

0 commit comments

Comments
 (0)