1
1
import torch
2
+ import modules .devices as devices
2
3
3
4
from einops import repeat
4
5
from omegaconf import ListConfig
@@ -314,6 +315,20 @@ def __init__(
314
315
self .masked_image_key = masked_image_key
315
316
assert self .masked_image_key in concat_keys
316
317
self .concat_keys = concat_keys
318
+
319
+
320
+ # =================================================================================================
321
+ # Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
322
+ # =================================================================================================
323
+ def register_buffer (self , name , attr ):
324
+ if type (attr ) == torch .Tensor :
325
+ optimal_type = devices .get_optimal_device ()
326
+ if attr .device != optimal_type :
327
+ if getattr (torch , 'has_mps' , False ):
328
+ attr = attr .to (device = "mps" , dtype = torch .float32 )
329
+ else :
330
+ attr = attr .to (optimal_type )
331
+ setattr (self , name , attr )
317
332
318
333
319
334
def should_hijack_inpainting (checkpoint_info ):
@@ -326,6 +341,8 @@ def do_inpainting_hijack():
326
341
327
342
ldm .models .diffusion .ddim .DDIMSampler .p_sample_ddim = p_sample_ddim
328
343
ldm .models .diffusion .ddim .DDIMSampler .sample = sample_ddim
344
+ ldm .models .diffusion .ddim .DDIMSampler .register_buffer = register_buffer
329
345
330
346
ldm .models .diffusion .plms .PLMSSampler .p_sample_plms = p_sample_plms
331
- ldm .models .diffusion .plms .PLMSSampler .sample = sample_plms
347
+ ldm .models .diffusion .plms .PLMSSampler .sample = sample_plms
348
+ ldm .models .diffusion .plms .PLMSSampler .register_buffer = register_buffer
0 commit comments