Skip to content

Commit bb2e2c8

Browse files
Merge pull request #4233 from thesved/patch-1
Make DDIM and PLMS work on Mac OS
2 parents b8a2e38 + 86b7fc6 commit bb2e2c8

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

modules/sd_hijack_inpainting.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import modules.devices as devices
23

34
from einops import repeat
45
from omegaconf import ListConfig
@@ -314,6 +315,20 @@ def __init__(
314315
self.masked_image_key = masked_image_key
315316
assert self.masked_image_key in concat_keys
316317
self.concat_keys = concat_keys
318+
319+
320+
# =================================================================================================
321+
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
322+
# =================================================================================================
323+
def register_buffer(self, name, attr):
324+
if type(attr) == torch.Tensor:
325+
optimal_type = devices.get_optimal_device()
326+
if attr.device != optimal_type:
327+
if getattr(torch, 'has_mps', False):
328+
attr = attr.to(device="mps", dtype=torch.float32)
329+
else:
330+
attr = attr.to(optimal_type)
331+
setattr(self, name, attr)
317332

318333

319334
def should_hijack_inpainting(checkpoint_info):
@@ -326,6 +341,8 @@ def do_inpainting_hijack():
326341

327342
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
328343
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
344+
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
329345

330346
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
331-
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
347+
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
348+
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer

0 commit comments

Comments
 (0)