Sketch์ด๋ฏธ์ง ๋ถ๋ฅ ๊ฒฝ์ง๋ํ๋ ์ฃผ์ด์ง ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ์ฌ ๋ชจ๋ธ์ ์ ์ํ๊ณ ์ด๋ค ๊ฐ์ฒด๋ฅผ ๋ํ๋ด๋์ง ๋ถ๋ฅํ๋ ๋ํ์ ๋๋ค.
Computer Vision์์๋ ๋ค์ํ ํํ์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๊ฐ ํ์ฉ๋๊ณ ์์ต๋๋ค. ์ด ์ค, ๋น์ ํ ๋ฐ์ดํฐ์ ์ ํํ ์ธ์๊ณผ ๋ถ๋ฅ๋ ์ฌ์ ํ ํด๊ฒฐํด์ผ ํ ์ฃผ์ ๊ณผ์ ๋ก ์๋ฆฌ์ก๊ณ ์์ต๋๋ค. ํนํ ์ฌ์ง๊ณผ ๊ฐ์ ์ผ๋ฐ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ๊ธฐ๋ฐํ์ฌ ๋ฐ์ ์ ์ด๋ฃจ์ด๋์๊ฐ๊ณ ์์ต๋๋ค.
ํ์ง๋ง ์ผ์์ ์ฌ์ง๊ณผ ๋ค๋ฅด๊ฒ ์ค์ผ์น๋ ์ธ๊ฐ์ ์์๋ ฅ๊ณผ ๊ฐ๋ ์ดํด๋ฅผ ๋ฐ์ํ๋ ์ถ์์ ์ด๊ณ ๋จ์ํ๋ ํํ์ ์ด๋ฏธ์ง์ ๋๋ค. ์ด๋ฌํ ์ค์ผ์น ๋ฐ์ดํฐ๋ ์์, ์ง๊ฐ, ์ธ๋ถ์ ์ธ ํํ๊ฐ ๋น๊ต์ ๊ฒฐ์ฌ๋์ด ์์ผ๋ฉฐ, ๋์ ์ ๊ธฐ๋ณธ์ ์ธ ํํ์ ๊ตฌ์กฐ์ ์ด์ ์ ๋ง์ถฅ๋๋ค. ์ด๋ ์ค์ผ์น๊ฐ ์ค์ ๊ฐ์ฒด์ ๋ณธ์ง์ ํน์ง์ ๊ฐ๊ฒฐํ๊ฒ ํํํ๋๋ฐ์ ์ค์ ์ ๋๊ณ ์๋ค๋ ์ ์ ๋ณด์ฌ์ค๋๋ค.
์ด๋ฌํ ์ค์ผ์น ๋ฐ์ดํฐ์ ํน์ฑ์ ์ดํดํ๊ณ ์ค์ผ์น ์ด๋ฏธ์ง๋ฅผ ํตํด ๋ชจ๋ธ์ด ๊ฐ์ฒด์ ๊ธฐ๋ณธ์ ์ธ ํํ์ ๊ตฌ์กฐ๋ฅผ ํ์ตํ๊ณ ์ธ์ํ๋๋ก ํจ์ผ๋ก์จ, ์ผ๋ฐ์ ์ธ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์์ ์ฐจ์ด์ ์ ์ดํดํ๊ณ ๋ ๋ค๋ฅธ ๊ด์ ์ ๋ํ ๋ชจ๋ธ ๊ฐ๋ฐ ์ญ๋์ ๋์ด๋๋ฐ์ ์ด์ ์ ๋์์ต๋๋ค. ์ด๋ฅผ ํตํด ์ค์ ์ธ๊ณ์ ๋ณต์กํ๊ณ ๋ค์ํ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ๋ํ ์ฐฝ์์ ์ธ ์ ๊ทผ๋ฐฉ๋ฒ๊ณผ ์ฒ๋ฆฌ ๋ฅ๋ ฅ์ ๋์ผ ์ ์์ต๋๋ค. ๋ํ, ์ค์ผ์น ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ๋ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ๋์งํธ ์์ , ๊ฒ์ ๊ฐ๋ฐ, ๊ต์ก ์ฝํ
์ธ ์์ฑ ๋ฑ ๋ค์ํ ๋ถ์ผ์์ ์์ฉ๋ ์ ์์ต๋๋ค.
ํ๋ก์ ํธ ์ ์ฒด ์ผ์
- 2024.09.10 (ํ) 10:00 ~ 2024.09.26 (๋ชฉ) 17:00
- Language : Python
- Environment
- CPU : Intel(R) Xeon(R) Gold 5120
- GPU : Tesla V100-SXM2 32GB ร 1
- Framework : PyTorch
- Collaborative Tool : Git, Wandb, Notion
๐ฆdata
โฃ ๐sample_submission.csv
โฃ ๐test.csv
โฃ ๐train.csv
โฃ ๐test
โ โฃ ๐0.JPEG
โ โฃ ๐1.JPEG
โ โฃ ๐2.JPEG
โ โ ...
โฃ ๐train
โ โฃ ๐n01443537
โ โฃ ๐n01484850
โ โ ...
- ํ์ต์ ์ฌ์ฉํ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ 15,021๊ฐ๋ก data/train/ ์๋์ ๊ฐ ๊ฐ์ฒด๋ณ ํด๋๋ก ๊ตฌ๋ถ๋์ด ์์ต๋๋ค.
- ์ ๊ณต๋๋ ์ด๋ฏธ์ง๋ ์ฃผ๋ก ์ฌ๋์ ์์ผ๋ก ๊ทธ๋ ค์ง ๋๋ก์์ด๋ ์ค์ผ์น๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
- train.csv์ test.csv์๋ ๊ฐ ์ด๋ฏธ์ง๋ณ ํด๋๋ช (class_name), ์ด๋ฏธ์ง ๊ฒฝ๋ก(image_path), ์์ธกํด์ผํ class(target)์ ๋ํ ์ ๋ณด๊ฐ ํฌํจ๋์ด ์์ต๋๋ค.
๐ฆlevel1-imageclassification-cv-05
โฃ ๐.github
โ โ ๐.keep
โฃ ๐data
โ โฃ ๐.DS_Store
โ โฃ ๐._DS_Store
โ โฃ ๐._sample_submission.csv
โ โฃ ๐._test.csv
โ โฃ ๐._train.csv
โ โฃ ๐sample_submission.csv
โ โฃ ๐test.csv
โ โ ๐train.csv
โฃ ๐model_checkpoints
โฃ ๐training_logs
โ โ ๐training_log.txt
โฃ ๐.gitignore
โฃ ๐augmentation.py
โฃ ๐augmentation_list.txt
โฃ ๐dataset.py
โฃ ๐inference.py
โฃ ๐main.py
โฃ ๐model.py
โฃ ๐README.md
โฃ ๐requirements.txt
โฃ ๐seed.py
โฃ ๐timm_list.txt
โ ๐train.py
- ๋ชจ๋ธ ํ์ต์ ์ํํ๋ ํจ์๋ก, ํ์ต๊ณผ ๊ฒ์ฆ ๋ฃจํ๋ฅผ ํฌํจํ์ฌ ์กฐ๊ธฐ ์ข ๋ฃ์ ์ฒดํฌํฌ์ธํธ ์ ์ฅ ๊ธฐ๋ฅ์ด ๊ตฌํ๋ ํ์ผ
- wandb ๋ก๊น , ํ์ต ์์ค ๊ณ์ฐ, ๊ฒ์ฆ, ๋ชจ๋ธ ์ ์ฅ, ๊ทธ๋ฆฌ๊ณ ์ต์ ์ ๋ชจ๋ธ ์ ํ ๋ฐ ์กฐ๊ธฐ ์ข ๋ฃ ๋ก์ง ํฌํจ
- ๋ชจ๋ ๋๋ค ์ฐ์ฐ์์ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ฌํํ ์ ์๋๋ก ์๋๋ฅผ ์ค์ ํ๋ ํ์ผ
- random, numpy, torch ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๊ด๋ จ๋ ์๋ ์ค์ ๋ฐ CUDA ๊ด๋ จ ๊ณ ์ ์ค์
- ConvNext ๋ชจ๋ธ์ ์ ์ํ ํ์ผ๋ก, timm ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ๋ฏธ๋ฆฌ ํ์ต๋ ๋ชจ๋ธ์ ๋ก๋
- ์ ๋ ฅ๋ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ์ ์ ๋ฌํ์ฌ ์์ธก์ ์ํํ๋ forward ๋ฉ์๋ ํฌํจ
- ํ์ต๊ณผ ์ถ๋ก ์ ์ํ ๋ฉ์ธ ์คํฌ๋ฆฝํธ๋ก, argparse๋ฅผ ํตํด ์ค์ ๊ฐ์ ๋ฐ์ ๋ชจ๋ธ ํ์ต๊ณผ ์ถ๋ก ์ ์ํ
- ๋ฐ์ดํฐ์ ๋ก๋, ํ์ต/๊ฒ์ฆ ๋ฃจํ, ์ฒดํฌํฌ์ธํธ ๋ก๋ ๋ฐ ์ ์ฅ, ์ถ๋ก ํ ๊ฒฐ๊ณผ ํ์ผ ์์ฑ
- ๋ชจ๋ธ์ ์ฌ์ฉํด ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ์ถ๋ก ํ๋ ํจ์์ ๊ฐ์ฅ ์ต๊ทผ์ ์ฒดํฌํฌ์ธํธ ํ์ผ์ ๊ฐ์ ธ์ค๋ ํจ์ ์ ์
- inference ํจ์๋ ์์ธก๊ฐ์ ๋ฐํํ๊ณ , get_latest_checkpoint ํจ์๋ ์ฒดํฌํฌ์ธํธ ๋๋ ํ ๋ฆฌ์์ ๊ฐ์ฅ ์ต๊ทผ ํ์ผ์ ์ ํ
- ํ์ต ๋ฐ ์ถ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๋ CustomDataset ํด๋์ค๋ฅผ ์ ์ํ ํ์ผ
- ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๊ณ , ์ฃผ์ด์ง ๋ณํ(transform)์ ์ ์ฉํ์ฌ ๋ฐํํ๋ฉฐ, ํ์ต ๋๋ ์ถ๋ก ๋ชจ๋์ ๋ฐ๋ผ ๋ผ๋ฒจ๊ณผ ํจ๊ป ๋ฐ์ดํฐ๋ฅผ ๋ฐํ
- ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ๋ค์ํ ๋ฐ์ดํฐ ์ฆ๊ฐ ๊ธฐ๋ฒ์ ์ ์ฉํ๋ SketchAutoAugment ํด๋์ค ์ ์
- ํ์ , ํฌ์คํฐํ, ์ ์น, ์์ ๋ฐ์ ๋ฑ ์ฌ๋ฌ Augmentation ์ ์ฑ ์ ๋๋ค์ผ๋ก ์ ์ฉํ์ฌ ์ด๋ฏธ์ง ๋ณํ
- pandas==2.1.4
- matplotlib==3.8.4
- seaborn==0.13.2
- Pillow==10.3.0
- numpy==1.26.3
- timm==0.9.16
- albumentations==1.4.4
- tqdm==4.66.1
- scikit-learn==1.4.2
- opencv-python==4.9.0.80
- wandb==0.18.0
pip install -r requirements.txt
wget https://aistages-api-public-prod.s3.amazonaws.com/app/Competitions/000307/data/data.tar.gz
python main.py --train_dir ../data/train --train_csv ../data/train.csv --test_dir ../data/test --test_csv ../data/test.csv --batch_size 16 --resize_height 448 --resize_width 448 --learning_rate 1e-4 --max_epochs 50
python main.py --train_dir ../data/train --train_csv ../data/train.csv --test_dir ../data/test --test_csv ../data/test.csv --resume_training --batch_size 16 --resize_height 448 --resize_width 448
ํด๋ฆญํด์ ํผ์น๊ธฐ/์ ๊ธฐ
-
--train_dir
(ํ์ ์ธ์):- ์ค๋ช : ํ์ต ๋ฐ์ดํฐ๊ฐ ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--train_dir ../data/train
-
--train_csv
(ํ์ ์ธ์):- ์ค๋ช : ํ์ต ๋ฐ์ดํฐ์ ์ด๋ฏธ์ง ๊ฒฝ๋ก์ ๋ ์ด๋ธ์ด ํฌํจ๋ CSV ํ์ผ ๊ฒฝ๋ก๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--train_csv ../data/train.csv
-
--test_dir
(ํ์ ์ธ์):- ์ค๋ช : ํ ์คํธ ๋ฐ์ดํฐ๊ฐ ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--test_dir ../data/test
-
--test_csv
(ํ์ ์ธ์):- ์ค๋ช : ํ ์คํธ ๋ฐ์ดํฐ์ ์ด๋ฏธ์ง ๊ฒฝ๋ก์ ID๊ฐ ํฌํจ๋ CSV ํ์ผ ๊ฒฝ๋ก๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--test_csv ../data/test.csv
-
--save_dir
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:./model_checkpoints
):- ์ค๋ช : ํ์ต๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--save_dir ./checkpoints
-
--log_dir
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:./training_logs
):- ์ค๋ช : ํ์ต ๋ก๊ทธ๋ฅผ ์ ์ฅํ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--log_dir ./logs
-
--batch_size
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:32
):- ์ค๋ช : ํ์ต๊ณผ ์ถ๋ก ์ ์ฌ์ฉํ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--batch_size 16
-
--learning_rate
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:1e-5
):- ์ค๋ช : ํ์ต ์ ์ฌ์ฉํ๋ ํ์ต๋ฅ ์ ์ค์ ํฉ๋๋ค.
- ์์:
--learning_rate 0.001
-
--weight_decay
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:0.01
):- ์ค๋ช : AdamW ์ตํฐ๋ง์ด์ ์์ ์ฌ์ฉํ๋ ๊ฐ์ค์น ๊ฐ์๊ฐ์ ์ค์ ํฉ๋๋ค.
- ์์:
--weight_decay 0.001
-
--max_epochs
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:50
):- ์ค๋ช : ํ์ตํ ์ต๋ ์ํฌํฌ ์๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--max_epochs 100
-
--accumulation_steps
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:8
):- ์ค๋ช : ๊ทธ๋๋์ธํธ ๋์ ์ ์ํ ์คํ ์๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--accumulation_steps 4
-
--patience
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:5
):- ์ค๋ช : ํ์ต ์ค ์กฐ๊ธฐ ์ข ๋ฃ(Early Stopping)๋ฅผ ์ํ patience๋ฅผ ์ค์ ํฉ๋๋ค. ์ด ๊ฐ์ ๊ฒ์ฆ ์์ค์ด ๊ฐ์ ๋์ง ์์ ๋ ๋ช ๋ฒ์ ์ํฌํฌ๋ฅผ ๋ ์คํํ ์ง ๊ฒฐ์ ํฉ๋๋ค.
- ์์:
--patience 10
-
--resume_training
(์ ํ์ ์ธ์):- ์ค๋ช : ๊ฐ์ฅ ์ต๊ทผ์ ์ฒดํฌํฌ์ธํธ์์ ํ์ต์ ์ฌ๊ฐํ ์ง ์ฌ๋ถ๋ฅผ ์ค์ ํฉ๋๋ค. ์ด ํ๋๊ทธ๋ฅผ ์ถ๊ฐํ๋ฉด, ํ์ต์ด ์ค๋จ๋ ์ฒดํฌํฌ์ธํธ์์ ์ด์ด์ ํ์ต์ด ๊ฐ๋ฅํฉ๋๋ค.
- ์์:
--resume_training
-
--resize_height
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:448
):- ์ค๋ช : ์ด๋ฏธ์ง ๋ณํ ์ ์ด๋ฏธ์ง์ ๋์ด๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--resize_height 512
-
--resize_width
(์ ํ์ ์ธ์, ๊ธฐ๋ณธ๊ฐ:448
):- ์ค๋ช : ์ด๋ฏธ์ง ๋ณํ ์ ์ด๋ฏธ์ง์ ๋๋น๋ฅผ ์ค์ ํฉ๋๋ค.
- ์์:
--resize_width 512