|
| 1 | +import cv2 |
| 2 | +import numpy as np |
| 3 | +from ..r_nudenet.nudenet import NudeDetector |
| 4 | +import os |
| 5 | +import torch |
| 6 | +import folder_paths as comfy_paths |
| 7 | +from folder_paths import models_dir |
| 8 | +from typing import Union, List |
| 9 | +import json |
| 10 | +import logging |
| 11 | +logger = logging.getLogger(__file__) |
| 12 | + |
| 13 | +comfy_paths.folder_names_and_paths["nsfw"] = ([os.path.join(models_dir, "nsfw")], {".pt",".onnx"}) |
| 14 | + |
| 15 | +def tensor2np(tensor: torch.Tensor): |
| 16 | + if len(tensor.shape) == 3: # Single image |
| 17 | + return np.clip(255.0 * tensor.cpu().numpy(), 0, 255).astype(np.uint8) |
| 18 | + else: # Batch of images |
| 19 | + return [np.clip(255.0 * t.cpu().numpy(), 0, 255).astype(np.uint8) for t in tensor] |
| 20 | + |
| 21 | +def np2tensor(img_np: Union[np.ndarray, List[np.ndarray]]) -> torch.Tensor: |
| 22 | + if isinstance(img_np, list): |
| 23 | + return torch.cat([np2tensor(img) for img in img_np], dim=0) |
| 24 | + return torch.from_numpy(img_np.astype(np.float32) / 255.0).unsqueeze(0) |
| 25 | + |
| 26 | + |
| 27 | +class DetectorForNSFW: |
| 28 | + |
| 29 | + def __init__(self) -> None: |
| 30 | + self.model = None |
| 31 | + |
| 32 | + @classmethod |
| 33 | + def INPUT_TYPES(cls): |
| 34 | + return { |
| 35 | + "required": { |
| 36 | + "image": ("IMAGE",), |
| 37 | + "detect_size":([640, 320], {"default": 320}), |
| 38 | + "provider": (["CPU", "CUDA", "ROCM"], ), |
| 39 | + }, |
| 40 | + "optional": { |
| 41 | + "model_name": (comfy_paths.get_filename_list("nsfw"), {"default": None}), |
| 42 | + "alternative_image": ("IMAGE",), |
| 43 | + "buttocks_exposed": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05}), |
| 44 | + "female_breast_exposed": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05}), |
| 45 | + "female_genitalia_exposed": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05}), |
| 46 | + "anus_exposed": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.05}), |
| 47 | + "male_genitalia_exposed": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.05}), |
| 48 | + }, |
| 49 | + } |
| 50 | + |
| 51 | + RETURN_TYPES = ("IMAGE", "STRING") |
| 52 | + RETURN_NAMES = ("filtered_image", "detect_result") |
| 53 | + FUNCTION = "filter_exposure" |
| 54 | + |
| 55 | + CATEGORY = "utils/filter" |
| 56 | + |
| 57 | + all_labels = [ |
| 58 | + "FEMALE_GENITALIA_COVERED", |
| 59 | + "FACE_FEMALE", |
| 60 | + "BUTTOCKS_EXPOSED", |
| 61 | + "FEMALE_BREAST_EXPOSED", |
| 62 | + "FEMALE_GENITALIA_EXPOSED", |
| 63 | + "MALE_BREAST_EXPOSED", |
| 64 | + "ANUS_EXPOSED", |
| 65 | + "FEET_EXPOSED", |
| 66 | + "BELLY_COVERED", |
| 67 | + "FEET_COVERED", |
| 68 | + "ARMPITS_COVERED", |
| 69 | + "ARMPITS_EXPOSED", |
| 70 | + "FACE_MALE", |
| 71 | + "BELLY_EXPOSED", |
| 72 | + "MALE_GENITALIA_EXPOSED", |
| 73 | + "ANUS_COVERED", |
| 74 | + "FEMALE_BREAST_COVERED", |
| 75 | + "BUTTOCKS_COVERED", |
| 76 | + ] |
| 77 | + |
| 78 | + def filter_exposure(self, image, model_name=None, detect_size=320, provider="CPU", alternative_image=None, **kwargs): |
| 79 | + if self.model is None: |
| 80 | + self.init_model(model_name, detect_size, provider) |
| 81 | + |
| 82 | + if alternative_image is not None: |
| 83 | + alternative_image = tensor2np(alternative_image) |
| 84 | + |
| 85 | + images = tensor2np(image) |
| 86 | + if not isinstance(images, List): |
| 87 | + images = [images] |
| 88 | + |
| 89 | + results, result_info = [],[] |
| 90 | + for img in images: |
| 91 | + detect_result = self.model.detect(img) |
| 92 | + |
| 93 | + logger.debug(f"nudenet detect result:{detect_result}") |
| 94 | + filtered_results = [] |
| 95 | + for item in detect_result: |
| 96 | + label = item['class'] |
| 97 | + score = item['score'] |
| 98 | + confidence_level = kwargs.get(label.lower()) |
| 99 | + if label.lower() in kwargs and score > confidence_level: |
| 100 | + filtered_results.append(item) |
| 101 | + info = {"detect_result":detect_result} |
| 102 | + if len(filtered_results) == 0: |
| 103 | + results.append(img) |
| 104 | + info["nsfw"] = False |
| 105 | + else: |
| 106 | + placeholder_image = alternative_image if alternative_image is not None else np.ones_like(img) * 255 |
| 107 | + results.append(placeholder_image) |
| 108 | + info["nsfw"] = True |
| 109 | + |
| 110 | + result_info.append(info) |
| 111 | + |
| 112 | + result_tensor = np2tensor(results) |
| 113 | + result_info = json.dumps(result_info) |
| 114 | + return (result_tensor, result_info,) |
| 115 | + |
| 116 | + def init_model(self, model_name, detect_size, provider): |
| 117 | + model_path = comfy_paths.get_full_path("nsfw", model_name) if model_name else None |
| 118 | + self.model = NudeDetector(model_path=model_path, providers=[provider + 'ExecutionProvider',], inference_resolution=detect_size) |
| 119 | + |
| 120 | + |
| 121 | +NODE_CLASS_MAPPINGS = { |
| 122 | + #image |
| 123 | + "DetectorForNSFW": DetectorForNSFW, |
| 124 | + |
| 125 | +} |
| 126 | + |
| 127 | +NODE_DISPLAY_NAME_MAPPINGS = { |
| 128 | + # Image |
| 129 | + "DetectorForNSFW": "detector for the NSFW", |
| 130 | + |
| 131 | +} |
0 commit comments