|
| 1 | +# main code adapted from https://github.com/TencentARC/MotionCtrl/tree/animatediff |
| 2 | +from __future__ import annotations |
| 3 | +from torch import nn, Tensor |
| 4 | +import torch |
| 5 | + |
| 6 | +from comfy.model_patcher import ModelPatcher |
| 7 | +import comfy.model_management |
| 8 | +import comfy.ops |
| 9 | +import comfy.utils |
| 10 | + |
| 11 | +from .adapter_cameractrl import ResnetBlockCameraCtrl |
| 12 | +from .ad_settings import AnimateDiffSettings |
| 13 | +from .motion_module_ad import AnimateDiffModel |
| 14 | +from .model_injection import apply_mm_settings |
| 15 | +from .utils_model import get_motion_model_path |
| 16 | + |
| 17 | + |
| 18 | +# cmcm (Camera Control) |
| 19 | +def inject_motionctrl_cmcm(motion_model: AnimateDiffModel, cmcm_name: str, ad_settings: AnimateDiffSettings=None, |
| 20 | + apply_non_ccs=True): |
| 21 | + cmcm_path = get_motion_model_path(cmcm_name) |
| 22 | + state_dict = comfy.utils.load_torch_file(cmcm_path, safe_load=True) |
| 23 | + _remove_module_prefix(state_dict) |
| 24 | + # if applicable, apply ad_settings to cmcm to match expected behavior |
| 25 | + if ad_settings is not None: |
| 26 | + state_dict = apply_mm_settings(model_dict=state_dict, mm_settings=ad_settings) |
| 27 | + motion_model.init_motionctrl_cc_projections(state_dict=state_dict) |
| 28 | + # seperate out PE keys so can be applied separately in case dims don't match |
| 29 | + apply_dict = {} |
| 30 | + for key in list(state_dict.keys()): |
| 31 | + if "cc_projection" in key: |
| 32 | + apply_dict[key] = state_dict[key] |
| 33 | + state_dict.pop(key) |
| 34 | + pe_dict = {} |
| 35 | + for key in list(state_dict.keys()): |
| 36 | + if "pos_encoder" in key: |
| 37 | + pe_dict[key] = state_dict[key] |
| 38 | + state_dict.pop(key) |
| 39 | + if apply_non_ccs: |
| 40 | + apply_dict.update(state_dict) |
| 41 | + for key, value in pe_dict.items(): |
| 42 | + comfy.utils.set_attr(motion_model, key, value) |
| 43 | + _, unexpected = motion_model.load_state_dict(apply_dict, strict=False) |
| 44 | + if len(unexpected) > 0: |
| 45 | + raise Exception(f"MotionCtrl CMCM model had unexpected keys: {unexpected}") |
| 46 | + # make sure model is still has proper dtype and offload device |
| 47 | + motion_model.to(comfy.model_management.unet_dtype()) |
| 48 | + motion_model.to(comfy.model_management.unet_offload_device()) |
| 49 | + |
| 50 | + |
| 51 | +# omcm (Object Control) |
| 52 | +def load_motionctrl_omcm(omcm_name: str): |
| 53 | + omcm_path = get_motion_model_path(omcm_name) |
| 54 | + state_dict = comfy.utils.load_torch_file(omcm_path, safe_load=True) |
| 55 | + _remove_module_prefix(state_dict) |
| 56 | + |
| 57 | + if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None: |
| 58 | + ops = comfy.ops.disable_weight_init |
| 59 | + else: |
| 60 | + ops = comfy.ops.manual_cast |
| 61 | + adapter = MotionCtrlAdapter(ops=ops) |
| 62 | + adapter.load_state_dict(state_dict=state_dict, strict=True) |
| 63 | + adapter.to( |
| 64 | + device = comfy.model_management.unet_offload_device(), |
| 65 | + dtype = comfy.model_management.unet_dtype() |
| 66 | + ) |
| 67 | + omcm_modelpatcher = _create_OMCMModelPatcher(model=adapter, |
| 68 | + load_device=comfy.model_management.get_torch_device(), |
| 69 | + offload_device=comfy.model_management.unet_offload_device()) |
| 70 | + return omcm_modelpatcher |
| 71 | + |
| 72 | + |
| 73 | +def _create_OMCMModelPatcher(model, load_device, offload_device) -> ObjectControlModelPatcher: |
| 74 | + patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) |
| 75 | + return patcher |
| 76 | + |
| 77 | + |
| 78 | +def _remove_module_prefix(state_dict: dict[str, Tensor]): |
| 79 | + for key in list(state_dict.keys()): |
| 80 | + # remove 'module.' prefix |
| 81 | + if key.startswith('module.'): |
| 82 | + new_key = key.replace('module.', '') |
| 83 | + state_dict[new_key] = state_dict[key] |
| 84 | + state_dict.pop(key) |
| 85 | + |
| 86 | + |
| 87 | +def convert_cameractrl_poses_to_RT(poses: list[list[float]]): |
| 88 | + tensors = [] |
| 89 | + for pose in poses: |
| 90 | + new_tensor = torch.tensor(pose[7:]) |
| 91 | + new_tensor = new_tensor.unsqueeze(0) |
| 92 | + tensors.append(new_tensor) |
| 93 | + RT = torch.cat(tensors, dim=0) |
| 94 | + return RT |
| 95 | + |
| 96 | + |
| 97 | +class ObjectControlModelPatcher(ModelPatcher): |
| 98 | + '''Class only used for type hints.''' |
| 99 | + def __init__(self): |
| 100 | + self.model: MotionCtrlAdapter |
| 101 | + |
| 102 | + |
| 103 | +class MotionCtrlAdapter(nn.Module): |
| 104 | + def __init__(self, |
| 105 | + downscale_factor=8, |
| 106 | + channels=[320, 640, 1280, 1280], |
| 107 | + nums_rb=2, cin=128, # 2*8*8 |
| 108 | + ksize=3, sk=True, |
| 109 | + use_conv=False, |
| 110 | + ops=comfy.ops.disable_weight_init): |
| 111 | + super(MotionCtrlAdapter, self).__init__() |
| 112 | + self.downscale_factor = downscale_factor |
| 113 | + self.unshuffle = nn.PixelUnshuffle(downscale_factor) |
| 114 | + self.channels = channels |
| 115 | + self.nums_rb = nums_rb |
| 116 | + self.body = [] |
| 117 | + for i in range(len(channels)): |
| 118 | + for j in range(nums_rb): |
| 119 | + if (i != 0) and (j == 0): |
| 120 | + self.body.append( |
| 121 | + ResnetBlockCameraCtrl(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv, ops=ops)) |
| 122 | + else: |
| 123 | + self.body.append( |
| 124 | + ResnetBlockCameraCtrl(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv, ops=ops)) |
| 125 | + self.body = nn.ModuleList(self.body) |
| 126 | + self.conv_in = ops.Conv2d(cin, channels[0], 3, 1, 1) |
| 127 | + |
| 128 | + def forward(self, x: Tensor): |
| 129 | + x = self.unshuffle(x) |
| 130 | + # extract features |
| 131 | + features = [] |
| 132 | + x = self.conv_in(x) |
| 133 | + for i in range(len(self.channels)): |
| 134 | + for j in range(self.nums_rb): |
| 135 | + idx = i * self.nums_rb + j |
| 136 | + x = self.body[idx](x) |
| 137 | + features.append(x) |
| 138 | + return features |
0 commit comments