Skip to content

Commit fdef825

Browse files
brkirchAUTOMATIC1111
authored andcommitted
Add 'interrogate' and 'all' choices to --use-cpu
* Add 'interrogate' and 'all' choices to --use-cpu * Change type for --use-cpu argument to str.lower, so that choices are case insensitive
1 parent fdecb63 commit fdef825

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

modules/devices.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def enable_tf32():
3434

3535
errors.run(enable_tf32, "Enabling TF32")
3636

37-
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
37+
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
3838
dtype = torch.float16
3939
dtype_vae = torch.float16
4040

modules/interrogate.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def load_clip_model(self):
5555

5656
model, preprocess = clip.load(clip_model_name)
5757
model.eval()
58-
model = model.to(shared.device)
58+
model = model.to(devices.device_interrogate)
5959

6060
return model, preprocess
6161

@@ -65,14 +65,14 @@ def load(self):
6565
if not shared.cmd_opts.no_half:
6666
self.blip_model = self.blip_model.half()
6767

68-
self.blip_model = self.blip_model.to(shared.device)
68+
self.blip_model = self.blip_model.to(devices.device_interrogate)
6969

7070
if self.clip_model is None:
7171
self.clip_model, self.clip_preprocess = self.load_clip_model()
7272
if not shared.cmd_opts.no_half:
7373
self.clip_model = self.clip_model.half()
7474

75-
self.clip_model = self.clip_model.to(shared.device)
75+
self.clip_model = self.clip_model.to(devices.device_interrogate)
7676

7777
self.dtype = next(self.clip_model.parameters()).dtype
7878

@@ -99,11 +99,11 @@ def rank(self, image_features, text_array, top_count=1):
9999
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
100100

101101
top_count = min(top_count, len(text_array))
102-
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
102+
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
103103
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
104104
text_features /= text_features.norm(dim=-1, keepdim=True)
105105

106-
similarity = torch.zeros((1, len(text_array))).to(shared.device)
106+
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
107107
for i in range(image_features.shape[0]):
108108
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
109109
similarity /= image_features.shape[0]
@@ -116,7 +116,7 @@ def generate_caption(self, pil_image):
116116
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
117117
transforms.ToTensor(),
118118
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
119-
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
119+
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
120120

121121
with torch.no_grad():
122122
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
@@ -140,7 +140,7 @@ def interrogate(self, pil_image, include_ranks=False):
140140

141141
res = caption
142142

143-
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
143+
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
144144

145145
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
146146
with torch.no_grad(), precision_scope("cuda"):

modules/shared.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
5555
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
5656
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
57-
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
57+
parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
5858
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
5959
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
6060
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
@@ -76,8 +76,8 @@
7676

7777
cmd_opts = parser.parse_args()
7878

79-
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
80-
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
79+
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
80+
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
8181

8282
device = devices.device
8383

0 commit comments

Comments
 (0)