Skip to content

Commit 76c4d8b

Browse files
authored
✨ Add PyTorch image classification example (#13134)
* ✨ add pytorch image classification example * 🔥 remove utils.py * 💄 fix flake8 style issues * 🔥 remove unnecessary line * ✨ limit dataset sizes * 📌 update reqs * 🎨 restructure - use datasets lib * 🎨 import transforms directly * 📝 add comments * 💄 style * 🔥 remove flag * 📌 update requirement warning * 📝 add vision README.md * 📝 update README.md * 📝 update README.md * 🎨 add image-classification tag to model card * 🚚 rename vision ➡️ image-classification * 📝 update image-classification README.md
1 parent 9bd5d97 commit 76c4d8b

File tree

14 files changed

+529
-0
lines changed

14 files changed

+529
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
<!---
2+
Copyright 2021 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
-->
16+
17+
# Image classification examples
18+
19+
The following examples showcase how to fine-tune a `ViT` for image-classification using PyTorch.
20+
21+
## Using datasets from 🤗 `datasets`
22+
23+
Here we show how to fine-tune a `ViT` on the [beans](https://huggingface.co/datasets/beans) dataset.
24+
25+
👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans).
26+
27+
```bash
28+
python run_image_classification.py \
29+
--dataset_name beans \
30+
--output_dir ./beans_outputs/ \
31+
--remove_unused_columns False \
32+
--do_train \
33+
--do_eval \
34+
--push_to_hub \
35+
--push_to_hub_model_id vit-base-beans \
36+
--learning_rate 2e-5 \
37+
--num_train_epochs 5 \
38+
--per_device_train_batch_size 8 \
39+
--per_device_eval_batch_size 8 \
40+
--logging_strategy steps \
41+
--logging_steps 10 \
42+
--evaluation_strategy epoch \
43+
--save_strategy epoch \
44+
--load_best_model_at_end True \
45+
--save_total_limit 3 \
46+
--seed 1337
47+
```
48+
49+
Here we show how to fine-tune a `ViT` on the [cats_vs_dogs](https://huggingface.co/datasets/cats_vs_dogs) dataset.
50+
51+
👀 See the results here: [nateraw/vit-base-cats-vs-dogs](https://huggingface.co/nateraw/vit-base-cats-vs-dogs).
52+
53+
```bash
54+
python run_image_classification.py \
55+
--dataset_name cats_vs_dogs \
56+
--output_dir ./cats_vs_dogs_outputs/ \
57+
--remove_unused_columns False \
58+
--do_train \
59+
--do_eval \
60+
--push_to_hub \
61+
--push_to_hub_model_id vit-base-cats-vs-dogs \
62+
--fp16 True \
63+
--learning_rate 2e-4 \
64+
--num_train_epochs 5 \
65+
--per_device_train_batch_size 32 \
66+
--per_device_eval_batch_size 32 \
67+
--logging_strategy steps \
68+
--logging_steps 10 \
69+
--evaluation_strategy epoch \
70+
--save_strategy epoch \
71+
--load_best_model_at_end True \
72+
--save_total_limit 3 \
73+
--seed 1337
74+
```
75+
76+
## Using your own data
77+
78+
To use your own dataset, the training script expects the following directory structure:
79+
80+
```bash
81+
root/dog/xxx.png
82+
root/dog/xxy.png
83+
root/dog/[...]/xxz.png
84+
85+
root/cat/123.png
86+
root/cat/nsdf3.png
87+
root/cat/[...]/asd932_.png
88+
```
89+
90+
Once you've prepared your dataset, you can can run the script like this:
91+
92+
```bash
93+
python run_image_classification.py \
94+
--dataset_name nateraw/image-folder \
95+
--train_dir <path-to-train-root> \
96+
--output_dir ./outputs/ \
97+
--remove_unused_columns False \
98+
--do_train \
99+
--do_eval
100+
```
101+
102+
### 💡 The above will split the train dir into training and evaluation sets
103+
- To control the split amount, use the `--train_val_split` flag.
104+
- To provide your own validation split in its own directory, you can pass the `--validation_dir <path-to-val-root>` flag.
105+
106+
107+
## Sharing your model on 🤗 Hub
108+
109+
0. If you haven't already, [sign up](https://huggingface.co/join) for a 🤗 account
110+
111+
1. Make sure you have `git-lfs` installed and git set up.
112+
113+
```bash
114+
$ apt install git-lfs
115+
$ git config --global user.email "[email protected]"
116+
$ git config --global user.name "Your Name"
117+
```
118+
119+
2. Log in with your HuggingFace account credentials using `huggingface-cli`
120+
121+
```bash
122+
$ huggingface-cli login
123+
# ...follow the prompts
124+
```
125+
126+
3. When running the script, pass the following arguments:
127+
128+
```bash
129+
python run_image_classification.py \
130+
--push_to_hub \
131+
--push_to_hub_model_id <name-your-model> \
132+
...
133+
```
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=1.9.0
2+
torchvision>=0.10.0

0 commit comments

Comments
 (0)