Skip to content

Commit 63b70f1

Browse files
authored
Merge PR #525 from Kosinkadink/develop
Backend changes + fixes
2 parents 4f1344e + 297f962 commit 63b70f1

22 files changed

+1397
-561
lines changed

animatediff/adapter_hellomeme.py

+535
Large diffs are not rendered by default.

animatediff/adapter_motionctrl.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# main code adapted from https://github.com/TencentARC/MotionCtrl/tree/animatediff
2+
from __future__ import annotations
3+
from torch import nn, Tensor
4+
import torch
5+
6+
from comfy.model_patcher import ModelPatcher
7+
import comfy.model_management
8+
import comfy.ops
9+
import comfy.utils
10+
11+
from .adapter_cameractrl import ResnetBlockCameraCtrl
12+
from .ad_settings import AnimateDiffSettings
13+
from .motion_module_ad import AnimateDiffModel
14+
from .model_injection import apply_mm_settings
15+
from .utils_model import get_motion_model_path
16+
17+
18+
# cmcm (Camera Control)
19+
def inject_motionctrl_cmcm(motion_model: AnimateDiffModel, cmcm_name: str, ad_settings: AnimateDiffSettings=None,
20+
apply_non_ccs=True):
21+
cmcm_path = get_motion_model_path(cmcm_name)
22+
state_dict = comfy.utils.load_torch_file(cmcm_path, safe_load=True)
23+
_remove_module_prefix(state_dict)
24+
# if applicable, apply ad_settings to cmcm to match expected behavior
25+
if ad_settings is not None:
26+
state_dict = apply_mm_settings(model_dict=state_dict, mm_settings=ad_settings)
27+
motion_model.init_motionctrl_cc_projections(state_dict=state_dict)
28+
# seperate out PE keys so can be applied separately in case dims don't match
29+
apply_dict = {}
30+
for key in list(state_dict.keys()):
31+
if "cc_projection" in key:
32+
apply_dict[key] = state_dict[key]
33+
state_dict.pop(key)
34+
pe_dict = {}
35+
for key in list(state_dict.keys()):
36+
if "pos_encoder" in key:
37+
pe_dict[key] = state_dict[key]
38+
state_dict.pop(key)
39+
if apply_non_ccs:
40+
apply_dict.update(state_dict)
41+
for key, value in pe_dict.items():
42+
comfy.utils.set_attr(motion_model, key, value)
43+
_, unexpected = motion_model.load_state_dict(apply_dict, strict=False)
44+
if len(unexpected) > 0:
45+
raise Exception(f"MotionCtrl CMCM model had unexpected keys: {unexpected}")
46+
# make sure model is still has proper dtype and offload device
47+
motion_model.to(comfy.model_management.unet_dtype())
48+
motion_model.to(comfy.model_management.unet_offload_device())
49+
50+
51+
# omcm (Object Control)
52+
def load_motionctrl_omcm(omcm_name: str):
53+
omcm_path = get_motion_model_path(omcm_name)
54+
state_dict = comfy.utils.load_torch_file(omcm_path, safe_load=True)
55+
_remove_module_prefix(state_dict)
56+
57+
if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None:
58+
ops = comfy.ops.disable_weight_init
59+
else:
60+
ops = comfy.ops.manual_cast
61+
adapter = MotionCtrlAdapter(ops=ops)
62+
adapter.load_state_dict(state_dict=state_dict, strict=True)
63+
adapter.to(
64+
device = comfy.model_management.unet_offload_device(),
65+
dtype = comfy.model_management.unet_dtype()
66+
)
67+
omcm_modelpatcher = _create_OMCMModelPatcher(model=adapter,
68+
load_device=comfy.model_management.get_torch_device(),
69+
offload_device=comfy.model_management.unet_offload_device())
70+
return omcm_modelpatcher
71+
72+
73+
def _create_OMCMModelPatcher(model, load_device, offload_device) -> ObjectControlModelPatcher:
74+
patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
75+
return patcher
76+
77+
78+
def _remove_module_prefix(state_dict: dict[str, Tensor]):
79+
for key in list(state_dict.keys()):
80+
# remove 'module.' prefix
81+
if key.startswith('module.'):
82+
new_key = key.replace('module.', '')
83+
state_dict[new_key] = state_dict[key]
84+
state_dict.pop(key)
85+
86+
87+
def convert_cameractrl_poses_to_RT(poses: list[list[float]]):
88+
tensors = []
89+
for pose in poses:
90+
new_tensor = torch.tensor(pose[7:])
91+
new_tensor = new_tensor.unsqueeze(0)
92+
tensors.append(new_tensor)
93+
RT = torch.cat(tensors, dim=0)
94+
return RT
95+
96+
97+
class ObjectControlModelPatcher(ModelPatcher):
98+
'''Class only used for type hints.'''
99+
def __init__(self):
100+
self.model: MotionCtrlAdapter
101+
102+
103+
class MotionCtrlAdapter(nn.Module):
104+
def __init__(self,
105+
downscale_factor=8,
106+
channels=[320, 640, 1280, 1280],
107+
nums_rb=2, cin=128, # 2*8*8
108+
ksize=3, sk=True,
109+
use_conv=False,
110+
ops=comfy.ops.disable_weight_init):
111+
super(MotionCtrlAdapter, self).__init__()
112+
self.downscale_factor = downscale_factor
113+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
114+
self.channels = channels
115+
self.nums_rb = nums_rb
116+
self.body = []
117+
for i in range(len(channels)):
118+
for j in range(nums_rb):
119+
if (i != 0) and (j == 0):
120+
self.body.append(
121+
ResnetBlockCameraCtrl(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv, ops=ops))
122+
else:
123+
self.body.append(
124+
ResnetBlockCameraCtrl(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv, ops=ops))
125+
self.body = nn.ModuleList(self.body)
126+
self.conv_in = ops.Conv2d(cin, channels[0], 3, 1, 1)
127+
128+
def forward(self, x: Tensor):
129+
x = self.unshuffle(x)
130+
# extract features
131+
features = []
132+
x = self.conv_in(x)
133+
for i in range(len(self.channels)):
134+
for j in range(self.nums_rb):
135+
idx = i * self.nums_rb + j
136+
x = self.body[idx](x)
137+
features.append(x)
138+
return features

0 commit comments

Comments
 (0)