Skip to content

Commit 000a8ca

Browse files
AUTOMATIC1111ruchej
authored andcommitted
sd3 TI support
1 parent 2ff4cf9 commit 000a8ca

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

modules/models/sd3/other_impls.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torch import nn
66
from transformers import CLIPTokenizer, T5TokenizerFast
77

8+
from modules import sd_hijack
9+
810

911
#################################################################################################
1012
### Core/Utility
@@ -110,9 +112,9 @@ def forward(self, x, mask=None, intermediate_output=None):
110112

111113

112114
class CLIPEmbeddings(torch.nn.Module):
113-
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
115+
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
114116
super().__init__()
115-
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
117+
self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
116118
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
117119

118120
def forward(self, input_tokens):
@@ -127,7 +129,7 @@ def __init__(self, config_dict, dtype, device):
127129
intermediate_size = config_dict["intermediate_size"]
128130
intermediate_activation = config_dict["hidden_act"]
129131
super().__init__()
130-
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
132+
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
131133
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
132134
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
133135

modules/models/sd3/sd3_cond.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __getitem__(self, key):
4040
"intermediate_size": 5120,
4141
"num_attention_heads": 20,
4242
"num_hidden_layers": 32,
43+
"textual_inversion_key": "clip_g",
4344
}
4445

4546
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
@@ -204,7 +205,10 @@ def before_load_weights(self, state_dict):
204205
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
205206

206207
def encode_embedding_init_text(self, init_text, nvpt):
207-
return torch.tensor([[0]], device=devices.device) # XXX
208+
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
209+
210+
def tokenize(self, texts):
211+
return self.model_lg.tokenize(texts)
208212

209213
def medvram_modules(self):
210214
return [self.clip_g, self.clip_l, self.t5xxl]

modules/sd_hijack.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,28 @@ def forward(self, input_ids):
359359
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
360360
emb = devices.cond_cast_unet(vec)
361361
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
362-
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
362+
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)
363363

364364
vecs.append(tensor)
365365

366366
return torch.stack(vecs)
367367

368368

369+
class TextualInversionEmbeddings(torch.nn.Embedding):
370+
def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):
371+
super().__init__(num_embeddings, embedding_dim, **kwargs)
372+
373+
self.embeddings = model_hijack
374+
self.textual_inversion_key = textual_inversion_key
375+
376+
@property
377+
def wrapped(self):
378+
return super().forward
379+
380+
def forward(self, input_ids):
381+
return EmbeddingsWithFixes.forward(self, input_ids)
382+
383+
369384
def add_circular_option_to_conv_2d():
370385
conv2d_constructor = torch.nn.Conv2d.__init__
371386

0 commit comments

Comments
 (0)