Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit c5c9edf

Browse files
RyujiTamakiepwalsh
andauthored
Add text_key and label_key to TextClassificationJsonReader (#5005)
* Add text_key and label_key to TextClassificationJsonReader * Update CHANGELOG.md * Remove unnecessary test * Apply suggestions from code review Co-authored-by: Evan Pete Walsh <[email protected]> Co-authored-by: Evan Pete Walsh <[email protected]>
1 parent a02f67d commit c5c9edf

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
- Added a way to specify extra parameters to the predictor in an `allennlp predict` call.
1919
- Added a way to initialize a `Vocabulary` from transformers models.
2020
- Added an example for fields of type `ListField[TextField]` to `apply_token_indexers` API docs.
21+
- Added `text_key` and `label_key` parameters to `TextClassificationJsonReader` class.
2122

2223
### Fixed
2324

allennlp/data/dataset_readers/text_classification_json.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
class TextClassificationJsonReader(DatasetReader):
1818
"""
1919
Reads tokens and their labels from a labeled text classification dataset.
20-
Expects a "text" field and a "label" field in JSON format.
2120
2221
The output of `read` is a list of `Instance` s with the fields:
2322
tokens : `TextField` and
@@ -44,6 +43,10 @@ class TextClassificationJsonReader(DatasetReader):
4443
skip_label_indexing : `bool`, optional (default = `False`)
4544
Whether or not to skip label indexing. You might want to skip label indexing if your
4645
labels are numbers, so the dataset reader doesn't re-number them starting from 0.
46+
text_key: `str`, optional (default=`"text"`)
47+
The key name of the source field in the JSON data file.
48+
label_key: `str`, optional (default=`"label"`)
49+
The key name of the target field in the JSON data file.
4750
"""
4851

4952
def __init__(
@@ -53,6 +56,8 @@ def __init__(
5356
segment_sentences: bool = False,
5457
max_sequence_length: int = None,
5558
skip_label_indexing: bool = False,
59+
text_key: str = "text",
60+
label_key: str = "label",
5661
**kwargs,
5762
) -> None:
5863
super().__init__(
@@ -63,6 +68,8 @@ def __init__(
6368
self._max_sequence_length = max_sequence_length
6469
self._skip_label_indexing = skip_label_indexing
6570
self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
71+
self._text_key = text_key
72+
self._label_key = label_key
6673
if self._segment_sentences:
6774
self._sentence_segmenter = SpacySentenceSplitter()
6875

@@ -73,8 +80,8 @@ def _read(self, file_path):
7380
if not line:
7481
continue
7582
items = json.loads(line)
76-
text = items["text"]
77-
label = items.get("label")
83+
text = items[self._text_key]
84+
label = items.get(self._label_key)
7885
if label is not None:
7986
if self._skip_label_indexing:
8087
try:

0 commit comments

Comments
 (0)