Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit b659e66

Browse files
dirkgrmatt-gardnerjiasenluj-minsanjays
authored
VQAv2 (#4639)
* albert works, but bert-base-uncased still gives zero gradients * Note * Formatting * Adds Registrable base classes for image operations * Adds a real example of a image2image module * Run the new code (without implementation) in the nlvr2 reader * Solve some issue involving circular imports * add new modules for vilbert * add parameters for detectron image loader. * push current code on implementing proposal generator. * push current progress on proposal generator * Update FasterRCNNProposalGenerator & Merge Detectron2 config * Loading of weights should now work * black, flake, mypy * Run detectron pipeline pieces one at a time This is unfinished and will not run this way. * Fix the data format for the backbone * Handle image sizes separately * remove drop and mask functionality from reader * make comment better * remove proposal_embedder, and finish proposal generator * working on grid embedder * added simple test for resnet backbone, which passes * Got proposal generator test passing * Change default number of detections per image: 100 => 36 * Fix detectron config hierarchy: test_detectron_per_image * Make number of detections configurable & Add test * rename ProposalGenerator to RegionDetector * try to fix makefile * another attempt at makefile * quotes in the pip command... * added a simple test for the dataset reader, made it pass * add feature caching to the dataset reader * another try with the makefile * a better temporary fix for installing detectron * writing files before committing is good... * fix tests * fix (at least part of) the vilbert tests * ok, this makefile change should actually work * add torchvision, try to remove eager import of detectron code * flake * cleanup * more cleanup * mypy, flake * add back code I shouldn't have removed * black * test and flake fixes * fix region_detector for multiple images and add feature and coords padding * fix imports * restore null grid embedder * add back (todo) null region detector * Bring back import changes, to fix circular imports caused by NLVR2 reader * region detector test passing * model test finally passing * update torchvision version * add vqav2 dataset * add gpu support for detectron feature extraction * add lmdbCache to cache feature into lmdb database * fix typo * update vqa jsonnet * fix url adding by cat * Fixes type annotation * Fixes borked error message * New feature cache * Formatting * Fix the tensor cache * Be explicit about our dependencies * Use the new tensor cache * Adds a test using the tensor cache * Run NLVR dataprep on GPU * Tqdm when finding images * Fixes padding in array field * Adjust max_length when truncating in PretrainedTransformerTokenizer * Fewer print statements * remove VQA from this branch and copy default vilbert parameters. * add VQAv2 dataset * Added dataset reader and model tests, which are now passing * Sanjay's vision features cache script (#4633) * Use LMDB cache in NLVR2 dataset reader; fix a few typos * Standalone script for caching image features * Removing reference to LMDB cache in NLVR2 dataset reader * Adding back asterisk in nlvr2 dataset reader * Fixing one variable name mistake * Decreasing batch size and making a few cuda-related changes * Loading images in batches to avoid GPU OOM error * Pedantic changes for consistency * Run the pre-processing with the models and not the data loading * Filter out paths of images already cached * Add image extensions other than png * Fixes import error * Makes the vision features script work alongside other scripts or training runs Co-authored-by: sanjays <[email protected]> Co-authored-by: sanjays <[email protected]> Co-authored-by: Sanjay Subramanian <[email protected]> Co-authored-by: Sanjay Subramanian <[email protected]> * Adds missing imports * Makes TensorCache into a real MutableMapping * Formatting * Changelog * Fix typecheck * Makes the NLVR2 reader work with Pete's new code * Fix type annotation * Formatting * Backwards compatibility * Restore NLVR to former glory * Types and multi-process reading for VQAv2 * Formatting * Fix tests * Fix broken config * Update grid embedder test * Fix vilbert_from_huggingface configuration * Don't run the vilbert_from_huggingface test anymore * Remove unused test fixtures * Fix the region detector test * Fix vilbert-from-huggingface and bring it back * Fuck the linter * Fix for VQA test * Why was this metric disabled? * Black and flake * Re-add VQA reader * Image featurizers now need to be called with sizes * Run the region detector test on GPU * Run more stuff on GPU The CPU test runner doesn't have enough memory. * Depend on newer version of Detectron * Reinstall Detectron before running tests * Just force CUDA to be on, instead of reinstalling Detecton2 * Fixes VQA2 DatasetReader * Fix documentation * Detectron needs CUDA_HOME to be set during install At least this thing fails quickly. * Try a different way of wrangling the detectron installer * Try a different way of wrangling the detectron installer * Bring back amp * Refactored VQA reader * More training paths * Remove debug code * Don't check in debug code * Auto-detect GPU to use * Apply indexers later * Fix typo * Register the model * Fields live on CPU. Only batches get GPUs. * black * black, flake * mypy * more flake * More realistic training config * Adds a basic Predictor for VQAv2 * Make vilbert output human-readable * Forgot to enumerate * Use the right namspace * Trying to make tests faster, and passing * add image prefix when loading coco image * fix vqav2 dataset reader and config file * use two regions, to make tests pass * black * Output probabilities in addition to logits * Make it possible to turn off the cache * Turn off the cache in the predictor * Fix the VQA predictor * change the experiment to the defualt vilbert hyperparams. * add default experiment_from_huggingface.json * fix typos in vqa reader * Proper probabilities * Formatting * Remove unused variable * Make mypy happy * Fixed loss function, metric, and got tests to pass * Updates the big training config * Put real settings into the vilbert_vqa config * Strings are lists in Python * Make mypy happy * Formatting * Unsatisfying mypy * Config changes to make this run * Fix dimensionality of embeddings * clean the code and add the image_num_heads and combine_num_heads * fix answer vocab and add save and load from pre-extracted vocab * fix loss and update save_answer_vocab script * Typo * Fixed fusion method * Tweaking the VQA config some more * Moved the from_huggingface config * 20 epochs * Set up the learning rate properly * Simplify * Hardcoded answer vocab * Don't be lazy * Steps per epoch cannot be None * Let's chase the right score * Fixing some parameter names * Fields are stored on CPUs * Bigger batch size, easier distributed training * Don't run the debug code by default * VQA with the Transformer Toolkit (#4729) * transformer toolkit: BertEmbeddings * transformer toolkit: BertSelfAttention * transformer toolkit: BertSelfOutput * transformer toolkit: BertAttention * transformer toolkit: BertIntermediate * transformer toolkit: BertOutput * transformer toolkit: BertLayer * transformer toolkit: BertBiAttention * transformer toolkit: BertEmbeddings * transformer toolkit: BertSelfAttention * transformer toolkit: BertSelfOutput * transformer toolkit: BertAttention * transformer toolkit: BertIntermediate * transformer toolkit: BertOutput * transformer toolkit: BertLayer * transformer toolkit: BertBiAttention * Attention scoring functions * merging output and self output * utility to replicate layers, further cleanup * adding sinusoidal positional encoding * adding activation layer * adding base class for generic loading of pretrained weights * further generalizing, adding tests * updates * adding bimodal encoder, kwargs in from_pretrained_module * vilbert using transformer toolkit * fixing test function * changing to torch.allclose * fixing attention score api * bug fix in bimodal output * changing to older attention modules * _construct_default_mapping returns mapping * adding kwargs to _get_input_arguments, adding examples * using cached_transformers * making transformer_encoder more general * added get_relevant_module, loading by name * fixing constructor name * undoing failure after merge * misc minor changes * Transformer toolkit (#4577) * transformer toolkit: BertEmbeddings * transformer toolkit: BertSelfAttention * transformer toolkit: BertSelfOutput * transformer toolkit: BertAttention * transformer toolkit: BertIntermediate * transformer toolkit: BertOutput * transformer toolkit: BertLayer * transformer toolkit: BertBiAttention * transformer toolkit: BertEmbeddings * transformer toolkit: BertSelfAttention * transformer toolkit: BertSelfOutput * transformer toolkit: BertAttention * transformer toolkit: BertIntermediate * transformer toolkit: BertOutput * transformer toolkit: BertLayer * transformer toolkit: BertBiAttention * Attention scoring functions * merging output and self output * utility to replicate layers, further cleanup * adding sinusoidal positional encoding * adding activation layer * adding base class for generic loading of pretrained weights * further generalizing, adding tests * updates * adding bimodal encoder, kwargs in from_pretrained_module * vilbert using transformer toolkit * fixing test function * changing to torch.allclose * fixing attention score api * bug fix in bimodal output * changing to older attention modules * _construct_default_mapping returns mapping * adding kwargs to _get_input_arguments, adding examples * using cached_transformers * making transformer_encoder more general * added get_relevant_module, loading by name * fixing constructor name * undoing failure after merge * misc minor changes Co-authored-by: Dirk Groeneveld <[email protected]> * separate num_attention_heads for both modalities, default arguments * adding tests for toolkit examples * debug statements for failing test * removing debug statements, reordering * Typo * Some compatibility with the transformer toolkit * Reorganize the image inputs * More transformer toolkit compatibility * Debug settings * Let's be more tolerant * Fix how VilBERT runs Co-authored-by: Akshita Bhagia <[email protected]> * Make the region detector and region embedder lazy * Fix references to the model * Make various automated tests pass * Formatting * More logging * One more logging statement * Read answer vocab from vocab file instead of determining it automatically * Don't keep the files open so long * Use most of the validation set for training as well * Get ready to be lazy * Upgrade paths * Be lazy * Keep unanswerable questions only during test time * Fix the from_huggingface config * Fixes the VQA score * VQA specific metric * Fixes some tests * Tests pass! * Formatting * Use the correct directory * Use the region detector that's meant for testing * Read the test split properly * Be a little more verbose while discovering images * Modernize Vilbert VQA * Update NLVR, but it still doesn't run * Formatting * Remove NLVR * Fix the last test * Formatting * Conditionally export the VilbertVqaPredictor * ModuleNotFoundError is a type of ImportError * Fix test-install * Try the broken test with a fixed seed * Try a bunch of seeds * Smaller model to get bigger magnitudes * Now that the test works, we don't need to specify the seeds anymore Co-authored-by: Matt Gardner <[email protected]> Co-authored-by: jiasenlu <[email protected]> Co-authored-by: Jaemin Cho <[email protected]> Co-authored-by: jiasenlu <[email protected]> Co-authored-by: sanjays <[email protected]> Co-authored-by: sanjays <[email protected]> Co-authored-by: Sanjay Subramanian <[email protected]> Co-authored-by: Sanjay Subramanian <[email protected]> Co-authored-by: Akshita Bhagia <[email protected]> Co-authored-by: Evan Pete Walsh <[email protected]>
1 parent c787230 commit b659e66

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1764
-689
lines changed

CHANGELOG.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Added
1212

13+
- Added `TensorCache` class for caching tensors on disk
14+
- Added reader for the NLVR2 dataset
15+
- Added cache for Detectron models that we might re-use several times in the code base
16+
- Added abstraction and concrete implementation for image loading
17+
- Added abstraction and concrete implementation for `GridEmbedder`
18+
- Added abstraction and demo implementation for an image augmentation module.
19+
- Added abstraction and concrete implementation for region detectors.
1320
- A new high-performance default `DataLoader`: `MultiProcessDataLoading`.
1421
- A `MultiTaskModel` and abstractions to use with it, including `Backbone` and `Head`. The
1522
`MultiTaskModel` first runs its inputs through the `Backbone`, then passes the result (and
@@ -33,7 +40,7 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
3340
- The `DataLoader` now decides whether to load instances lazily or not.
3441
With the `PyTorchDataLoader` this is controlled with the `lazy` parameter, but with
3542
the `MultiProcessDataLoading` this is controlled by the `max_instances_in_memory` setting.
36-
- `TensorField` is now implemented in terms of torch tensors, not numpy.
43+
- `ArrayField` is now called `TensorField`, and implemented in terms of torch tensors, not numpy.
3744

3845

3946
## Unreleased (1.x branch)

allennlp/commands/train.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ def _train_worker(
483483
return None
484484

485485

486+
DataPath = Union[str, List[str], Dict[str, str]]
487+
488+
486489
class TrainModel(Registrable):
487490
"""
488491
This class exists so that we can easily read a configuration file with the `allennlp train`
@@ -554,16 +557,16 @@ def from_partial_objects(
554557
serialization_dir: str,
555558
local_rank: int,
556559
dataset_reader: DatasetReader,
557-
train_data_path: str,
560+
train_data_path: DataPath,
558561
model: Lazy[Model],
559562
data_loader: Lazy[DataLoader],
560563
trainer: Lazy[Trainer],
561564
vocabulary: Lazy[Vocabulary] = Lazy(Vocabulary),
562565
datasets_for_vocab_creation: List[str] = None,
563566
validation_dataset_reader: DatasetReader = None,
564-
validation_data_path: str = None,
567+
validation_data_path: DataPath = None,
565568
validation_data_loader: Lazy[DataLoader] = None,
566-
test_data_path: str = None,
569+
test_data_path: DataPath = None,
567570
evaluate_on_test: bool = False,
568571
batch_weight_key: str = "",
569572
) -> "TrainModel":

allennlp/common/testing/model_test_case.py

+6
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def ensure_model_can_train_save_and_load(
7373
metric_terminal_value: float = None,
7474
metric_tolerance: float = 1e-4,
7575
disable_dropout: bool = True,
76+
seed: int = None,
7677
):
7778
"""
7879
# Parameters
@@ -108,6 +109,11 @@ def ensure_model_can_train_save_and_load(
108109
If True we will set all dropout to 0 before checking gradients. (Otherwise, with small
109110
datasets, you may get zero gradients because of unlucky dropout.)
110111
"""
112+
if seed is not None:
113+
random.seed(seed)
114+
numpy.random.seed(seed)
115+
torch.manual_seed(seed)
116+
111117
save_dir = self.TEST_DIR / "save_and_load_test"
112118
archive_file = save_dir / "model.tar.gz"
113119
model = train_model_from_file(param_file, save_dir, overrides=overrides)

allennlp/common/util.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Tuple,
2828
TypeVar,
2929
Union,
30+
Sequence,
3031
)
3132

3233
import numpy
@@ -143,7 +144,7 @@ def lazy_groups_of(iterable: Iterable[A], group_size: int) -> Iterator[List[A]]:
143144

144145

145146
def pad_sequence_to_length(
146-
sequence: List,
147+
sequence: Sequence,
147148
desired_length: int,
148149
default_value: Callable[[], Any] = lambda: 0,
149150
padding_on_right: bool = True,
@@ -174,6 +175,7 @@ def pad_sequence_to_length(
174175
175176
padded_sequence : `List`
176177
"""
178+
sequence = list(sequence)
177179
# Truncates the sequence to the desired length.
178180
if padding_on_right:
179181
padded_sequence = sequence[:desired_length]
@@ -342,8 +344,8 @@ def import_module_and_submodules(package_name: str) -> None:
342344
# Import at top level
343345
try:
344346
module = importlib.import_module(package_name)
345-
except ModuleNotFoundError as err:
346-
if err.name in ("detectron2", "torchvision"):
347+
except ImportError as err:
348+
if err.name in {"detectron2", "torchvision"}:
347349
logger.warning(
348350
"vision module '%s' is unavailable since '%s' is not installed",
349351
package_name,
@@ -651,6 +653,21 @@ def format_size(size: int) -> str:
651653
return f"{size}B"
652654

653655

656+
def nan_safe_tensor_divide(numerator, denominator):
657+
"""Performs division and handles divide-by-zero.
658+
659+
On zero-division, sets the corresponding result elements to zero.
660+
"""
661+
result = numerator / denominator
662+
mask = denominator == 0.0
663+
if not mask.any():
664+
return result
665+
666+
# remove nan
667+
result[mask] = 0.0
668+
return result
669+
670+
654671
def shuffle_iterable(i: Iterable[T], pool_size: int = 1024) -> Iterable[T]:
655672
import random
656673

allennlp/data/data_loaders/multi_process_data_loader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class MultiProcessDataLoader(DataLoader):
8181
max_instances_in_memory: `int`, optional (default = `None`)
8282
If not specified, all instances will be read and cached in memory for the duration
8383
of the data loader's life. This is generally ideal when your data can fit in memory
84-
during training. However, when you're datasets are too big, using this option
84+
during training. However, when your datasets are too big, using this option
8585
will turn on lazy loading, where only `max_instances_in_memory` instances are processed
8686
at a time.
8787

allennlp/data/dataset_readers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from allennlp.data.dataset_readers.text_classification_json import TextClassificationJsonReader
2121

2222
try:
23-
from allennlp.data.dataset_readers.nlvr2 import Nlvr2Reader
23+
from allennlp.data.dataset_readers.vqav2 import VQAv2Reader
2424
except ModuleNotFoundError as err:
2525
if err.name not in ("detectron2", "torchvision"):
2626
raise

allennlp/data/dataset_readers/dataset_reader.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22
import itertools
33
from os import PathLike
4-
from typing import Iterable, Iterator, Optional, Union, TypeVar
4+
from typing import Iterable, Iterator, Optional, Union, TypeVar, Dict, List
55
import logging
66
import warnings
77

@@ -58,6 +58,9 @@ class DistributedInfo:
5858

5959
_T = TypeVar("_T")
6060

61+
PathOrStr = Union[PathLike, str]
62+
DatasetReaderInput = Union[PathOrStr, List[PathOrStr], Dict[str, PathOrStr]]
63+
6164

6265
class DatasetReader(Registrable):
6366
"""
@@ -178,14 +181,19 @@ def __init__(
178181
if util.is_distributed():
179182
self._distributed_info = DistributedInfo(dist.get_world_size(), dist.get_rank())
180183

181-
def read(self, file_path: Union[PathLike, str]) -> Iterator[Instance]:
184+
def read(self, file_path: DatasetReaderInput) -> Iterator[Instance]:
182185
"""
183186
Returns an iterator of instances that can be read from the file path.
184187
"""
185188
if not isinstance(file_path, str):
186-
file_path = str(file_path)
187-
188-
for instance in self._multi_worker_islice(self._read(file_path)):
189+
if isinstance(file_path, list):
190+
file_path = [str(f) for f in file_path]
191+
elif isinstance(file_path, dict):
192+
file_path = {k: str(v) for k, v in file_path.items()}
193+
else:
194+
file_path = str(file_path)
195+
196+
for instance in self._multi_worker_islice(self._read(file_path)): # type: ignore
189197
if self._worker_info is None:
190198
# If not running in a subprocess, it's safe to apply the token_indexers right away.
191199
self.apply_token_indexers(instance)

0 commit comments

Comments
 (0)