Skip to content

Commit f297a20

Browse files
authored
Merge PR #440 from Kosinkadink/current_device_fix
Remove current_device param from ModelPatcher init
2 parents 106d691 + c1c3bbc commit f297a20

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

animatediff/model_injection.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
class ModelPatcherAndInjector(ModelPatcher):
3737
def __init__(self, m: ModelPatcher):
3838
# replicate ModelPatcher.clone() to initialize ModelPatcherAndInjector
39-
super().__init__(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
39+
super().__init__(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update)
4040
self.patches = {}
4141
for k in m.patches:
4242
self.patches[k] = m.patches[k][:]
@@ -439,7 +439,7 @@ def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_pat
439439
class ModelPatcherCLIPHooks(ModelPatcher):
440440
def __init__(self, m: ModelPatcher):
441441
# replicate ModelPatcher.clone() to initialize
442-
super().__init__(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
442+
super().__init__(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update)
443443
self.patches = {}
444444
for k in m.patches:
445445
self.patches[k] = m.patches[k][:]
@@ -1016,7 +1016,7 @@ def cleanup(self):
10161016

10171017
def clone(self):
10181018
# normal ModelPatcher clone actions
1019-
n = MotionModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
1019+
n = MotionModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
10201020
n.patches = {}
10211021
for k in self.patches:
10221022
n.patches[k] = self.patches[k][:]
@@ -1124,7 +1124,7 @@ def get_name_string(self, show_version=False):
11241124

11251125

11261126
def get_vanilla_model_patcher(m: ModelPatcher) -> ModelPatcher:
1127-
model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, m.current_device, weight_inplace_update=m.weight_inplace_update)
1127+
model = ModelPatcher(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update)
11281128
model.patches = {}
11291129
for k in m.patches:
11301130
model.patches[k] = m.patches[k][:]

animatediff/sampling.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,14 @@ def ad_callback(step, x0, x, total_steps):
487487
iter_model = model.model
488488
else:
489489
iter_model = model
490+
current_device = None
491+
if hasattr(model, "current_device"): # backwards compatibility, for now
492+
current_device = model.current_device
493+
else:
494+
current_device = model.model.device
490495
iter_kwargs[IterationOptions.SAMPLER] = comfy.samplers.KSampler(
491496
iter_model, steps=999, #steps=args[-7],
492-
device=model.current_device, sampler=args[-5],
497+
device=current_device, sampler=args[-5],
493498
scheduler=args[-4], denoise=kwargs.get("denoise", None),
494499
model_options=model.model_options)
495500
del iter_model

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-animatediff-evolved"
33
description = "Improved AnimateDiff integration for ComfyUI."
4-
version = "1.0.11"
4+
version = "1.0.12"
55
license = { file = "LICENSE" }
66
dependencies = []
77

0 commit comments

Comments
 (0)