5
5
from torch import nn
6
6
from transformers import CLIPTokenizer , T5TokenizerFast
7
7
8
+ from modules import sd_hijack
9
+
8
10
9
11
#################################################################################################
10
12
### Core/Utility
@@ -110,9 +112,9 @@ def forward(self, x, mask=None, intermediate_output=None):
110
112
111
113
112
114
class CLIPEmbeddings (torch .nn .Module ):
113
- def __init__ (self , embed_dim , vocab_size = 49408 , num_positions = 77 , dtype = None , device = None ):
115
+ def __init__ (self , embed_dim , vocab_size = 49408 , num_positions = 77 , dtype = None , device = None , textual_inversion_key = "clip_l" ):
114
116
super ().__init__ ()
115
- self .token_embedding = torch . nn . Embedding (vocab_size , embed_dim , dtype = dtype , device = device )
117
+ self .token_embedding = sd_hijack . TextualInversionEmbeddings (vocab_size , embed_dim , dtype = dtype , device = device , textual_inversion_key = textual_inversion_key )
116
118
self .position_embedding = torch .nn .Embedding (num_positions , embed_dim , dtype = dtype , device = device )
117
119
118
120
def forward (self , input_tokens ):
@@ -127,7 +129,7 @@ def __init__(self, config_dict, dtype, device):
127
129
intermediate_size = config_dict ["intermediate_size" ]
128
130
intermediate_activation = config_dict ["hidden_act" ]
129
131
super ().__init__ ()
130
- self .embeddings = CLIPEmbeddings (embed_dim , dtype = torch .float32 , device = device )
132
+ self .embeddings = CLIPEmbeddings (embed_dim , dtype = torch .float32 , device = device , textual_inversion_key = config_dict . get ( 'textual_inversion_key' , 'clip_l' ) )
131
133
self .encoder = CLIPEncoder (num_layers , embed_dim , heads , intermediate_size , intermediate_activation , dtype , device )
132
134
self .final_layer_norm = nn .LayerNorm (embed_dim , dtype = dtype , device = device )
133
135
0 commit comments