Skip to content

Commit ca2cec9

Browse files
committed
decode_latent_tokens
1 parent 61c1440 commit ca2cec9

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

inference.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ def process_video(model, video_path, output_path, transform, device, frame_skip=
2828
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
2929
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
3030

31+
# Process reference frame
3132
reference_frame = vr[0].asnumpy()
3233
reference_frame = Image.fromarray(reference_frame)
3334
reference_frame = transform(reference_frame).unsqueeze(0).to(device)
3435

36+
with torch.no_grad():
37+
f_r = model.dense_feature_encoder(reference_frame)
38+
t_r = model.latent_token_encoder(reference_frame)
39+
3540
total_frames = len(vr)
3641
for i in range(1, total_frames):
3742
if i % (frame_skip + 1) != 0:
@@ -42,7 +47,8 @@ def process_video(model, video_path, output_path, transform, device, frame_skip=
4247
current_frame = transform(current_frame).unsqueeze(0).to(device)
4348

4449
with torch.no_grad():
45-
reconstructed_frame = model(current_frame, reference_frame)
50+
t_c = model.latent_token_encoder(current_frame)
51+
reconstructed_frame = model.decode_latent_tokens(f_r, t_r, t_c)
4652

4753
reconstructed_frame = reconstructed_frame.squeeze().cpu().numpy().transpose(1, 2, 0)
4854
reconstructed_frame = (reconstructed_frame * 255).astype(np.uint8)
@@ -77,16 +83,8 @@ def main():
7783
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
7884
])
7985

80-
if config.input.video_path:
81-
process_video(model, config.input.video_path, config.output.path, transform, device, config.input.frame_skip)
82-
else:
83-
current_frame = load_image(config.input.current_frame_path, transform).to(device)
84-
reference_frame = load_image(config.input.reference_frame_path, transform).to(device)
85-
86-
with torch.no_grad():
87-
reconstructed_frame = model(current_frame, reference_frame)
86+
process_video(model, config.input.video_path, config.output.path, transform, device, config.input.frame_skip)
8887

89-
save_output(reconstructed_frame, config.output.path)
9088

9189
if __name__ == "__main__":
9290
main()

model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def style_mixing(self, t_c, t_r):
387387
return t_c_mixed, t_r_mixed
388388
return t_c, t_r
389389

390+
def tokens(self, x_current, x_reference):
391+
f_r = self.dense_feature_encoder(x_reference)
392+
t_r = self.latent_token_encoder(x_reference)
393+
t_c = self.latent_token_encoder(x_current)
394+
return f_r,t_r,t_c
395+
390396
def decode_latent_tokens(self,f_r,t_r,t_c):
391397
mix_t_c = t_c
392398
mix_t_r = t_r

0 commit comments

Comments
 (0)