Skip to content

Commit d305e43

Browse files
authored
with_format docstring (#7203)
1 parent cdb1d32 commit d305e43

File tree

3 files changed

+109
-20
lines changed

3 files changed

+109
-20
lines changed

src/datasets/arrow_dataset.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -2649,12 +2649,32 @@ def with_format(
26492649
'format_kwargs': {},
26502650
'output_all_columns': False,
26512651
'type': None}
2652-
>>> ds = ds.with_format(type='tensorflow', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
2652+
>>> ds = ds.with_format("torch")
26532653
>>> ds.format
2654-
{'columns': ['input_ids', 'token_type_ids', 'attention_mask', 'label'],
2654+
{'columns': ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
26552655
'format_kwargs': {},
26562656
'output_all_columns': False,
2657-
'type': 'tensorflow'}
2657+
'type': 'torch'}
2658+
>>> ds[0]
2659+
{'text': 'compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
2660+
'label': tensor(1),
2661+
'input_ids': tensor([ 101, 18027, 16310, 16001, 1103, 9321, 178, 11604, 7235, 6617,
2662+
1742, 2165, 2820, 1206, 6588, 22572, 12937, 1811, 2153, 1105,
2663+
1147, 12890, 19587, 6463, 1105, 15026, 1482, 119, 102, 0,
2664+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2665+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2666+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2667+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2668+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2669+
0, 0, 0, 0]),
2670+
'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2671+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2672+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2673+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
2674+
'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2675+
1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2676+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2677+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
26582678
```
26592679
"""
26602680
dataset = copy.deepcopy(self)

src/datasets/dataset_dict.py

+51-13
Original file line numberDiff line numberDiff line change
@@ -693,12 +693,32 @@ def with_format(
693693
'format_kwargs': {},
694694
'output_all_columns': False,
695695
'type': None}
696-
>>> ds = ds.with_format(type='tensorflow', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
696+
>>> ds = ds.with_format("torch")
697697
>>> ds["train"].format
698-
{'columns': ['input_ids', 'token_type_ids', 'attention_mask', 'label'],
698+
{'columns': ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
699699
'format_kwargs': {},
700700
'output_all_columns': False,
701-
'type': 'tensorflow'}
701+
'type': 'torch'}
702+
>>> ds["train"][0]
703+
{'text': 'compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
704+
'label': tensor(1),
705+
'input_ids': tensor([ 101, 18027, 16310, 16001, 1103, 9321, 178, 11604, 7235, 6617,
706+
1742, 2165, 2820, 1206, 6588, 22572, 12937, 1811, 2153, 1105,
707+
1147, 12890, 19587, 6463, 1105, 15026, 1482, 119, 102, 0,
708+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
709+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
710+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
711+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
712+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
713+
0, 0, 0, 0]),
714+
'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
715+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
716+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
717+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
718+
'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
719+
1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
720+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
721+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
702722
```
703723
"""
704724
dataset = copy.deepcopy(self)
@@ -1801,25 +1821,43 @@ def with_format(
18011821
) -> "IterableDatasetDict":
18021822
"""
18031823
Return a dataset with the specified format.
1804-
This method only supports the "torch" format for now.
1805-
The format is set to all the datasets of the dataset dictionary.
1824+
The 'pandas' format is currently not implemented.
18061825
18071826
Args:
1808-
type (`str`, *optional*, defaults to `None`):
1809-
If set to "torch", the returned dataset
1810-
will be a subclass of `torch.utils.data.IterableDataset` to be used in a `DataLoader`.
1827+
1828+
type (`str`, *optional*):
1829+
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'arrow', 'jax']`.
1830+
`None` means it returns python objects (default).
18111831
18121832
Example:
18131833
18141834
```py
18151835
>>> from datasets import load_dataset
1816-
>>> ds = load_dataset("rotten_tomatoes", streaming=True)
18171836
>>> from transformers import AutoTokenizer
1818-
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
1819-
>>> def encode(example):
1820-
... return tokenizer(examples["text"], truncation=True, padding="max_length")
1821-
>>> ds = ds.map(encode, batched=True, remove_columns=["text"])
1837+
>>> ds = load_dataset("rotten_tomatoes", split="validation", streaming=True)
1838+
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
1839+
>>> ds = ds.map(lambda x: tokenizer(x['text'], truncation=True, padding=True), batched=True)
18221840
>>> ds = ds.with_format("torch")
1841+
>>> next(iter(ds))
1842+
{'text': 'compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
1843+
'label': tensor(1),
1844+
'input_ids': tensor([ 101, 18027, 16310, 16001, 1103, 9321, 178, 11604, 7235, 6617,
1845+
1742, 2165, 2820, 1206, 6588, 22572, 12937, 1811, 2153, 1105,
1846+
1147, 12890, 19587, 6463, 1105, 15026, 1482, 119, 102, 0,
1847+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1848+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1849+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1850+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1851+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1852+
0, 0, 0, 0]),
1853+
'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1854+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1855+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1856+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
1857+
'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1858+
1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1859+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1860+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
18231861
```
18241862
"""
18251863
return IterableDatasetDict({k: dataset.with_format(type=type) for k, dataset in self.items()})

src/datasets/iterable_dataset.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -2178,13 +2178,44 @@ def with_format(
21782178
) -> "IterableDataset":
21792179
"""
21802180
Return a dataset with the specified format.
2181-
Supported formats: "arrow", or None for regular python objects.
2182-
The other formats are currently not implemented.
2181+
The 'pandas' format is currently not implemented.
21832182
21842183
Args:
21852184
2186-
type (`str`, optional, default None): if set to "torch", the returned dataset
2187-
will be a subclass of torch.utils.data.IterableDataset to be used in a DataLoader
2185+
type (`str`, *optional*):
2186+
Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'arrow', 'jax']`.
2187+
`None` means it returns python objects (default).
2188+
2189+
Example:
2190+
2191+
```py
2192+
>>> from datasets import load_dataset
2193+
>>> from transformers import AutoTokenizer
2194+
>>> ds = load_dataset("rotten_tomatoes", split="validation", streaming=True)
2195+
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
2196+
>>> ds = ds.map(lambda x: tokenizer(x['text'], truncation=True, padding=True), batched=True)
2197+
>>> ds = ds.with_format("torch")
2198+
>>> next(iter(ds))
2199+
{'text': 'compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
2200+
'label': tensor(1),
2201+
'input_ids': tensor([ 101, 18027, 16310, 16001, 1103, 9321, 178, 11604, 7235, 6617,
2202+
1742, 2165, 2820, 1206, 6588, 22572, 12937, 1811, 2153, 1105,
2203+
1147, 12890, 19587, 6463, 1105, 15026, 1482, 119, 102, 0,
2204+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2205+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2206+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2207+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2208+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2209+
0, 0, 0, 0]),
2210+
'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2211+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2212+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2213+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
2214+
'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2215+
1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2216+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2217+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
2218+
```
21882219
"""
21892220
type = get_format_type_from_alias(type)
21902221
# TODO(QL): add format_kwargs

0 commit comments

Comments
 (0)