Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit e729e9a

Browse files
authored
Added GQA reader (#4832)
* Adds reader for GQA dataset. Will download questions from https://cs.stanford.edu/people/dorarad/gqa/download.html. * Cleaned up GQA reader tests
1 parent 52e9dd9 commit e729e9a

File tree

9 files changed

+380
-0
lines changed

9 files changed

+380
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
3333
- Added abstraction and concrete implementation for region detectors.
3434
- Transformer toolkit to plug and play with modular components of transformer architectures.
3535
- `VisionReader` and `VisionTextModel` base classes added. `VisualEntailment` and `VQA` inherit from these.
36+
- Added reader for the GQA dataset
3637

3738
### Changed
3839

allennlp/data/dataset_readers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from allennlp.data.dataset_readers.vision_reader import VisionReader
2424
from allennlp.data.dataset_readers.vqav2 import VQAv2Reader
2525
from allennlp.data.dataset_readers.visual_entailment import VisualEntailmentReader
26+
from allennlp.data.dataset_readers.gqa import GQAReader
2627
except ModuleNotFoundError as err:
2728
if err.name not in ("detectron2", "torchvision"):
2829
raise

allennlp/data/dataset_readers/gqa.py

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from os import PathLike
2+
from typing import (
3+
Dict,
4+
Union,
5+
Optional,
6+
Tuple,
7+
)
8+
import json
9+
import os
10+
11+
from overrides import overrides
12+
import torch
13+
from torch import Tensor
14+
15+
from allennlp.common.file_utils import cached_path
16+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
17+
from allennlp.data.fields import ArrayField, LabelField, TextField
18+
from allennlp.data.image_loader import ImageLoader
19+
from allennlp.data.instance import Instance
20+
from allennlp.data.token_indexers import TokenIndexer
21+
from allennlp.data.tokenizers import Tokenizer
22+
from allennlp.modules.vision.grid_embedder import GridEmbedder
23+
from allennlp.modules.vision.region_detector import RegionDetector
24+
from allennlp.data.dataset_readers.vision_reader import VisionReader
25+
26+
27+
@DatasetReader.register("gqa")
28+
class GQAReader(VisionReader):
29+
"""
30+
Parameters
31+
----------
32+
image_dir: `str`
33+
Path to directory containing `png` image files.
34+
image_featurizer: `GridEmbedder`
35+
The backbone image processor (like a ResNet), whose output will be passed to the region
36+
detector for finding object boxes in the image.
37+
region_detector: `RegionDetector`
38+
For pulling out regions of the image (both coordinates and features) that will be used by
39+
downstream models.
40+
data_dir: `str`
41+
Path to directory containing text files for each dataset split. These files contain
42+
the sentences and metadata for each task instance.
43+
tokenizer: `Tokenizer`, optional
44+
token_indexers: `Dict[str, TokenIndexer]`
45+
lazy : `bool`, optional
46+
Whether to load data lazily. Passed to super class.
47+
"""
48+
49+
def __init__(
50+
self,
51+
image_dir: Union[str, PathLike],
52+
image_loader: ImageLoader,
53+
image_featurizer: GridEmbedder,
54+
region_detector: RegionDetector,
55+
*,
56+
feature_cache_dir: Optional[Union[str, PathLike]] = None,
57+
data_dir: Optional[Union[str, PathLike]] = None,
58+
tokenizer: Tokenizer = None,
59+
token_indexers: Dict[str, TokenIndexer] = None,
60+
cuda_device: Optional[Union[int, torch.device]] = None,
61+
max_instances: Optional[int] = None,
62+
image_processing_batch_size: int = 8,
63+
skip_image_feature_extraction: bool = False,
64+
) -> None:
65+
super().__init__(
66+
image_dir,
67+
image_loader,
68+
image_featurizer,
69+
region_detector,
70+
feature_cache_dir=feature_cache_dir,
71+
tokenizer=tokenizer,
72+
token_indexers=token_indexers,
73+
cuda_device=cuda_device,
74+
max_instances=max_instances,
75+
image_processing_batch_size=image_processing_batch_size,
76+
skip_image_feature_extraction=skip_image_feature_extraction,
77+
)
78+
self.data_dir = data_dir
79+
80+
@overrides
81+
def _read(self, split_or_filename: str):
82+
83+
if not self.data_dir:
84+
self.data_dir = "https://nlp.stanford.edu/data/gqa/questions1.2.zip!"
85+
86+
splits = {
87+
"challenge_all": f"{self.data_dir}challenge_all_questions.json",
88+
"challenge_balanced": f"{self.data_dir}challenge_balanced_questions.json",
89+
"test_all": f"{self.data_dir}test_all_questions.json",
90+
"test_balanced": f"{self.data_dir}test_balanced_questions.json",
91+
"testdev_all": f"{self.data_dir}testdev_all_questions.json",
92+
"testdev_balanced": f"{self.data_dir}testdev_balanced_questions.json",
93+
"train_balanced": f"{self.data_dir}train_balanced_questions.json",
94+
"train_all": f"{self.data_dir}train_all_questions",
95+
"val_all": f"{self.data_dir}val_all_questions.json",
96+
"val_balanced": f"{self.data_dir}val_balanced_questions.json",
97+
}
98+
99+
filename = splits.get(split_or_filename, split_or_filename)
100+
101+
# If we're considering a directory of files (such as train_all)
102+
# loop through each in file in generator
103+
if os.path.isdir(filename):
104+
files = [f"{filename}{file_path}" for file_path in os.listdir(filename)]
105+
else:
106+
files = [filename]
107+
108+
for data_file in files:
109+
with open(cached_path(data_file, extract_archive=True)) as f:
110+
questions_with_annotations = json.load(f)
111+
112+
# It would be much easier to just process one image at a time, but it's faster to process
113+
# them in batches. So this code gathers up instances until it has enough to fill up a batch
114+
# that needs processing, and then processes them all.
115+
question_dicts = list(
116+
self.shard_iterable(
117+
questions_with_annotations[q_id] for q_id in questions_with_annotations
118+
)
119+
)
120+
121+
processed_images = self._process_image_paths(
122+
self.images[f"{question_dict['imageId']}.jpg"] for question_dict in question_dicts
123+
)
124+
125+
for question_dict, processed_image in zip(question_dicts, processed_images):
126+
answer = {
127+
"answer": question_dict["answer"],
128+
}
129+
yield self.text_to_instance(question_dict["question"], processed_image, answer)
130+
131+
@overrides
132+
def text_to_instance(
133+
self, # type: ignore
134+
question: str,
135+
image: Union[str, Tuple[Tensor, Tensor]],
136+
answer: Dict[str, str] = None,
137+
*,
138+
use_cache: bool = True,
139+
) -> Instance:
140+
tokenized_question = self._tokenizer.tokenize(question)
141+
question_field = TextField(tokenized_question, None)
142+
if isinstance(image, str):
143+
features, coords = next(self._process_image_paths([image], use_cache=use_cache))
144+
else:
145+
features, coords = image
146+
147+
fields = {
148+
"box_features": ArrayField(features),
149+
"box_coordinates": ArrayField(coords),
150+
"question": question_field,
151+
}
152+
153+
if answer:
154+
fields["label"] = LabelField(answer["answer"], label_namespace="answer")
155+
156+
return Instance(fields)
157+
158+
@overrides
159+
def apply_token_indexers(self, instance: Instance) -> None:
160+
instance["question"].token_indexers = self._token_indexers # type: ignore
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
"202218649": {
3+
"semantic": [
4+
{
5+
"operation": "select",
6+
"dependencies": [],
7+
"argument": "chalkboard (0)"
8+
},
9+
{
10+
"operation": "relate",
11+
"dependencies": [0],
12+
"argument": "_,hanging above,s (12)"
13+
},
14+
{
15+
"operation": "query",
16+
"dependencies": [1],
17+
"argument": "name"
18+
}
19+
],
20+
"entailed": ["202218648"],
21+
"equivalent": ["202218649"],
22+
"question": "What is hanging above the chalkboard?",
23+
"imageId": "n578564",
24+
"isBalanced": true,
25+
"groups": {
26+
"global": "thing",
27+
"local": "14-chalkboard_hanging above,s"
28+
},
29+
"answer": "picture",
30+
"semanticStr": "select: chalkboard (0)->relate: _,hanging above,s (12) [0]->query: name [1]",
31+
"annotations": {
32+
"answer": {"0": "12"},
33+
"question": {},
34+
"fullAnswer": {"1": "12", "6": "0"}
35+
},
36+
"types": {
37+
"detailed": "relS",
38+
"semantic": "rel",
39+
"structural": "query"
40+
},
41+
"fullAnswer": "The picture is hanging above the chalkboard."
42+
}
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{
2+
"20240871": {
3+
"semantic": [
4+
{
5+
"operation": "select",
6+
"dependencies": [],
7+
"argument": "water (4)"
8+
},
9+
{
10+
"operation": "relate",
11+
"dependencies": [0],
12+
"argument": "table,below,s (11)"
13+
},
14+
{
15+
"operation": "verify shape",
16+
"dependencies": [1],
17+
"argument": "round"
18+
},
19+
{
20+
"operation": "verify material",
21+
"dependencies": [1],
22+
"argument": "wood "
23+
},
24+
{"operation": "and",
25+
"dependencies": [2, 3],
26+
"argument": ""
27+
}
28+
],
29+
"entailed": ["20240900", "20240892", "20240891", "20240890", "20240879", "20240896", "20240895", "20240894", "20240875", "20240897", "20240899", "20240898", "20240870", "20240878", "20240910", "20240877", "20240909", "20240886", "20240887", "20240882", "20240911", "20240872", "20240888", "20240889"], "equivalent": ["20240871", "20240870"],
30+
"question": "Does the table below the water look wooden and round?",
31+
"imageId": "n166008",
32+
"isBalanced": false,
33+
"groups": {
34+
"global": null,
35+
"local": "05-round_wood"
36+
},
37+
"answer": "yes",
38+
"semanticStr": "select: water (4)->relate: table,below,s (11) [0]->verify shape: round [1]->verify material: wood [1]->and: [2, 3]",
39+
"annotations": {
40+
"answer": {},
41+
"question": {"2": "11", "5": "4"},
42+
"fullAnswer": {"2": "11"}
43+
},
44+
"types": {
45+
"detailed": "verifyAttrs",
46+
"semantic": "attr",
47+
"structural": "logical"
48+
},
49+
"fullAnswer": "Yes, the table is wooden and round."
50+
}
51+
}

test_fixtures/data/gqa/questions.json

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
"202218649": {
3+
"semantic": [
4+
{
5+
"operation": "select",
6+
"dependencies": [],
7+
"argument": "chalkboard (0)"
8+
},
9+
{
10+
"operation": "relate",
11+
"dependencies": [0],
12+
"argument": "_,hanging above,s (12)"
13+
},
14+
{
15+
"operation": "query",
16+
"dependencies": [1],
17+
"argument": "name"
18+
}
19+
],
20+
"entailed": ["202218648"],
21+
"equivalent": ["202218649"],
22+
"question": "What is hanging above the chalkboard?",
23+
"imageId": "n578564",
24+
"isBalanced": true,
25+
"groups": {
26+
"global": "thing",
27+
"local": "14-chalkboard_hanging above,s"
28+
},
29+
"answer": "picture",
30+
"semanticStr": "select: chalkboard (0)->relate: _,hanging above,s (12) [0]->query: name [1]",
31+
"annotations": {
32+
"answer": {"0": "12"},
33+
"question": {},
34+
"fullAnswer": {"1": "12", "6": "0"}
35+
},
36+
"types": {
37+
"detailed": "relS",
38+
"semantic": "rel",
39+
"structural": "query"
40+
},
41+
"fullAnswer": "The picture is hanging above the chalkboard."
42+
}
43+
}

0 commit comments

Comments
 (0)