Skip to content

Commit 7940ef1

Browse files
committed
ok
1 parent 3c71c75 commit 7940ef1

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

model.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -478,16 +478,9 @@ def process_tokens(self, t_c, t_r):
478478
return m_c, m_r
479479

480480

481-
class IMFEncoder(nn.Module):
482-
def __init__(self, model):
483-
super(IMFEncoder, self).__init__()
484-
self.model = model
485481

486-
def forward(self, x_current, x_reference):
487-
f_r = self.model.dense_feature_encoder(x_reference)
488-
t_r = self.model.latent_token_encoder(x_reference)
489-
t_c = self.model.latent_token_encoder(x_current)
490-
return f_r, t_r, t_c
482+
483+
491484

492485
class MappingNetwork(nn.Module):
493486
def __init__(self, latent_dim, w_dim, depth):

onnxconv.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@
55
from PIL import Image
66
from torchvision import transforms
77

8+
9+
class IMFDecoder(nn.Module):
10+
def __init__(self, model):
11+
super(IMFDecoder, self).__init__()
12+
self.model = model
13+
14+
def decode_latent_tokens(self,f_r,t_r,t_c):
15+
return self.model.decode_latent_tokens(f_r,t_r,t_c)
16+
17+
18+
819
# Define the IMFEncoder class
920
class IMFEncoder(nn.Module):
1021
def __init__(self, model):

0 commit comments

Comments
 (0)