Skip to content

Commit 161f99d

Browse files
TyToddlhoestq
andauthored
Torchcodec decoding (#7616)
* passes all but 1 test case * Migrated Audio feature to use torchcodec as a backend. Fixed how formatter handles torchcodec objects. Fixed test scripts to work with new Audio backend * fixed audio and video features so they now pass the test_dataset_with_audio_feature_map_is_decoded test case. Implemented casting for VideoDecoder and AudioDecoder types * added load dataset test case to test_video.py * Modified documentation to document new torchcodec implementation of Video and Audio features. Fixed the the rest of the test files to be compatible with new Audio and Video features. * code formatting for torchcodec changes * Update src/datasets/features/audio.py Co-authored-by: Quentin Lhoest <[email protected]> * added backwards compatibility support and _hf_encoded for Audio feature. * move AudioDecoder to its own file * naming * docs * style * update tests * no torchcodec for windows * further cleaning * fix * install ffmpeg in ci * fix ffmpeg installation * fix mono backward compatibility * fix ffmpeg * again * fix mono backward compat * fix tests * fix tests * again --------- Co-authored-by: Quentin Lhoest <[email protected]> Co-authored-by: Quentin Lhoest <[email protected]>
1 parent b7819cd commit 161f99d

34 files changed

+879
-602
lines changed

.github/workflows/ci.yml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,17 @@ jobs:
4444
- uses: actions/checkout@v4
4545
with:
4646
fetch-depth: 0
47+
- name: Setup FFmpeg
48+
if: ${{ matrix.os == 'ubuntu-latest' }}
49+
run: |
50+
sudo apt update
51+
sudo apt install -y ffmpeg
4752
- name: Set up Python 3.9
4853
uses: actions/setup-python@v5
4954
with:
5055
python-version: "3.9"
5156
- name: Upgrade pip
5257
run: python -m pip install --upgrade pip
53-
- name: Pin setuptools-scm
54-
if: ${{ matrix.os == 'ubuntu-latest' }}
55-
run: echo "installing pinned version of setuptools-scm to fix seqeval installation on 3.7" && pip install "setuptools-scm==6.4.2"
5658
- name: Install uv
5759
run: pip install --upgrade uv
5860
- name: Install dependencies
@@ -80,6 +82,11 @@ jobs:
8082
- uses: actions/checkout@v4
8183
with:
8284
fetch-depth: 0
85+
- name: Setup FFmpeg
86+
if: ${{ matrix.os == 'ubuntu-latest' }}
87+
run: |
88+
sudo apt update
89+
sudo apt install -y ffmpeg
8390
- name: Set up Python 3.11
8491
uses: actions/setup-python@v5
8592
with:
@@ -107,6 +114,11 @@ jobs:
107114
- uses: actions/checkout@v4
108115
with:
109116
fetch-depth: 0
117+
- name: Setup FFmpeg
118+
if: ${{ matrix.os == 'ubuntu-latest' }}
119+
run: |
120+
sudo apt update
121+
sudo apt install -y ffmpeg
110122
- name: Set up Python 3.11
111123
uses: actions/setup-python@v5
112124
with:

docs/source/about_dataset_features.mdx

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ See the [flatten](./process#flatten) section to learn how you can extract the ne
5353

5454
</Tip>
5555

56-
The array feature type is useful for creating arrays of various sizes. You can create arrays with two dimensions using [`Array2D`], and even arrays with five dimensions using [`Array5D`].
56+
The array feature type is useful for creating arrays of various sizes. You can create arrays with two dimensions using [`Array2D`], and even arrays with five dimensions using [`Array5D`].
5757

5858
```py
5959
>>> features = Features({'a': Array2D(shape=(1, 3), dtype='int32')})
@@ -69,9 +69,9 @@ The array type also allows the first dimension of the array to be dynamic. This
6969

7070
Audio datasets have a column with type [`Audio`], which contains three important fields:
7171

72-
* `array`: the decoded audio data represented as a 1-dimensional array.
73-
* `path`: the path to the downloaded audio file.
74-
* `sampling_rate`: the sampling rate of the audio data.
72+
- `array`: the decoded audio data represented as a 1-dimensional array.
73+
- `path`: the path to the downloaded audio file.
74+
- `sampling_rate`: the sampling rate of the audio data.
7575

7676
When you load an audio dataset and call the audio column, the [`Audio`] feature automatically decodes and resamples the audio file:
7777

@@ -80,10 +80,7 @@ When you load an audio dataset and call the audio column, the [`Audio`] feature
8080

8181
>>> dataset = load_dataset("PolyAI/minds14", "en-US", split="train")
8282
>>> dataset[0]["audio"]
83-
{'array': array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
84-
0. , 0. ], dtype=float32),
85-
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav',
86-
'sampling_rate': 8000}
83+
<datasets.features._torchcodec.AudioDecoder object at 0x11642b6a0>
8784
```
8885

8986
<Tip warning={true}>
@@ -92,7 +89,7 @@ Index into an audio dataset using the row index first and then the `audio` colum
9289

9390
</Tip>
9491

95-
With `decode=False`, the [`Audio`] type simply gives you the path or the bytes of the audio file, without decoding it into an `array`,
92+
With `decode=False`, the [`Audio`] type simply gives you the path or the bytes of the audio file, without decoding it into an torchcodec `AudioDecoder` object,
9693

9794
```py
9895
>>> dataset = load_dataset("PolyAI/minds14", "en-US", split="train").cast_column("audio", Audio(decode=False))
@@ -126,7 +123,7 @@ Index into an image dataset using the row index first and then the `image` colum
126123

127124
</Tip>
128125

129-
With `decode=False`, the [`Image`] type simply gives you the path or the bytes of the image file, without decoding it into an `PIL.Image`,
126+
With `decode=False`, the [`Image`] type simply gives you the path or the bytes of the image file, without decoding it into an `PIL.Image`,
130127

131128
```py
132129
>>> dataset = load_dataset("AI-Lab-Makerere/beans", split="train").cast_column("image", Image(decode=False))
@@ -146,4 +143,4 @@ You can also define a dataset of images from numpy arrays:
146143
And in this case the numpy arrays are encoded into PNG (or TIFF if the pixels values precision is important).
147144

148145
For multi-channels arrays like RGB or RGBA, only uint8 is supported. If you use a larger precision, you get a warning and the array is downcasted to uint8.
149-
For gray-scale images you can use the integer or float precision you want as long as it is compatible with `Pillow`. A warning is shown if your image integer or float precision is too high, and in this case the array is downcated: an int64 array is downcasted to int32, and a float64 array is downcasted to float32.
146+
For gray-scale images you can use the integer or float precision you want as long as it is compatible with `Pillow`. A warning is shown if your image integer or float precision is too high, and in this case the array is downcated: an int64 array is downcasted to int32, and a float64 array is downcasted to float32.

docs/source/audio_dataset.mdx

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ dataset = load_dataset("<username>/my_dataset")
1010

1111
There are several methods for creating and sharing an audio dataset:
1212

13-
* Create an audio dataset from local files in python with [`Dataset.push_to_hub`]. This is an easy way that requires only a few steps in python.
14-
15-
* Create an audio dataset repository with the `AudioFolder` builder. This is a no-code solution for quickly creating an audio dataset with several thousand audio files.
13+
- Create an audio dataset from local files in python with [`Dataset.push_to_hub`]. This is an easy way that requires only a few steps in python.
1614

15+
- Create an audio dataset repository with the `AudioFolder` builder. This is a no-code solution for quickly creating an audio dataset with several thousand audio files.
1716

1817
<Tip>
1918

@@ -28,10 +27,7 @@ You can load your own dataset using the paths to your audio files. Use the [`~Da
2827
```py
2928
>>> audio_dataset = Dataset.from_dict({"audio": ["path/to/audio_1", "path/to/audio_2", ..., "path/to/audio_n"]}).cast_column("audio", Audio())
3029
>>> audio_dataset[0]["audio"]
31-
{'array': array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
32-
0. , 0. ], dtype=float32),
33-
'path': 'path/to/audio_1',
34-
'sampling_rate': 16000}
30+
<datasets.features._torchcodec.AudioDecoder object at 0x11642b6a0>
3531
```
3632

3733
Then upload the dataset to the Hugging Face Hub using [`Dataset.push_to_hub`]:
@@ -51,7 +47,6 @@ my_dataset/
5147

5248
## AudioFolder
5349

54-
5550
The `AudioFolder` is a dataset builder designed to quickly load an audio dataset with several thousand audio files without requiring you to write any code.
5651

5752
<Tip>
@@ -101,7 +96,6 @@ If all audio files are contained in a single directory or if they are not on the
10196

10297
</Tip>
10398

104-
10599
If there is additional information you'd like to include about your dataset, like text captions or bounding boxes, add it as a `metadata.csv` file in your folder. This lets you quickly create datasets for different computer vision tasks like text captioning or object detection. You can also use a JSONL file `metadata.jsonl` or a Parquet file `metadata.parquet`.
106100

107101
```

docs/source/audio_load.mdx

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,14 @@ Audio decoding is based on the [`soundfile`](https://github.com/bastibe/python-s
88
To work with audio datasets, you need to have the `audio` dependencies installed.
99
Check out the [installation](./installation#audio) guide to learn how to install it.
1010

11-
1211
## Local files
1312

1413
You can load your own dataset using the paths to your audio files. Use the [`~Dataset.cast_column`] function to take a column of audio file paths, and cast it to the [`Audio`] feature:
1514

1615
```py
1716
>>> audio_dataset = Dataset.from_dict({"audio": ["path/to/audio_1", "path/to/audio_2", ..., "path/to/audio_n"]}).cast_column("audio", Audio())
1817
>>> audio_dataset[0]["audio"]
19-
{'array': array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
20-
0. , 0. ], dtype=float32),
21-
'path': 'path/to/audio_1',
22-
'sampling_rate': 16000}
18+
<datasets.features._torchcodec.AudioDecoder object at 0x11642b6a0>
2319
```
2420

2521
## AudioFolder
@@ -99,7 +95,7 @@ For a guide on how to load any type of dataset, take a look at the <a class="und
9995

10096
## Audio decoding
10197

102-
By default, audio files are decoded sequentially as NumPy arrays when you iterate on a dataset.
98+
By default, audio files are decoded sequentially as torchcodec [`AudioDecoder`](https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder) objects when you iterate on a dataset.
10399
However it is possible to speed up the dataset significantly using multithreaded decoding:
104100

105101
```python

docs/source/audio_process.mdx

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ This guide shows specific methods for processing audio datasets. Learn how to:
77

88
For a guide on how to process any type of dataset, take a look at the <a class="underline decoration-sky-400 decoration-2 font-semibold" href="./process">general process guide</a>.
99

10-
1110
## Cast
1211

1312
The [`~Dataset.cast_column`] function is used to cast a column to another feature to be decoded. When you use this function with the [`Audio`] feature, you can resample the sampling rate:
@@ -22,16 +21,26 @@ The [`~Dataset.cast_column`] function is used to cast a column to another featur
2221
Audio files are decoded and resampled on-the-fly, so the next time you access an example, the audio file is resampled to 16kHz:
2322

2423
```py
25-
>>> dataset[0]["audio"]
26-
{'array': array([ 2.3443763e-05, 2.1729663e-04, 2.2145823e-04, ...,
27-
3.8356509e-05, -7.3497440e-06, -2.1754686e-05], dtype=float32),
28-
'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav',
29-
'sampling_rate': 16000}
24+
>>> audio = dataset[0]["audio"]
25+
<datasets.features._torchcodec.AudioDecoder object at 0x11642b6a0>
26+
>>> audio = audio_dataset[0]["audio"]
27+
>>> samples = audio.get_all_samples()
28+
>>> samples.data
29+
tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 2.3447e-06,
30+
-1.9127e-04, -5.3330e-05]]
31+
>>> samples.sample_rate
32+
16000
3033
```
3134

3235
<div class="flex justify-center">
33-
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/resample.gif"/>
34-
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/resample-dark.gif"/>
36+
<img
37+
class="block dark:hidden"
38+
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/resample.gif"
39+
/>
40+
<img
41+
class="hidden dark:block"
42+
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/resample-dark.gif"
43+
/>
3544
</div>
3645

3746
## Map
@@ -40,30 +49,30 @@ The [`~Dataset.map`] function helps preprocess your entire dataset at once. Depe
4049

4150
- For pretrained speech recognition models, load a feature extractor and tokenizer and combine them in a `processor`:
4251

43-
```py
44-
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoProcessor
52+
```py
53+
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoProcessor
4554

46-
>>> model_checkpoint = "facebook/wav2vec2-large-xlsr-53"
47-
# after defining a vocab.json file you can instantiate a tokenizer object:
48-
>>> tokenizer = AutoTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
49-
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
50-
>>> processor = AutoProcessor.from_pretrained(feature_extractor=feature_extractor, tokenizer=tokenizer)
51-
```
55+
>>> model_checkpoint = "facebook/wav2vec2-large-xlsr-53"
56+
# after defining a vocab.json file you can instantiate a tokenizer object:
57+
>>> tokenizer = AutoTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
58+
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)
59+
>>> processor = AutoProcessor.from_pretrained(feature_extractor=feature_extractor, tokenizer=tokenizer)
60+
```
5261

5362
- For fine-tuned speech recognition models, you only need to load a `processor`:
5463

55-
```py
56-
>>> from transformers import AutoProcessor
64+
```py
65+
>>> from transformers import AutoProcessor
5766

58-
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
59-
```
67+
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
68+
```
6069

6170
When you use [`~Dataset.map`] with your preprocessing function, include the `audio` column to ensure you're actually resampling the audio data:
6271

6372
```py
6473
>>> def prepare_dataset(batch):
6574
... audio = batch["audio"]
66-
... batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
75+
... batch["input_values"] = processor(audio.get_all_samples().data, sampling_rate=audio["sampling_rate"]).input_values[0]
6776
... batch["input_length"] = len(batch["input_values"])
6877
... with processor.as_target_processor():
6978
... batch["labels"] = processor(batch["sentence"]).input_ids

