@@ -18,15 +18,12 @@ def load_image(image_path, transform):
18
18
def save_output (tensor , filename ):
19
19
save_image (tensor , filename , normalize = True )
20
20
21
- def process_video (model , video_path , output_path , transform , device , frame_skip = 0 ):
21
+ def process_video (model , video_path , output_dir , transform , device , frame_skip = 0 ):
22
22
ctx = gpu (0 ) if torch .cuda .is_available () else cpu (0 )
23
23
vr = VideoReader (video_path , ctx = ctx )
24
24
25
- fps = vr .get_avg_fps ()
26
- width , height = vr [0 ].shape [1 ], vr [0 ].shape [0 ]
27
-
28
- fourcc = cv2 .VideoWriter_fourcc (* 'mp4v' )
29
- out = cv2 .VideoWriter (output_path , fourcc , fps , (width , height ))
25
+ # Create output directory if it doesn't exist
26
+ # os.makedirs(output_dir, exist_ok=True)
30
27
31
28
# Process reference frame
32
29
reference_frame = vr [0 ].asnumpy ()
@@ -50,13 +47,15 @@ def process_video(model, video_path, output_path, transform, device, frame_skip=
50
47
t_c = model .latent_token_encoder (current_frame )
51
48
reconstructed_frame = model .decode_latent_tokens (f_r , t_r , t_c )
52
49
53
- reconstructed_frame = reconstructed_frame .squeeze ().cpu ().numpy ().transpose (1 , 2 , 0 )
54
- reconstructed_frame = (reconstructed_frame * 255 ).astype (np .uint8 )
55
- reconstructed_frame = cv2 .cvtColor (reconstructed_frame , cv2 .COLOR_RGB2BGR )
56
-
57
- out .write (reconstructed_frame )
50
+ # Convert the reconstructed frame to a PIL Image
51
+ reconstructed_frame = reconstructed_frame .squeeze ().cpu ()
52
+ reconstructed_frame = transforms .ToPILImage ()(reconstructed_frame )
53
+
54
+ # Save the reconstructed frame as an image
55
+ output_path = os .path .join (output_dir , f"frame_{ i :04d} .png" )
56
+ reconstructed_frame .save (output_path )
58
57
59
- out . release ( )
58
+ print ( f"Processed { total_frames } frames. Output saved in { output_dir } " )
60
59
61
60
def main ():
62
61
# Load configuration
@@ -83,7 +82,7 @@ def main():
83
82
transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
84
83
])
85
84
86
- process_video (model , config .input .video_path , config .output .path , transform , device , config .input .frame_skip )
85
+ process_video (model , config .input .video_path , config .output .directory , transform , device , config .input .frame_skip )
87
86
88
87
89
88
if __name__ == "__main__" :
0 commit comments