Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit c2ffb10

Browse files
leo-liuzyepwalsh
andauthored
Add influence functions to interpret module (#4988)
* creating a new functionality to fields and instances to support outputing instnaces to json files * creating tests for the new functionality * fixing docs * Delete __init__.py * Delete influence_interpreter.py * Delete use_if.py * Delete simple_influence_test.py * fixing docs * finishing up SimpleInfluence * passing lint * passing format * making small progress in coding * Delete fast_influence.py Submit to the wrong branch * Delete faiss_utils.py wrong branch * Delete gpt2_bug.py not sure why it's included * Delete text_class.py not sure why it's included * adding test file * adding testing files * deleted unwanted files * deleted unwanted files and rearrange test files * small bug * adjust function call to save instance in json * Update allennlp/interpret/influence_interpreters/influence_interpreter.py Co-authored-by: Evan Pete Walsh <[email protected]> * Update allennlp/interpret/influence_interpreters/influence_interpreter.py Co-authored-by: Evan Pete Walsh <[email protected]> * Update allennlp/interpret/influence_interpreters/influence_interpreter.py Co-authored-by: Evan Pete Walsh <[email protected]> * move some documentation of parameters to base class * delete one comment * delete one deprecated abstract method * changing interface * formatting * formatting err * passing mypy * passing mypy * passing mypy * passing mypy * passing integration test * passing integration test * adding a new option to the do-all function * modifying the callable function to the interface * update API, fixes * doc fixes * add `from_path` and `from_archive` methods * fix docs, improve logging * add test * address @matt-gardner's comments * fixes to documentation * update docs Co-authored-by: Evan Pete Walsh <[email protected]> Co-authored-by: Evan Pete Walsh <[email protected]>
1 parent 0c7d60b commit c2ffb10

File tree

12 files changed

+814
-12
lines changed

12 files changed

+814
-12
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Add new dimension to the `interpret` module: influence functions via the `InfluenceInterpreter` base class, along with a concrete implementation: `SimpleInfluence`.
13+
- Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars.
1214
- The test for distributed metrics now takes a parameter specifying how often you want to run it.
1315

1416

allennlp/data/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
TensorDict,
44
allennlp_collate,
55
)
6-
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
6+
from allennlp.data.dataset_readers.dataset_reader import DatasetReader, DatasetReaderInput
77
from allennlp.data.fields.field import DataArray, Field
88
from allennlp.data.fields.text_field import TextFieldTensors
99
from allennlp.data.instance import Instance

allennlp/data/data_loaders/multiprocess_data_loader.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from multiprocessing.process import BaseProcess
44
import random
55
import traceback
6-
from typing import List, Iterator, Optional, Iterable, Union
6+
from typing import List, Iterator, Optional, Iterable, Union, TypeVar
77

88
from overrides import overrides
99
import torch
@@ -23,6 +23,9 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26+
_T = TypeVar("_T")
27+
28+
2629
@DataLoader.register("multiprocess")
2730
class MultiProcessDataLoader(DataLoader):
2831
"""
@@ -118,6 +121,9 @@ class MultiProcessDataLoader(DataLoader):
118121
will automatically call [`set_target_device()`](#set_target_device) before iterating
119122
over batches.
120123
124+
quiet : `bool`, optional (default = `False`)
125+
If `True`, tqdm progress bars will be disabled.
126+
121127
# Best practices
122128
123129
- **Large datasets**
@@ -200,6 +206,7 @@ def __init__(
200206
max_instances_in_memory: int = None,
201207
start_method: str = "fork",
202208
cuda_device: Optional[Union[int, str, torch.device]] = None,
209+
quiet: bool = False,
203210
) -> None:
204211
# Do some parameter validation.
205212
if num_workers is not None and num_workers < 0:
@@ -240,6 +247,7 @@ def __init__(
240247
self.collate_fn = allennlp_collate
241248
self.max_instances_in_memory = max_instances_in_memory
242249
self.start_method = start_method
250+
self.quiet = quiet
243251
self.cuda_device: Optional[torch.device] = None
244252
if cuda_device is not None:
245253
if not isinstance(cuda_device, torch.device):
@@ -346,7 +354,7 @@ def iter_instances(self) -> Iterator[Instance]:
346354

347355
if self.num_workers <= 0:
348356
# Just read all instances in main process.
349-
for instance in Tqdm.tqdm(
357+
for instance in self._maybe_tqdm(
350358
self.reader.read(self.data_path), desc="loading instances"
351359
):
352360
self.reader.apply_token_indexers(instance)
@@ -365,7 +373,7 @@ def iter_instances(self) -> Iterator[Instance]:
365373
workers = self._start_instance_workers(queue, ctx)
366374

367375
try:
368-
for instance in Tqdm.tqdm(
376+
for instance in self._maybe_tqdm(
369377
self._gather_instances(queue), desc="loading instances"
370378
):
371379
if self.max_instances_in_memory is None:
@@ -569,6 +577,11 @@ def _instances_to_batches(
569577
break
570578
yield tensorize(batch)
571579

580+
def _maybe_tqdm(self, iterator: Iterable[_T], **tqdm_kwargs) -> Iterable[_T]:
581+
if self.quiet:
582+
return iterator
583+
return Tqdm.tqdm(iterator, **tqdm_kwargs)
584+
572585

573586
class WorkerError(Exception):
574587
"""

allennlp/data/data_loaders/simple_data_loader.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77

88
from allennlp.common.util import lazy_groups_of
9+
from allennlp.common.tqdm import Tqdm
910
from allennlp.data.data_loaders.data_loader import DataLoader, allennlp_collate, TensorDict
1011
from allennlp.data.dataset_readers import DatasetReader
1112
from allennlp.data.instance import Instance
@@ -37,6 +38,8 @@ def __init__(
3738
self._batch_generator: Optional[Iterator[TensorDict]] = None
3839

3940
def __len__(self) -> int:
41+
if self.batches_per_epoch is not None:
42+
return self.batches_per_epoch
4043
return math.ceil(len(self.instances) / self.batch_size)
4144

4245
@overrides
@@ -87,6 +90,10 @@ def from_dataset_reader(
8790
batch_size: int,
8891
shuffle: bool = False,
8992
batches_per_epoch: Optional[int] = None,
93+
quiet: bool = False,
9094
) -> "SimpleDataLoader":
91-
instances = list(reader.read(data_path))
95+
instance_iter = reader.read(data_path)
96+
if not quiet:
97+
instance_iter = Tqdm.tqdm(instance_iter, desc="loading instances")
98+
instances = list(instance_iter)
9299
return cls(instances, batch_size, shuffle=shuffle, batches_per_epoch=batches_per_epoch)

allennlp/interpret/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from allennlp.interpret.attackers.attacker import Attacker
22
from allennlp.interpret.saliency_interpreters.saliency_interpreter import SaliencyInterpreter
3+
from allennlp.interpret.influence_interpreters.influence_interpreter import InfluenceInterpreter
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from allennlp.interpret.influence_interpreters.influence_interpreter import InfluenceInterpreter
2+
from allennlp.interpret.influence_interpreters.simple_influence import SimpleInfluence

0 commit comments

Comments
 (0)