Skip to content

Commit 4d0276f

Browse files
committed
Motion Predictor
1 parent 5941433 commit 4d0276f

File tree

4 files changed

+201
-62
lines changed

4 files changed

+201
-62
lines changed

__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"IG Path Join": IG_PathJoin,
2626
"IG Cross Fade Images": IG_CrossFadeImages,
2727
"IG Interpolate": IG_Interpolate,
28+
"IG MotionPredictor": IG_MotionPredictor,
2829
"IG ZFill": IG_ZFill,
2930
"IG String List": IG_StringList,
3031
"IG Float List": IG_FloatList,
@@ -43,6 +44,7 @@
4344
"IG Path Join": "📂 IG Path Join",
4445
"IG Cross Fade Images": "🧑🏻‍🧑🏿‍🧒🏽 IG Cross Fade Images",
4546
"IG Interpolate": "🧑🏻‍🧑🏿‍🧒🏽 IG Interpolate",
47+
"IG MotionPredictor": "🏃‍♀️ IG Motion Predictor",
4648
"IG ZFill": "⌨️ IG ZFill",
4749
"IG String List": "📃 IG String List",
4850
"IG Float List": "📃 IG Float List",

motion_predictor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# __init__.py
2+
from .motion_predictor import MotionPredictor

motion_predictor/motion_predictor.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import logging
2+
import math
3+
4+
import torch
5+
from diffusers import DiffusionPipeline
6+
from diffusers.configuration_utils import ConfigMixin
7+
from diffusers.models import ModelMixin
8+
from einops import rearrange
9+
from torch import nn
10+
11+
logger = logging.getLogger(__name__)
12+
13+
def generate_positional_encodings(length, hidden_dim):
14+
# Precompute positional encodings once in log space
15+
position = torch.arange(length).unsqueeze(1)
16+
div_term = torch.exp(torch.arange(0, hidden_dim, 2) * -(math.log(10000.0) / hidden_dim))
17+
pe = torch.zeros(length, hidden_dim)
18+
pe[:, 0::2] = torch.sin(position * div_term)
19+
pe[:, 1::2] = torch.cos(position * div_term)
20+
return pe
21+
22+
class MotionPredictor(ModelMixin, ConfigMixin):
23+
def __init__(self, token_dim:int=768, hidden_dim:int=1024, num_heads:int=16, num_layers:int=8, total_frames:int=16, tokens_per_frame:int=16):
24+
super(MotionPredictor, self).__init__()
25+
self.total_frames = total_frames
26+
self.tokens_per_frame = tokens_per_frame
27+
28+
# Initialize layers
29+
self.input_projection = nn.Linear(token_dim, hidden_dim) # Project token to hidden dimension
30+
self.transformer = nn.TransformerDecoder(
31+
nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=num_heads),
32+
num_layers=num_layers
33+
)
34+
self.output_projection = nn.Linear(hidden_dim, token_dim) # Project back to token dimension
35+
# Initialize positional encodings
36+
self.positional_encodings = generate_positional_encodings(total_frames, hidden_dim)
37+
self.positional_encodings = nn.Parameter(self.positional_encodings, requires_grad=False) # Optionally make it a parameter if you want it on the same device automatically
38+
39+
def create_attention_mask(self, total_frames, num_tokens):
40+
# Initialize the mask with float('-inf') everywhere
41+
mask = torch.zeros((total_frames * num_tokens, total_frames * num_tokens), dtype=torch.bool, device=self.device)
42+
43+
# Indices for the first frame tokens and the last frame tokens
44+
first_frame_indices = torch.arange(0, num_tokens, device=self.device)
45+
last_frame_indices = torch.arange((total_frames - 1) * num_tokens, total_frames * num_tokens, device=self.device)
46+
47+
# Allow attention to the first and last frame tokens
48+
mask[first_frame_indices, :] = 0
49+
mask[last_frame_indices, :] = 0
50+
51+
return mask
52+
53+
def interpolate_tokens(self, start_tokens:torch.Tensor, end_tokens:torch.Tensor):
54+
# Linear interpolation in the token space
55+
interpolation_steps = torch.linspace(0, 1, steps=self.total_frames, device=start_tokens.device, dtype=torch.float16)[:, None, None]
56+
start_tokens_expanded = start_tokens.unsqueeze(1) # Shape becomes [batch_size, 1, tokens, token_dim]
57+
end_tokens_expanded = end_tokens.unsqueeze(1) # Shape becomes [batch_size, 1, tokens, token_dim]
58+
interpolated_tokens = (start_tokens_expanded * (1 - interpolation_steps) + end_tokens_expanded * interpolation_steps)
59+
return interpolated_tokens # Shape: [batch_size, total_frames, tokens, token_dim]
60+
61+
def predict_motion(self, start_tokens:torch.Tensor, end_tokens:torch.Tensor):
62+
start_tokens = start_tokens.to(self.device)
63+
end_tokens = end_tokens.to(self.device)
64+
65+
# Get interpolated tokens
66+
interpolated_tokens = self.interpolate_tokens(start_tokens, end_tokens).to(self.dtype)
67+
68+
# Flatten frames and tokens dimensions
69+
batch_size, total_frames, num_tokens, token_dim = interpolated_tokens.shape
70+
71+
print(f"Interpolated tokens {interpolated_tokens.shape}")
72+
# Apply input projection
73+
projected_tokens = self.input_projection(interpolated_tokens)
74+
75+
# Add positional encodings
76+
projected_tokens += self.positional_encodings[:total_frames * num_tokens].unsqueeze(0).unsqueeze(2) # Add PE to each frame
77+
78+
# Reshape to match the transformer expected input [seq_len, batch_size, hidden_dim]
79+
projected_tokens = rearrange(projected_tokens, 'b f t d -> (f t) b d')
80+
81+
# Create an attention mask that only allows attending to the first and last frame
82+
attention_mask = self.create_attention_mask(total_frames, num_tokens)
83+
84+
# Transformer predicts the motion along the new sequence dimension
85+
logger.debug(f"projected_tokens {projected_tokens.shape} attention_mask {attention_mask.shape}")
86+
motion_tokens = self.transformer(projected_tokens, projected_tokens, memory_mask=attention_mask)
87+
88+
# Reshape back and apply output projection
89+
motion_tokens = rearrange(motion_tokens, '(f t) b d -> b f t d', t=num_tokens, f=total_frames)
90+
motion_tokens = self.output_projection(motion_tokens)
91+
92+
return motion_tokens
93+
94+
def forward(self, start_tokens:torch.Tensor, end_tokens:torch.Tensor):
95+
return self.predict_motion(start_tokens, end_tokens)

