1
+ # based on:
2
+ # https://github.com/FacePerceiver/facer/blob/main/facer/face_parsing/farl.py
3
+ import functools
4
+ import logging .config
5
+ logging .config .fileConfig ("config/logging.conf" )
6
+ logger = logging .getLogger ('sdk' )
7
+
8
+ import torch
9
+ import torch .nn .functional as F
10
+ import numpy as np
11
+ from math import ceil
12
+ from itertools import product as product
13
+ import torch .backends .cudnn as cudnn
14
+
15
+ from core .model_handler .BaseModelHandler import BaseModelHandler
16
+ from utils .transform import *
17
+
18
+ pretrain_settings = {
19
+ 'lapa/448' : {
20
+ 'matrix_src_tag' : 'points' ,
21
+ 'get_matrix_fn' : functools .partial (get_face_align_matrix ,
22
+ target_shape = (448 , 448 ), target_face_scale = 1.0 ),
23
+ 'get_grid_fn' : functools .partial (make_tanh_warp_grid ,
24
+ warp_factor = 0.8 , warped_shape = (448 , 448 )),
25
+ 'get_inv_grid_fn' : functools .partial (make_inverted_tanh_warp_grid ,
26
+ warp_factor = 0.8 , warped_shape = (448 , 448 )),
27
+ 'label_names' : ['background' , 'face' , 'rb' , 'lb' , 're' ,
28
+ 'le' , 'nose' , 'ulip' , 'imouth' , 'llip' , 'hair' ]
29
+ }
30
+ }
31
+
32
+
33
+ class FaceParsingModelHandler (BaseModelHandler ):
34
+ def __init__ (self , model = None , device = None , cfg = None ):
35
+ super ().__init__ (model , device , cfg )
36
+
37
+ self .model = model .to (self .device )
38
+ def _preprocess (self , image , face_nums ):
39
+ """Preprocess the image, such as standardization and other operations.
40
+
41
+ Returns:
42
+ A tensor, the shape is 1 x 3 x h x w.
43
+ A dict, {'rects','points','scores','image_ids'}
44
+ """
45
+ if not isinstance (image , np .ndarray ):
46
+ logger .error ('The input should be the ndarray read by cv2!' )
47
+ raise InputError ()
48
+ img = np .float32 (image )
49
+ img = img .transpose (2 , 0 , 1 )
50
+ img = np .expand_dims (img ,0 ).repeat (face_nums ,axis = 0 )
51
+ return torch .from_numpy (img )
52
+ def inference_on_image (self , face_nums : int , images : torch .Tensor , landmarks ):
53
+ """Get the inference of the image and process the inference result.
54
+
55
+ Returns:
56
+
57
+ """
58
+ cudnn .benchmark = True
59
+ try :
60
+ image_pre = self ._preprocess (images , face_nums )
61
+ except Exception as e :
62
+ raise e
63
+ setting = pretrain_settings ['lapa/448' ]
64
+ images = image_pre .float () / 255.0
65
+ _ , _ , h , w = images .shape
66
+ simages = images .to (self .device )
67
+ matrix = setting ['get_matrix_fn' ](landmarks .to (self .device ))
68
+ grid = setting ['get_grid_fn' ](matrix = matrix , orig_shape = (h , w ))
69
+ inv_grid = setting ['get_inv_grid_fn' ](matrix = matrix , orig_shape = (h , w ))
70
+
71
+ w_images = F .grid_sample (
72
+ simages , grid , mode = 'bilinear' , align_corners = False )
73
+
74
+ w_seg_logits , _ = self .model (w_images ) # (b*n) x c x h x w
75
+
76
+ seg_logits = F .grid_sample (
77
+ w_seg_logits , inv_grid , mode = 'bilinear' , align_corners = False )
78
+ data_pre = {}
79
+ data_pre ['seg' ] = {'logits' : seg_logits ,
80
+ 'label_names' : setting ['label_names' ]}
81
+ return data_pre
82
+
83
+ def _postprocess (self , loc , conf , scale , input_height , input_width ):
84
+ """Postprecess the prediction result.
85
+ Decode detection result, set the confidence threshold and do the NMS
86
+ to keep the appropriate detection box.
87
+
88
+ Returns:
89
+ A numpy array, the shape is N * (x, y, w, h, confidence),
90
+ N is the number of detection box.
91
+ """
92
+ pass
0 commit comments