1
1
import torch
2
- import modules .devices as devices
3
2
4
3
from einops import repeat
5
4
from omegaconf import ListConfig
@@ -315,20 +314,6 @@ def __init__(
315
314
self .masked_image_key = masked_image_key
316
315
assert self .masked_image_key in concat_keys
317
316
self .concat_keys = concat_keys
318
-
319
-
320
- # =================================================================================================
321
- # Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
322
- # =================================================================================================
323
- def register_buffer (self , name , attr ):
324
- if type (attr ) == torch .Tensor :
325
- optimal_type = devices .get_optimal_device ()
326
- if attr .device != optimal_type :
327
- if getattr (torch , 'has_mps' , False ):
328
- attr = attr .to (device = "mps" , dtype = torch .float32 )
329
- else :
330
- attr = attr .to (optimal_type )
331
- setattr (self , name , attr )
332
317
333
318
334
319
def should_hijack_inpainting (checkpoint_info ):
@@ -341,8 +326,7 @@ def do_inpainting_hijack():
341
326
342
327
ldm .models .diffusion .ddim .DDIMSampler .p_sample_ddim = p_sample_ddim
343
328
ldm .models .diffusion .ddim .DDIMSampler .sample = sample_ddim
344
- ldm .models .diffusion .ddim .DDIMSampler .register_buffer = register_buffer
345
329
346
330
ldm .models .diffusion .plms .PLMSSampler .p_sample_plms = p_sample_plms
347
331
ldm .models .diffusion .plms .PLMSSampler .sample = sample_plms
348
- ldm . models . diffusion . plms . PLMSSampler . register_buffer = register_buffer
332
+
0 commit comments