nodes/interpolate.py

Lines changed: 102 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
import json
77

88
import sys
9+
from comfy import model_management
910
import folder_paths
1011
from ..common.tree import *
1112
from ..common.constants import *
13+
from ..motion_predictor import MotionPredictor
14+
import comfy.utils
1215

1316
def crossfade(images_1, images_2, alpha):
1417
crossfade = (1 - alpha) * images_1 + alpha * images_2
@@ -42,6 +45,99 @@ def exponential_ease_out(t):
4245
"exponential_ease_out": exponential_ease_out,
4346
}
4447

48+
def tensor_to_size(source, dest_size):
49+
if isinstance(dest_size, torch.Tensor):
50+
dest_size = dest_size.shape[0]
51+
source_size = source.shape[0]
52+
53+
if source_size < dest_size:
54+
shape = [dest_size - source_size] + [1]*(source.dim()-1)
55+
source = torch.cat((source, source[-1:].repeat(shape)), dim=0)
56+
elif source_size > dest_size:
57+
source = source[:dest_size]
58+
59+
return source
60+
61+
class IG_MotionPredictor:
62+
@classmethod
63+
def INPUT_TYPES(s):
64+
return {
65+
"required": {
66+
"pos_embeds": ("PROJ_EMBEDS",),
67+
"neg_embeds": ("PROJ_EMBEDS",),
68+
"transitioning_frames": ("INT", {"default": 16,"min": 0, "max": 4096, "step": 1}),
69+
"repeat_count": ("INT", {"default": 1, "min": 1, "max": 4096, "step": 1}),
70+
"mode": (["motion_predict", "interpolate_linear"], ),
71+
"motion_predictor_file": (folder_paths.get_filename_list("ipadapter"),),
72+
},
73+
"optional": {
74+
"positive_prompts": ("STRING", {"default": [], "forceInput": True}),
75+
"negative_prompts": ("STRING", {"default": [], "forceInput": True}),
76+
}
77+
}
78+
79+
RETURN_TYPES = ("PROJ_EMBEDS", "PROJ_EMBEDS", "STRING", "STRING", "INT",)
80+
RETURN_NAMES = ("pos_embeds", "neg_embeds", "positive_string", "negative_string", "BATCH_SIZE", )
81+
FUNCTION = "main"
82+
CATEGORY = TREE_INTERP
83+
84+
@torch.inference_mode()
85+
def main(self, pos_embeds, neg_embeds, transitioning_frames, repeat_count, mode, motion_predictor_file, positive_prompts=None, negative_prompts=None):
86+
87+
torch_device = model_management.get_torch_device()
88+
dtype = model_management.unet_dtype()
89+
90+
easing_function = easing_functions["linear"]
91+
92+
print( f"Embed shape {pos_embeds.shape}")
93+
94+
inbetween_embeds = []
95+
# Make sure we have 2 images
96+
if len(pos_embeds) > 1:
97+
if mode == "motion_predict":
98+
motion_predictor = MotionPredictor(total_frames=transitioning_frames).to(torch_device, dtype=dtype)
99+
motion_predictor_path = folder_paths.get_full_path("ipadapter", motion_predictor_file)
100+
checkpoint = comfy.utils.load_torch_file(motion_predictor_path, safe_load=True)
101+
motion_predictor.load_state_dict(checkpoint)
102+
for i in range(len(pos_embeds) - 1):
103+
embed1 = pos_embeds[i]
104+
embed2 = pos_embeds[i + 1]
105+
embed1 = embed1.unsqueeze(0)
106+
embed2 = embed2.unsqueeze(0)
107+
inbetween_embeds = motion_predictor(embed1, embed2).squeeze(0)
108+
elif mode == "interpolate_linear":
109+
# Interpolate embeds
110+
for i in range(len(pos_embeds) - 1):
111+
embed1 = pos_embeds[i]
112+
embed2 = pos_embeds[i + 1]
113+
alphas = torch.linspace(0, 1, transitioning_frames)
114+
for alpha in alphas:
115+
eased_alpha = easing_function(alpha.item())
116+
print(f"eased alpha {eased_alpha}")
117+
inbetween_embed = (1 - eased_alpha) * embed1 + eased_alpha * embed2
118+
inbetween_embeds.extend([inbetween_embed])
119+
120+
inbetween_embeds = [embed for embed in inbetween_embeds for _ in range(repeat_count)]
121+
# Find size of batch
122+
batch_size = len(inbetween_embeds)
123+
124+
inbetween_embeds = torch.stack(inbetween_embeds, dim=0)
125+
126+
# ensure that cond and uncond have the same batch size
127+
neg_embeds = tensor_to_size(neg_embeds, inbetween_embeds.shape[0])
128+
129+
# Combine and format prompt strings
130+
def format_text_prompts(text_prompts):
131+
string = ""
132+
for i, prompt in enumerate(text_prompts):
133+
string += f"\"{i * transitioning_frames * repeat_count - 1}\":\"{prompt}\",\n"
134+
return string
135+
136+
positive_string = format_text_prompts(positive_prompts) if positive_prompts is not None and len(positive_prompts) > 0 else "\"0\":\"\",\n"
137+
negative_string = format_text_prompts(negative_prompts) if negative_prompts is not None and len(negative_prompts) > 0 else "\"0\":\"\",\n"
138+
139+
return (inbetween_embeds, neg_embeds, positive_string, negative_string, batch_size,)
140+
45141
class IG_Interpolate:
46142
@classmethod
47143
def INPUT_TYPES(s):
@@ -68,6 +164,7 @@ def INPUT_TYPES(s):
68164
FUNCTION = "main"
69165
CATEGORY = TREE_INTERP
70166

