6
6
import json
7
7
8
8
import sys
9
+ from comfy import model_management
9
10
import folder_paths
10
11
from ..common .tree import *
11
12
from ..common .constants import *
13
+ from ..motion_predictor import MotionPredictor
14
+ import comfy .utils
12
15
13
16
def crossfade (images_1 , images_2 , alpha ):
14
17
crossfade = (1 - alpha ) * images_1 + alpha * images_2
@@ -42,6 +45,99 @@ def exponential_ease_out(t):
42
45
"exponential_ease_out" : exponential_ease_out ,
43
46
}
44
47
48
+ def tensor_to_size (source , dest_size ):
49
+ if isinstance (dest_size , torch .Tensor ):
50
+ dest_size = dest_size .shape [0 ]
51
+ source_size = source .shape [0 ]
52
+
53
+ if source_size < dest_size :
54
+ shape = [dest_size - source_size ] + [1 ]* (source .dim ()- 1 )
55
+ source = torch .cat ((source , source [- 1 :].repeat (shape )), dim = 0 )
56
+ elif source_size > dest_size :
57
+ source = source [:dest_size ]
58
+
59
+ return source
60
+
61
+ class IG_MotionPredictor :
62
+ @classmethod
63
+ def INPUT_TYPES (s ):
64
+ return {
65
+ "required" : {
66
+ "pos_embeds" : ("PROJ_EMBEDS" ,),
67
+ "neg_embeds" : ("PROJ_EMBEDS" ,),
68
+ "transitioning_frames" : ("INT" , {"default" : 16 ,"min" : 0 , "max" : 4096 , "step" : 1 }),
69
+ "repeat_count" : ("INT" , {"default" : 1 , "min" : 1 , "max" : 4096 , "step" : 1 }),
70
+ "mode" : (["motion_predict" , "interpolate_linear" ], ),
71
+ "motion_predictor_file" : (folder_paths .get_filename_list ("ipadapter" ),),
72
+ },
73
+ "optional" : {
74
+ "positive_prompts" : ("STRING" , {"default" : [], "forceInput" : True }),
75
+ "negative_prompts" : ("STRING" , {"default" : [], "forceInput" : True }),
76
+ }
77
+ }
78
+
79
+ RETURN_TYPES = ("PROJ_EMBEDS" , "PROJ_EMBEDS" , "STRING" , "STRING" , "INT" ,)
80
+ RETURN_NAMES = ("pos_embeds" , "neg_embeds" , "positive_string" , "negative_string" , "BATCH_SIZE" , )
81
+ FUNCTION = "main"
82
+ CATEGORY = TREE_INTERP
83
+
84
+ @torch .inference_mode ()
85
+ def main (self , pos_embeds , neg_embeds , transitioning_frames , repeat_count , mode , motion_predictor_file , positive_prompts = None , negative_prompts = None ):
86
+
87
+ torch_device = model_management .get_torch_device ()
88
+ dtype = model_management .unet_dtype ()
89
+
90
+ easing_function = easing_functions ["linear" ]
91
+
92
+ print ( f"Embed shape { pos_embeds .shape } " )
93
+
94
+ inbetween_embeds = []
95
+ # Make sure we have 2 images
96
+ if len (pos_embeds ) > 1 :
97
+ if mode == "motion_predict" :
98
+ motion_predictor = MotionPredictor (total_frames = transitioning_frames ).to (torch_device , dtype = dtype )
99
+ motion_predictor_path = folder_paths .get_full_path ("ipadapter" , motion_predictor_file )
100
+ checkpoint = comfy .utils .load_torch_file (motion_predictor_path , safe_load = True )
101
+ motion_predictor .load_state_dict (checkpoint )
102
+ for i in range (len (pos_embeds ) - 1 ):
103
+ embed1 = pos_embeds [i ]
104
+ embed2 = pos_embeds [i + 1 ]
105
+ embed1 = embed1 .unsqueeze (0 )
106
+ embed2 = embed2 .unsqueeze (0 )
107
+ inbetween_embeds = motion_predictor (embed1 , embed2 ).squeeze (0 )
108
+ elif mode == "interpolate_linear" :
109
+ # Interpolate embeds
110
+ for i in range (len (pos_embeds ) - 1 ):
111
+ embed1 = pos_embeds [i ]
112
+ embed2 = pos_embeds [i + 1 ]
113
+ alphas = torch .linspace (0 , 1 , transitioning_frames )
114
+ for alpha in alphas :
115
+ eased_alpha = easing_function (alpha .item ())
116
+ print (f"eased alpha { eased_alpha } " )
117
+ inbetween_embed = (1 - eased_alpha ) * embed1 + eased_alpha * embed2
118
+ inbetween_embeds .extend ([inbetween_embed ])
119
+
120
+ inbetween_embeds = [embed for embed in inbetween_embeds for _ in range (repeat_count )]
121
+ # Find size of batch
122
+ batch_size = len (inbetween_embeds )
123
+
124
+ inbetween_embeds = torch .stack (inbetween_embeds , dim = 0 )
125
+
126
+ # ensure that cond and uncond have the same batch size
127
+ neg_embeds = tensor_to_size (neg_embeds , inbetween_embeds .shape [0 ])
128
+
129
+ # Combine and format prompt strings
130
+ def format_text_prompts (text_prompts ):
131
+ string = ""
132
+ for i , prompt in enumerate (text_prompts ):
133
+ string += f"\" { i * transitioning_frames * repeat_count - 1 } \" :\" { prompt } \" ,\n "
134
+ return string
135
+
136
+ positive_string = format_text_prompts (positive_prompts ) if positive_prompts is not None and len (positive_prompts ) > 0 else "\" 0\" :\" \" ,\n "
137
+ negative_string = format_text_prompts (negative_prompts ) if negative_prompts is not None and len (negative_prompts ) > 0 else "\" 0\" :\" \" ,\n "
138
+
139
+ return (inbetween_embeds , neg_embeds , positive_string , negative_string , batch_size ,)
140
+
45
141
class IG_Interpolate :
46
142
@classmethod
47
143
def INPUT_TYPES (s ):
@@ -68,6 +164,7 @@ def INPUT_TYPES(s):
68
164
FUNCTION = "main"
69
165
CATEGORY = TREE_INTERP
70
166
167
+ @torch .inference_mode ()
71
168
def main (self , ipadapter , clip_vision , transitioning_frames , repeat_count , interpolation , buffer , input_images1 = None , input_images2 = None , input_images3 = None , positive_prompts = None , negative_prompts = None ):
72
169
if 'ipadapter' in ipadapter :
73
170
ipadapter_model = ipadapter ['ipadapter' ]['model' ]
@@ -89,12 +186,16 @@ def main(self, ipadapter, clip_vision, transitioning_frames, repeat_count, inter
89
186
continue
90
187
# Create pos embeds
91
188
img_cond_embeds = clip_vision .encode_image (input_images )
92
-
189
+ print ( f"penultimate_hidden_states shape { img_cond_embeds .penultimate_hidden_states .shape } " )
190
+ print ( f"last_hidden_state shape { img_cond_embeds .last_hidden_state .shape } " )
191
+ print ( f"image_embeds shape { img_cond_embeds .image_embeds .shape } " )
192
+
93
193
if is_plus :
94
194
img_cond_embeds = img_cond_embeds .penultimate_hidden_states
95
195
else :
96
196
img_cond_embeds = img_cond_embeds .image_embeds
97
197
print ( f"Embed shape { img_cond_embeds .shape } " )
198
+
98
199
inbetween_embeds = []
99
200
# Make sure we have 2 images
100
201
if len (img_cond_embeds ) > 1 :
@@ -187,65 +288,4 @@ def main(self, input_images, transitioning_frames, interpolation, repeat_count):
187
288
# crossfade_images.append(last_image)
188
289
189
290
crossfade_images = torch .stack (crossfade_images , dim = 0 )
190
-
191
- # If not at end, transition image
192
-
193
-
194
- # for i in range(transitioning_frames):
195
- # alpha = alphas[i]
196
- # image1 = images_1[i + transition_start_index]
197
- # image2 = images_2[i + transition_start_index]
198
- # easing_function = easing_functions.get(interpolation)
199
- # alpha = easing_function(alpha) # Apply the easing function to the alpha value
200
-
201
- # crossfade_image = crossfade(image1, image2, alpha)
202
- # crossfade_images.append(crossfade_image)
203
-
204
- # # Convert crossfade_images to tensor
205
- # crossfade_images = torch.stack(crossfade_images, dim=0)
206
- # # Get the last frame result of the interpolation
207
- # last_frame = crossfade_images[-1]
208
- # # Calculate the number of remaining frames from images_2
209
- # remaining_frames = len(images_2) - (transition_start_index + transitioning_frames)
210
- # # Crossfade the remaining frames with the last used alpha value
211
- # for i in range(remaining_frames):
212
- # alpha = alphas[-1]
213
- # image1 = images_1[i + transition_start_index + transitioning_frames]
214
- # image2 = images_2[i + transition_start_index + transitioning_frames]
215
- # easing_function = easing_functions.get(interpolation)
216
- # alpha = easing_function(alpha) # Apply the easing function to the alpha value
217
-
218
- # crossfade_image = crossfade(image1, image2, alpha)
219
- # crossfade_images = torch.cat([crossfade_images, crossfade_image.unsqueeze(0)], dim=0)
220
- # # Append the beginning of images_1
221
- # beginning_images_1 = images_1[:transition_start_index]
222
- # crossfade_images = torch.cat([beginning_images_1, crossfade_images], dim=0)
223
291
return (crossfade_images , )
224
-
225
-
226
- # class IG_ParseqToWeights:
227
-
228
- # FUNCTION = "main"
229
- # CATEGORY = TREE_INTERP
230
- # RETURN_TYPES = ("FLOAT",)
231
- # RETURN_NAMES = ("weights",)
232
-
233
- # @classmethod
234
- # def INPUT_TYPES(s):
235
- # return {
236
- # "required": {
237
- # "parseq": ("STRING", {"default": '', "multiline": True}),
238
- # },
239
- # }
240
-
241
- # def main(self, parseq):
242
- # # Load the JSON string into a dictionary
243
- # data = json.loads(parseq)
244
-
245
- # # Extract the list of frames
246
- # frames = data.get('rendered_frames', [])
247
-
248
- # # Extract the prompt_weight_1 from each frame and store it in a list
249
- # prompt_weights = [frame['prompt_weight_1'] for frame in frames]
250
-
251
- # return (prompt_weights, )
0 commit comments