@@ -28,10 +28,15 @@ def process_video(model, video_path, output_path, transform, device, frame_skip=
28
28
fourcc = cv2 .VideoWriter_fourcc (* 'mp4v' )
29
29
out = cv2 .VideoWriter (output_path , fourcc , fps , (width , height ))
30
30
31
+ # Process reference frame
31
32
reference_frame = vr [0 ].asnumpy ()
32
33
reference_frame = Image .fromarray (reference_frame )
33
34
reference_frame = transform (reference_frame ).unsqueeze (0 ).to (device )
34
35
36
+ with torch .no_grad ():
37
+ f_r = model .dense_feature_encoder (reference_frame )
38
+ t_r = model .latent_token_encoder (reference_frame )
39
+
35
40
total_frames = len (vr )
36
41
for i in range (1 , total_frames ):
37
42
if i % (frame_skip + 1 ) != 0 :
@@ -42,7 +47,8 @@ def process_video(model, video_path, output_path, transform, device, frame_skip=
42
47
current_frame = transform (current_frame ).unsqueeze (0 ).to (device )
43
48
44
49
with torch .no_grad ():
45
- reconstructed_frame = model (current_frame , reference_frame )
50
+ t_c = model .latent_token_encoder (current_frame )
51
+ reconstructed_frame = model .decode_latent_tokens (f_r , t_r , t_c )
46
52
47
53
reconstructed_frame = reconstructed_frame .squeeze ().cpu ().numpy ().transpose (1 , 2 , 0 )
48
54
reconstructed_frame = (reconstructed_frame * 255 ).astype (np .uint8 )
@@ -77,16 +83,8 @@ def main():
77
83
transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
78
84
])
79
85
80
- if config .input .video_path :
81
- process_video (model , config .input .video_path , config .output .path , transform , device , config .input .frame_skip )
82
- else :
83
- current_frame = load_image (config .input .current_frame_path , transform ).to (device )
84
- reference_frame = load_image (config .input .reference_frame_path , transform ).to (device )
85
-
86
- with torch .no_grad ():
87
- reconstructed_frame = model (current_frame , reference_frame )
86
+ process_video (model , config .input .video_path , config .output .path , transform , device , config .input .frame_skip )
88
87
89
- save_output (reconstructed_frame , config .output .path )
90
88
91
89
if __name__ == "__main__" :
92
90
main ()
0 commit comments