167+
@torch.inference_mode()
71168
def main(self, ipadapter, clip_vision, transitioning_frames, repeat_count, interpolation, buffer, input_images1=None, input_images2=None, input_images3=None, positive_prompts=None, negative_prompts=None):
72169
if 'ipadapter' in ipadapter:
73170
ipadapter_model = ipadapter['ipadapter']['model']
@@ -89,12 +186,16 @@ def main(self, ipadapter, clip_vision, transitioning_frames, repeat_count, inter
89186
continue
90187
# Create pos embeds
91188
img_cond_embeds = clip_vision.encode_image(input_images)
92-
189+
print( f"penultimate_hidden_states shape {img_cond_embeds.penultimate_hidden_states.shape}")
190+
print( f"last_hidden_state shape {img_cond_embeds.last_hidden_state.shape}")
191+
print( f"image_embeds shape {img_cond_embeds.image_embeds.shape}")
192+
93193
if is_plus:
94194
img_cond_embeds = img_cond_embeds.penultimate_hidden_states
95195
else:
96196
img_cond_embeds = img_cond_embeds.image_embeds
97197
print( f"Embed shape {img_cond_embeds.shape}")
198+
98199
inbetween_embeds = []
99200
# Make sure we have 2 images
100201
if len(img_cond_embeds) > 1:
@@ -187,65 +288,4 @@ def main(self, input_images, transitioning_frames, interpolation, repeat_count):
187288
# crossfade_images.append(last_image)
188289

189290
crossfade_images = torch.stack(crossfade_images, dim=0)
190-
191-
# If not at end, transition image
192-
193-
194-
# for i in range(transitioning_frames):
195-
# alpha = alphas[i]
196-
# image1 = images_1[i + transition_start_index]
197-
# image2 = images_2[i + transition_start_index]
198-
# easing_function = easing_functions.get(interpolation)
199-
# alpha = easing_function(alpha) # Apply the easing function to the alpha value
200-
201-
# crossfade_image = crossfade(image1, image2, alpha)
202-
# crossfade_images.append(crossfade_image)
203-
204-
# # Convert crossfade_images to tensor
205-
# crossfade_images = torch.stack(crossfade_images, dim=0)
206-
# # Get the last frame result of the interpolation
207-
# last_frame = crossfade_images[-1]
208-
# # Calculate the number of remaining frames from images_2
209-
# remaining_frames = len(images_2) - (transition_start_index + transitioning_frames)
210-
# # Crossfade the remaining frames with the last used alpha value
211-
# for i in range(remaining_frames):
212-
# alpha = alphas[-1]
213-
# image1 = images_1[i + transition_start_index + transitioning_frames]
214-
# image2 = images_2[i + transition_start_index + transitioning_frames]
215-
# easing_function = easing_functions.get(interpolation)
216-
# alpha = easing_function(alpha) # Apply the easing function to the alpha value
217-
218-
# crossfade_image = crossfade(image1, image2, alpha)
219-
# crossfade_images = torch.cat([crossfade_images, crossfade_image.unsqueeze(0)], dim=0)
220-
# # Append the beginning of images_1
221-
# beginning_images_1 = images_1[:transition_start_index]
222-
# crossfade_images = torch.cat([beginning_images_1, crossfade_images], dim=0)
223291
return (crossfade_images, )
224-
225-
226-
# class IG_ParseqToWeights:
227-
228-
# FUNCTION = "main"
229-
# CATEGORY = TREE_INTERP
230-
# RETURN_TYPES = ("FLOAT",)
231-
# RETURN_NAMES = ("weights",)
232-
233-
# @classmethod
234-
# def INPUT_TYPES(s):
235-
# return {
236-
# "required": {
237-
# "parseq": ("STRING", {"default": '', "multiline": True}),
238-
# },
239-
# }
240-
241-
# def main(self, parseq):
242-
# # Load the JSON string into a dictionary
243-
# data = json.loads(parseq)
244-
245-
# # Extract the list of frames
246-
# frames = data.get('rendered_frames', [])
247-
248-
# # Extract the prompt_weight_1 from each frame and store it in a list
249-
# prompt_weights = [frame['prompt_weight_1'] for frame in frames]
250-
251-
# return (prompt_weights, )

0 commit comments

Comments
 (0)