Skip to content

Commit 7ba3923

Browse files
committed
move DDIM/PLMS fix for OSX out of the file with inpainting code.
1 parent bb2e2c8 commit 7ba3923

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

modules/sd_hijack.py

+23
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import ldm.modules.attention
1616
import ldm.modules.diffusionmodules.model
17+
import ldm.models.diffusion.ddim
18+
import ldm.models.diffusion.plms
1719

1820
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
1921
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -406,3 +408,24 @@ def conv2d_constructor_circular(self, *args, **kwargs):
406408

407409

408410
model_hijack = StableDiffusionModelHijack()
411+
412+
413+
def register_buffer(self, name, attr):
414+
"""
415+
Fix register buffer bug for Mac OS.
416+
"""
417+
418+
if type(attr) == torch.Tensor:
419+
if attr.device != devices.device:
420+
421+
# would this not break cuda when torch adds has_mps() to main version?
422+
if getattr(torch, 'has_mps', False):
423+
attr = attr.to(device="mps", dtype=torch.float32)
424+
else:
425+
attr = attr.to(devices.device)
426+
427+
setattr(self, name, attr)
428+
429+
430+
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
431+
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer

modules/sd_hijack_inpainting.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import modules.devices as devices
32

43
from einops import repeat
54
from omegaconf import ListConfig
@@ -315,20 +314,6 @@ def __init__(
315314
self.masked_image_key = masked_image_key
316315
assert self.masked_image_key in concat_keys
317316
self.concat_keys = concat_keys
318-
319-
320-
# =================================================================================================
321-
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
322-
# =================================================================================================
323-
def register_buffer(self, name, attr):
324-
if type(attr) == torch.Tensor:
325-
optimal_type = devices.get_optimal_device()
326-
if attr.device != optimal_type:
327-
if getattr(torch, 'has_mps', False):
328-
attr = attr.to(device="mps", dtype=torch.float32)
329-
else:
330-
attr = attr.to(optimal_type)
331-
setattr(self, name, attr)
332317

333318

334319
def should_hijack_inpainting(checkpoint_info):
@@ -341,8 +326,7 @@ def do_inpainting_hijack():
341326

342327
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
343328
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
344-
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
345329

346330
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
347331
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
348-
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
332+

0 commit comments

Comments
 (0)