@@ -693,12 +693,32 @@ def with_format(
693
693
'format_kwargs': {},
694
694
'output_all_columns': False,
695
695
'type': None}
696
- >>> ds = ds.with_format(type='tensorflow', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'] )
696
+ >>> ds = ds.with_format("torch" )
697
697
>>> ds["train"].format
698
- {'columns': ['input_ids ', 'token_type_ids ', 'attention_mask ', 'label '],
698
+ {'columns': ['text ', 'label ', 'input_ids ', 'token_type_ids', 'attention_mask '],
699
699
'format_kwargs': {},
700
700
'output_all_columns': False,
701
- 'type': 'tensorflow'}
701
+ 'type': 'torch'}
702
+ >>> ds["train"][0]
703
+ {'text': 'compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
704
+ 'label': tensor(1),
705
+ 'input_ids': tensor([ 101, 18027, 16310, 16001, 1103, 9321, 178, 11604, 7235, 6617,
706
+ 1742, 2165, 2820, 1206, 6588, 22572, 12937, 1811, 2153, 1105,
707
+ 1147, 12890, 19587, 6463, 1105, 15026, 1482, 119, 102, 0,
708
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
709
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
710
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
711
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
712
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
713
+ 0, 0, 0, 0]),
714
+ 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
715
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
716
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
717
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
718
+ 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
719
+ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
720
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
721
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
702
722
```
703
723
"""
704
724
dataset = copy .deepcopy (self )
@@ -1801,25 +1821,43 @@ def with_format(
1801
1821
) -> "IterableDatasetDict" :
1802
1822
"""
1803
1823
Return a dataset with the specified format.
1804
- This method only supports the "torch" format for now.
1805
- The format is set to all the datasets of the dataset dictionary.
1824
+ The 'pandas' format is currently not implemented.
1806
1825
1807
1826
Args:
1808
- type (`str`, *optional*, defaults to `None`):
1809
- If set to "torch", the returned dataset
1810
- will be a subclass of `torch.utils.data.IterableDataset` to be used in a `DataLoader`.
1827
+
1828
+ type (`str`, *optional*):
1829
+ Either output type selected in `[None, 'numpy', 'torch', 'tensorflow', 'arrow', 'jax']`.
1830
+ `None` means it returns python objects (default).
1811
1831
1812
1832
Example:
1813
1833
1814
1834
```py
1815
1835
>>> from datasets import load_dataset
1816
- >>> ds = load_dataset("rotten_tomatoes", streaming=True)
1817
1836
>>> from transformers import AutoTokenizer
1818
- >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
1819
- >>> def encode(example):
1820
- ... return tokenizer(examples["text"], truncation=True, padding="max_length")
1821
- >>> ds = ds.map(encode, batched=True, remove_columns=["text"])
1837
+ >>> ds = load_dataset("rotten_tomatoes", split="validation", streaming=True)
1838
+ >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
1839
+ >>> ds = ds.map(lambda x: tokenizer(x['text'], truncation=True, padding=True), batched=True)
1822
1840
>>> ds = ds.with_format("torch")
1841
+ >>> next(iter(ds))
1842
+ {'text': 'compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children .',
1843
+ 'label': tensor(1),
1844
+ 'input_ids': tensor([ 101, 18027, 16310, 16001, 1103, 9321, 178, 11604, 7235, 6617,
1845
+ 1742, 2165, 2820, 1206, 6588, 22572, 12937, 1811, 2153, 1105,
1846
+ 1147, 12890, 19587, 6463, 1105, 15026, 1482, 119, 102, 0,
1847
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1848
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1849
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1850
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1851
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1852
+ 0, 0, 0, 0]),
1853
+ 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1854
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1855
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1856
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
1857
+ 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1858
+ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1859
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1860
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}
1823
1861
```
1824
1862
"""
1825
1863
return IterableDatasetDict ({k : dataset .with_format (type = type ) for k , dataset in self .items ()})
0 commit comments