Skip to content

Commit d9de082

Browse files
committed
Updated GenderFaceFilter node
1 parent bb9d0e0 commit d9de082

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ These custom nodes provide a rotation aware face extraction, paste back, and var
55
![Comparison](examples/comparison.jpg)
66

77
## Patch notes
8+
- 2024-05-22 - Updated GenderFaceFilter node.
89
- 2024-05-19 - Added BiSeNetMask and JonathandinuMask nodes. Careful about JonathandinuMask, it's more accurate than BiSeNet, but it takes more memory; you can get out of memory more easily with it.
910
- 2024-03-10 - Added nodes to detect faces using `face_yolov8m` instead of `insightface`.
1011

nodes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import torch
22
from collections import defaultdict
33
from .utils import *
4+
from transformers import pipeline
45

56
class GenderFaceFilter:
67
@classmethod
78
def INPUT_TYPES(cls):
89
return {
910
'required': {
1011
'faces': ('FACE',),
11-
'gender': (['male', 'female'],)
12+
'gender': (['man', 'woman'],)
1213
}
1314
}
1415

@@ -18,11 +19,16 @@ def INPUT_TYPES(cls):
1819
CATEGORY = 'facetools'
1920

2021
def run(self, faces, gender):
21-
gid = 0 if gender == 'female' else 1
2222
filtered = []
2323
rest = []
24+
pipe = pipeline('image-classification', model='dima806/man_woman_face_image_detection', device=0)
2425
for face in faces:
25-
if face.gender == gid:
26+
_, im = face.crop(224, 1.2)
27+
im = im.permute(0,3,1,2)[0]
28+
im = tv.transforms.functional.resize(im, (224,224))
29+
r = pipe(tv.transforms.functional.to_pil_image(im))
30+
idx = np.argmax([i['score'] for i in r])
31+
if r[idx]['label'] == gender:
2632
filtered.append(face)
2733
else:
2834
rest.append(face)

utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ def crop(self, size, crop_factor):
9696
N = N @ T4 @ S @ T3
9797
crop = cv2.warpAffine(self.img.numpy(), N, (size, size))
9898
crop = torch.from_numpy(crop)[None]
99-
# maskedcrop = cv2.warpAffine(self.img.numpy(), N, (size, size))
10099

101-
return N, crop#, maskedcrop
100+
return N, crop
102101

103102
def detect_faces(img, threshold):
104103
img = pad_to_stride(img, stride=32)

0 commit comments

Comments
 (0)