17
17
class TextClassificationJsonReader (DatasetReader ):
18
18
"""
19
19
Reads tokens and their labels from a labeled text classification dataset.
20
- Expects a "text" field and a "label" field in JSON format.
21
20
22
21
The output of `read` is a list of `Instance` s with the fields:
23
22
tokens : `TextField` and
@@ -44,6 +43,10 @@ class TextClassificationJsonReader(DatasetReader):
44
43
skip_label_indexing : `bool`, optional (default = `False`)
45
44
Whether or not to skip label indexing. You might want to skip label indexing if your
46
45
labels are numbers, so the dataset reader doesn't re-number them starting from 0.
46
+ text_key: `str`, optional (default=`"text"`)
47
+ The key name of the source field in the JSON data file.
48
+ label_key: `str`, optional (default=`"label"`)
49
+ The key name of the target field in the JSON data file.
47
50
"""
48
51
49
52
def __init__ (
@@ -53,6 +56,8 @@ def __init__(
53
56
segment_sentences : bool = False ,
54
57
max_sequence_length : int = None ,
55
58
skip_label_indexing : bool = False ,
59
+ text_key : str = "text" ,
60
+ label_key : str = "label" ,
56
61
** kwargs ,
57
62
) -> None :
58
63
super ().__init__ (
@@ -63,6 +68,8 @@ def __init__(
63
68
self ._max_sequence_length = max_sequence_length
64
69
self ._skip_label_indexing = skip_label_indexing
65
70
self ._token_indexers = token_indexers or {"tokens" : SingleIdTokenIndexer ()}
71
+ self ._text_key = text_key
72
+ self ._label_key = label_key
66
73
if self ._segment_sentences :
67
74
self ._sentence_segmenter = SpacySentenceSplitter ()
68
75
@@ -73,8 +80,8 @@ def _read(self, file_path):
73
80
if not line :
74
81
continue
75
82
items = json .loads (line )
76
- text = items ["text" ]
77
- label = items .get ("label" )
83
+ text = items [self . _text_key ]
84
+ label = items .get (self . _label_key )
78
85
if label is not None :
79
86
if self ._skip_label_indexing :
80
87
try :
0 commit comments