Skip to content

Commit ff9cf22

Browse files
authored
Merge pull request #33 from johndpope/feat/32-cycle-consistency
DRAFT - Feat/32 cycle consistency
2 parents ace2981 + 9eb3a11 commit ff9cf22

26 files changed

+6022
-271
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,8 @@ models/gaze_model_pytorch_vgg16_prl_mpii_allsubjects1.model
66
*.dat
77
*.pth
88
# output_images/*.*
9+
*.png
10+
junk/-2KGPYEFnsU_11_nobg.mp4
11+
*.png
12+
*.png
13+
*.npz

EmoDataset.md

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
## EMODataset Class Summary
2+
3+
### Overview
4+
The `EMODataset` class is a PyTorch dataset for processing and augmenting video frames, with functionalities to remove backgrounds, warp and crop faces, and save/load processed frames efficiently. The class is designed to handle large video datasets and includes methods to streamline the preprocessing pipeline.
5+
6+
### Dependencies
7+
The class relies on the following libraries:
8+
- `moviepy.editor`: Video editing and processing.
9+
- `PIL.Image`: Image processing.
10+
- `torch`: PyTorch for tensor operations and model support.
11+
- `torchvision.transforms`: Image transformations.
12+
- `decord`: Efficient video reading.
13+
- `rembg`: Background removal.
14+
- `face_recognition`: Face detection.
15+
- `skimage.transform`: Image warping.
16+
- `cv2`: Video writing with OpenCV.
17+
- `numpy`: Array operations.
18+
- `io`, `os`, `json`, `Path`, `subprocess`, `tqdm`: Standard libraries for file handling, I/O operations, and progress visualization.
19+
20+
### Initialization
21+
The `__init__` method sets up the dataset with various parameters:
22+
- `use_gpu`, `sample_rate`, `n_sample_frames`, `width`, `height`, `img_scale`, `img_ratio`, `video_dir`, `drop_ratio`, `json_file`, `stage`, `transform`, `remove_background`, `use_greenscreen`, `apply_crop_warping`
23+
- Loads video metadata from the provided JSON file.
24+
- Initializes decord for video reading with PyTorch tensor output.
25+
26+
### Methods
27+
28+
#### `__len__`
29+
Returns the length of the dataset, determined by the number of video IDs.
30+
31+
#### `warp_and_crop_face`
32+
Processes an image tensor to detect, warp, and crop the face region:
33+
- Converts tensor to PIL image.
34+
- Removes background.
35+
- Detects face locations.
36+
- Crops the face region.
37+
- Optionally applies thin-plate-spline warping.
38+
- Converts the processed image back to a tensor and returns it.
39+
40+
#### `load_and_process_video`
41+
Loads and processes video frames:
42+
- Checks if processed tensor file exists; if so, loads tensors.
43+
- If not, processes video frames, applies augmentation, and saves frames as PNG images and tensors.
44+
- Saves processed tensors as compressed numpy arrays for efficient loading.
45+
46+
#### `augmentation`
47+
Applies transformations and optional background removal to the provided images:
48+
- Supports both single images and lists of images.
49+
- Returns transformed tensors.
50+
51+
#### `remove_bg`
52+
Removes the background from the provided image using `rembg`:
53+
- Optionally applies a green screen background.
54+
- Converts image to RGB format and returns it.
55+
56+
#### `save_video`
57+
Saves a list of frames as a video file:
58+
- Uses OpenCV to write frames to a video file.
59+
60+
#### `process_video`
61+
Processes all frames of a video:
62+
- Uses the `process_video_frames` method to process frames.
63+
64+
#### `process_video_frames`
65+
Processes frames of a video using decord:
66+
- Reads frames using decord and applies augmentation.
67+
- Returns processed frames.
68+
69+
#### `__getitem__`
70+
Returns a sample from the dataset:
71+
- Loads and processes source and driving videos.
72+
- Returns a dictionary containing video IDs and frames.
73+
74+
### Usage
75+
To use the `EMODataset` class:
76+
1. Initialize the dataset with appropriate parameters.
77+
2. Use PyTorch DataLoader to iterate over the dataset and retrieve samples.
78+
3. Process the frames as needed for training or inference in a machine learning model.
79+
80+
### Example
81+
```python
82+
from torchvision import transforms
83+
84+
transform = transforms.Compose([
85+
transforms.Resize((512, 512)),
86+
transforms.ToTensor(),
87+
])
88+
89+
dataset = EMODataset(
90+
use_gpu=False,
91+
sample_rate=5,
92+
n_sample_frames=16,
93+
width=512,
94+
height=512,
95+
img_scale=(0.9, 1.0),
96+
video_dir="path/to/videos",
97+
json_file="path/to/metadata.json",
98+
transform=transform,
99+
remove_background=True,
100+
use_greenscreen=False,
101+
apply_crop_warping=True
102+
)
103+
104+
for sample in dataset:
105+
print(sample)
106+
```
107+
108+
This class provides a comprehensive pipeline for processing video data, making it suitable for tasks such as training deep learning models on video datasets.

0 commit comments

Comments
 (0)