Skip to content

Commit 6687b07

Browse files
author
fengyu32
committed
add-faceparsing
1 parent db0b087 commit 6687b07

File tree

12 files changed

+716
-0
lines changed

12 files changed

+716
-0
lines changed
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#
2+
This repo hosts the face_parsing implementation of the CVPR2022 paper "General Facial Representation Learning in a Visual-Linguistic Manner"
3+
4+
# Some Results by FaRL
5+
![image](Data/images/face_parsing.jpg)
6+
7+
# Requirements
8+
* python >= 3.7.1
9+
* pytorch >= 1.9.1
10+
11+
# Pre-trained Model
12+
[face_parsing.farl.lapa]https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt
13+
Please put the pre-trained model under FaceX-Zoo/face_sdk/models/face_parsing/face_parsing_1.0/
14+
# Usage
15+
```sh
16+
cd ../../face_sdk
17+
python api_usage/face_parsing.py
18+
```s
19+
20+
# Reference
21+
This project is mainly inspired by [FaRL](https://github.com/FacePerceiver/FaRL).

data/images/face_parsing.jpg

39 KB
Loading

face_sdk/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ python api_usage/face_alignment.py # Get 106 landmark of a face
2323
python api_usage/face_crop.py # Get croped face from a image
2424
python api_usage/face_feature.py # Get features of a face
2525
python api_usage/face_pipline.py # Run face recognition pipeline
26+
python api_usage/face_parsing.py # Run face parsing pipeline
2627
```
2728
The results will be saved at [api_usage/temp](api_usage/temp)
2829
## Update the models

face_sdk/api_usage/face_parsing.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import sys
2+
sys.path.append('.')
3+
import logging
4+
mpl_logger = logging.getLogger('matplotlib')
5+
mpl_logger.setLevel(logging.WARNING)
6+
import logging.config
7+
logging.config.fileConfig("config/logging.conf")
8+
logger = logging.getLogger('api')
9+
10+
import yaml
11+
import cv2
12+
import numpy as np
13+
import torch
14+
from utils.show import show_bchw
15+
from utils.draw import draw_bchw
16+
from core.model_loader.face_parsing.FaceParsingModelLoader import FaceParsingModelLoader
17+
from core.model_handler.face_parsing.FaceParsingModelHandler import FaceParsingModelHandler
18+
from core.model_loader.face_detection.FaceDetModelLoader import FaceDetModelLoader
19+
from core.model_handler.face_detection.FaceDetModelHandler import FaceDetModelHandler
20+
from core.model_loader.face_alignment.FaceAlignModelLoader import FaceAlignModelLoader
21+
from core.model_handler.face_alignment.FaceAlignModelHandler import FaceAlignModelHandler
22+
23+
with open('config/model_conf.yaml') as f:
24+
model_conf = yaml.load(f,Loader=yaml.FullLoader)
25+
26+
if __name__ == '__main__':
27+
# common setting for all models, need not modify.
28+
model_path = 'models'
29+
30+
# face detection model setting.
31+
scene = 'non-mask'
32+
model_category = 'face_detection'
33+
model_name = model_conf[scene][model_category]
34+
logger.info('Start to load the face detection model...')
35+
try:
36+
faceDetModelLoader = FaceDetModelLoader(model_path, model_category, model_name)
37+
model, cfg = faceDetModelLoader.load_model()
38+
faceDetModelHandler = FaceDetModelHandler(model, 'cuda:0', cfg)
39+
except Exception as e:
40+
logger.error('Falied to load face detection Model.')
41+
logger.error(e)
42+
sys.exit(-1)
43+
else:
44+
logger.info('Success!')
45+
46+
# face landmark model setting.
47+
model_category = 'face_alignment'
48+
model_name = model_conf[scene][model_category]
49+
logger.info('Start to load the face landmark model...')
50+
try:
51+
faceAlignModelLoader = FaceAlignModelLoader(model_path, model_category, model_name)
52+
model, cfg = faceAlignModelLoader.load_model()
53+
faceAlignModelHandler = FaceAlignModelHandler(model, 'cuda:0', cfg)
54+
except Exception as e:
55+
logger.error('Failed to load face landmark model.')
56+
logger.error(e)
57+
sys.exit(-1)
58+
else:
59+
logger.info('Success!')
60+
61+
# face parsing model setting.
62+
scene = 'non-mask'
63+
model_category = 'face_parsing'
64+
model_name = model_conf[scene][model_category]
65+
logger.info('Start to load the face parsing model...')
66+
try:
67+
faceParsingModelLoader = FaceParsingModelLoader(model_path, model_category, model_name)
68+
model, cfg = faceParsingModelLoader.load_model()
69+
faceParsingModelHandler = FaceParsingModelHandler(model, 'cuda:0', cfg)
70+
except Exception as e:
71+
logger.error('Falied to load face parsing Model.')
72+
logger.error(e)
73+
sys.exit(-1)
74+
else:
75+
logger.info('Success!')
76+
77+
78+
79+
# read image and get face features.
80+
image_path = 'api_usage/test_images/test1.jpg'
81+
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
82+
try:
83+
dets = faceDetModelHandler.inference_on_image(image)
84+
face_nums = dets.shape[0]
85+
with torch.no_grad():
86+
for i in range(face_nums):
87+
landmarks = faceAlignModelHandler.inference_on_image(image, dets[i])
88+
89+
landmarks = torch.from_numpy(landmarks[[104,105,54,84,90]]).float()
90+
if i == 0:
91+
landmarks_five = landmarks
92+
else:
93+
landmarks_five = torch.stack([landmarks_five,landmarks], dim = 0)
94+
95+
print(landmarks_five.shape)
96+
faces = faceParsingModelHandler.inference_on_image(face_nums, image, landmarks_five)
97+
seg_logits = faces['seg']['logits']
98+
99+
100+
seg_probs = seg_logits.softmax(dim=1) # nfaces x nclasses x h x w
101+
show_bchw(draw_bchw(image, faces))
102+
103+
except Exception as e:
104+
logger.error('Parsing failed!')
105+
logger.error(e)
106+
sys.exit(-1)
107+
else:
108+
logger.info('Success!')

face_sdk/config/model_conf.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ non-mask:
22
face_detection: face_detection_1.0
33
face_alignment: face_alignment_1.0
44
face_recognition: face_recognition_1.0
5+
face_parsing: face_parsing_1.0
56
mask:
67
face_detection: face_detection_2.0
78
face_alignment: face_alignment_2.0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# based on:
2+
# https://github.com/FacePerceiver/facer/blob/main/facer/face_parsing/farl.py
3+
import functools
4+
import logging.config
5+
logging.config.fileConfig("config/logging.conf")
6+
logger = logging.getLogger('sdk')
7+
8+
import torch
9+
import torch.nn.functional as F
10+
import numpy as np
11+
from math import ceil
12+
from itertools import product as product
13+
import torch.backends.cudnn as cudnn
14+
15+
from core.model_handler.BaseModelHandler import BaseModelHandler
16+
from utils.transform import *
17+
18+
pretrain_settings = {
19+
'lapa/448': {
20+
'matrix_src_tag': 'points',
21+
'get_matrix_fn': functools.partial(get_face_align_matrix,
22+
target_shape=(448, 448), target_face_scale=1.0),
23+
'get_grid_fn': functools.partial(make_tanh_warp_grid,
24+
warp_factor=0.8, warped_shape=(448, 448)),
25+
'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
26+
warp_factor=0.8, warped_shape=(448, 448)),
27+
'label_names': ['background', 'face', 'rb', 'lb', 're',
28+
'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
29+
}
30+
}
31+
32+
33+
class FaceParsingModelHandler(BaseModelHandler):
34+
def __init__(self, model=None, device=None, cfg=None):
35+
super().__init__(model, device, cfg)
36+
37+
self.model = model.to(self.device)
38+
def _preprocess(self, image, face_nums):
39+
"""Preprocess the image, such as standardization and other operations.
40+
41+
Returns:
42+
A tensor, the shape is 1 x 3 x h x w.
43+
A dict, {'rects','points','scores','image_ids'}
44+
"""
45+
if not isinstance(image, np.ndarray):
46+
logger.error('The input should be the ndarray read by cv2!')
47+
raise InputError()
48+
img = np.float32(image)
49+
img = img.transpose(2, 0, 1)
50+
img = np.expand_dims(img,0).repeat(face_nums,axis=0)
51+
return torch.from_numpy(img)
52+
def inference_on_image(self, face_nums: int, images: torch.Tensor, landmarks):
53+
"""Get the inference of the image and process the inference result.
54+
55+
Returns:
56+
57+
"""
58+
cudnn.benchmark = True
59+
try:
60+
image_pre = self._preprocess(images, face_nums)
61+
except Exception as e:
62+
raise e
63+
setting = pretrain_settings['lapa/448']
64+
images = image_pre.float() / 255.0
65+
_, _, h, w = images.shape
66+
simages = images.to(self.device)
67+
matrix = setting['get_matrix_fn'](landmarks.to(self.device))
68+
grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
69+
inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
70+
71+
w_images = F.grid_sample(
72+
simages, grid, mode='bilinear', align_corners=False)
73+
74+
w_seg_logits, _ = self.model(w_images) # (b*n) x c x h x w
75+
76+
seg_logits = F.grid_sample(
77+
w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
78+
data_pre = {}
79+
data_pre['seg'] = {'logits': seg_logits,
80+
'label_names': setting['label_names']}
81+
return data_pre
82+
83+
def _postprocess(self, loc, conf, scale, input_height, input_width):
84+
"""Postprecess the prediction result.
85+
Decode detection result, set the confidence threshold and do the NMS
86+
to keep the appropriate detection box.
87+
88+
Returns:
89+
A numpy array, the shape is N * (x, y, w, h, confidence),
90+
N is the number of detection box.
91+
"""
92+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import logging.config
2+
logging.config.fileConfig("config/logging.conf")
3+
logger = logging.getLogger('sdk')
4+
5+
import torch
6+
7+
from core.model_loader.BaseModelLoader import BaseModelLoader
8+
9+
class FaceParsingModelLoader(BaseModelLoader):
10+
def __init__(self, model_path, model_category, model_name, meta_file='model_meta.json'):
11+
logger.info('Start to analyze the face parsing model, model path: %s, model category: %s,model name: %s' %
12+
(model_path, model_category, model_name))
13+
super().__init__(model_path, model_category, model_name, meta_file)
14+
15+
self.cfg['input_height'] = self.meta_conf['input_height']
16+
self.cfg['input_width'] = self.meta_conf['input_width']
17+
18+
19+
def load_model(self):
20+
try:
21+
model = torch.jit.load(self.cfg['model_file_path'])
22+
except Exception as e:
23+
logger.error('The model failed to load, please check the model path: %s!'
24+
% self.cfg['model_file_path'])
25+
raise e
26+
else:
27+
logger.info('Successfully loaded the face parsing model!')
28+
return model, self.cfg
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"model_type" : "face_parsing.farl.lapa",
3+
"model_info" : "some model info",
4+
"model_file" : "face_parsing.farl.lapa.main_ema_136500_jit191.pt",
5+
"release_date" : "20220226",
6+
"input_height" : 448,
7+
"input_width" : 448
8+
}

face_sdk/test.jpg

39.2 KB
Loading

0 commit comments

Comments
 (0)