1
- from typing import Dict , Mapping , Iterable , Union
1
+ from typing import Dict , Mapping , Iterable , Union , Optional
2
2
import json
3
3
4
+ from overrides import overrides
5
+
4
6
from allennlp .common .checks import ConfigurationError
5
- from allennlp .data .dataset_readers .dataset_reader import DatasetReader , PathOrStr
7
+ from allennlp .data .dataset_readers .dataset_reader import (
8
+ DatasetReader ,
9
+ PathOrStr ,
10
+ WorkerInfo ,
11
+ DistributedInfo ,
12
+ )
6
13
from allennlp .data .fields import MetadataField
7
14
from allennlp .data .instance import Instance
8
15
@@ -52,6 +59,18 @@ def __init__(
52
59
raise ConfigurationError (f"invalid scheme: { scheme } " )
53
60
self ._scheme = scheme
54
61
62
+ @overrides
63
+ def _set_worker_info (self , info : Optional [WorkerInfo ]) -> None :
64
+ super ()._set_worker_info (info )
65
+ for reader in self ._readers .values ():
66
+ reader ._set_worker_info (info )
67
+
68
+ @overrides
69
+ def _set_distributed_info (self , info : Optional [DistributedInfo ]) -> None :
70
+ super ()._set_distributed_info (info )
71
+ for reader in self ._readers .values ():
72
+ reader ._set_distributed_info (info )
73
+
55
74
def _read_round_robin (self , datasets : Mapping [str , Iterable [Instance ]]) -> Iterable [Instance ]:
56
75
remaining = set (datasets )
57
76
dataset_iterators = {key : iter (dataset ) for key , dataset in datasets .items ()}
@@ -72,6 +91,7 @@ def _read_all_at_once(self, datasets: Mapping[str, Iterable[Instance]]) -> Itera
72
91
instance .fields [self ._dataset_field_name ] = MetadataField (key )
73
92
yield instance
74
93
94
+ @overrides
75
95
def _read (self , file_path : Union [str , Dict [str , PathOrStr ]]) -> Iterable [Instance ]:
76
96
if isinstance (file_path , str ):
77
97
try :
@@ -97,6 +117,11 @@ def _read(self, file_path: Union[str, Dict[str, PathOrStr]]) -> Iterable[Instanc
97
117
else :
98
118
raise RuntimeError ("impossible to get here" )
99
119
100
- def text_to_instance (self ) -> Instance : # type: ignore
120
+ @overrides
121
+ def text_to_instance (self , dataset_key : str , * args , ** kwargs ) -> Instance : # type: ignore
122
+ return self ._readers [dataset_key ].text_to_instance (* args , ** kwargs ) # type: ignore[call-arg]
101
123
102
- raise RuntimeError ("text_to_instance doesn't make sense here" )
124
+ @overrides
125
+ def apply_token_indexers (self , instance : Instance ) -> None :
126
+ dataset = instance .fields [self ._dataset_field_name ].metadata # type: ignore[attr-defined]
127
+ self ._readers [dataset ].apply_token_indexers (instance )
0 commit comments