Skip to content

Commit cd009bd

Browse files
authored
Update vqa_fine_tune.py
1 parent 008bfc5 commit cd009bd

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

src/training/vqa_fine_tune.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,14 @@ def get_task_dataloaders(path, transforms, labelencoder, args):
6767
#dataset_df = dataset_df[0:12800]
6868
b_size = args.batch_size
6969
if(split == "validation"):
70-
b_size = args.batch_size
71-
dataset = VQATextDataset(dataset_df,
72-
split,
73-
transforms,
74-
labelencoder,
75-
tokenizer=tokenizer,
76-
)
70+
b_size = args.batch_size * 20
71+
dataset_df = dataset_df[0:12800]
72+
dataset = VQATextDataset(dataset_df,
73+
split,
74+
transforms,
75+
labelencoder,
76+
tokenizer=tokenizer,
77+
)
7778
dataloader = DataLoader(
7879
dataset,
7980
batch_size=b_size,
@@ -222,7 +223,7 @@ def parse_args(args):
222223
"--workers", type=int, default=2, help="Number of dataloader workers per GPU."
223224
)
224225
parser.add_argument(
225-
"--batch-size", type=int, default=256, help="Batch size per GPU."
226+
"--batch-size", type=int, default=128, help="Batch size per GPU."
226227
)
227228
parser.add_argument(
228229
"--epochs", type=int, default=10, help="Number of epochs to train for."

0 commit comments

Comments
 (0)