docs/source/create_dataset.mdx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ Sometimes, you may need to create a dataset if you're working with your own data
44

55
In this tutorial, you'll learn how to use 🤗 Datasets low-code methods for creating all types of datasets:
66

7-
* Folder-based builders for quickly creating an image or audio dataset
8-
* `from_` methods for creating datasets from local files
7+
- Folder-based builders for quickly creating an image or audio dataset
8+
- `from_` methods for creating datasets from local files
99

1010
## File-based builders
1111

@@ -24,10 +24,10 @@ To get the list of supported formats and code examples, follow this guide [here]
2424

2525
There are two folder-based builders, [`ImageFolder`] and [`AudioFolder`]. These are low-code methods for quickly creating an image or speech and audio dataset with several thousand examples. They are great for rapidly prototyping computer vision and speech models before scaling to a larger dataset. Folder-based builders takes your data and automatically generates the dataset's features, splits, and labels. Under the hood:
2626

27-
* [`ImageFolder`] uses the [`~datasets.Image`] feature to decode an image file. Many image extension formats are supported, such as jpg and png, but other formats are also supported. You can check the complete [list](https://github.com/huggingface/datasets/blob/b5672a956d5de864e6f5550e493527d962d6ae55/src/datasets/packaged_modules/imagefolder/imagefolder.py#L39) of supported image extensions.
28-
* [`AudioFolder`] uses the [`~datasets.Audio`] feature to decode an audio file. Audio extensions such as wav and mp3 are supported, and you can check the complete [list](https://github.com/huggingface/datasets/blob/b5672a956d5de864e6f5550e493527d962d6ae55/src/datasets/packaged_modules/audiofolder/audiofolder.py#L39) of supported audio extensions.
27+
- [`ImageFolder`] uses the [`~datasets.Image`] feature to decode an image file. Many image extension formats are supported, such as jpg and png, but other formats are also supported. You can check the complete [list](https://github.com/huggingface/datasets/blob/b5672a956d5de864e6f5550e493527d962d6ae55/src/datasets/packaged_modules/imagefolder/imagefolder.py#L39) of supported image extensions.
28+
- [`AudioFolder`] uses the [`~datasets.Audio`] feature to decode an audio file. Extensions such as wav, mp3, and even mp4 are supported, and you can check the complete [list](https://ffmpeg.org/ffmpeg-formats.html) of supported audio extensions. Decoding is done via ffmpeg.
2929

30-
The dataset splits are generated from the repository structure, and the label names are automatically inferred from the directory name.
30+
The dataset splits are generated from the repository structure, and the label names are automatically inferred from the directory name.
3131

3232
For example, if your image dataset (it is the same for an audio dataset) is stored like this:
3333

@@ -44,7 +44,7 @@ pokemon/test/water/wartortle.png
4444
Then this is how the folder-based builder generates an example:
4545

4646
<div class="flex justify-center">
47-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/folder-based-builder.png"/>
47+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/folder-based-builder.png" />
4848
</div>
4949

5050
Create the image dataset by specifying `imagefolder` in [`load_dataset`]:

docs/source/installation.md

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ You should install 🤗 Datasets in a [virtual environment](https://docs.python.
3030
```bash
3131
# Activate the virtual environment
3232
source .env/bin/activate
33-
33+
3434
# Deactivate the virtual environment
3535
source .env/bin/deactivate
3636
```
@@ -65,18 +65,6 @@ To work with audio datasets, you need to install the [`Audio`] feature as an ext
6565
pip install datasets[audio]
6666
```
6767

68-
<Tip warning={true}>
69-
70-
To decode mp3 files, you need to have at least version 1.1.0 of the `libsndfile` system library. Usually, it's bundled with the python [`soundfile`](https://github.com/bastibe/python-soundfile) package, which is installed as an extra audio dependency for 🤗 Datasets.
71-
For Linux, the required version of `libsndfile` is bundled with `soundfile` starting from version 0.12.0. You can run the following command to determine which version of `libsndfile` is being used by `soundfile`:
72-
73-
```bash
74-
python -c "import soundfile; print(soundfile.__libsndfile_version__)"
75-
```
76-
77-
</Tip>
78-
79-
8068
## Vision
8169

8270
To work with image datasets, you need to install the [`Image`] feature as an extra dependency:

0 commit comments

Comments
 (0)