Skip to content

Commit 4be7b62

Browse files
Merge pull request #13568 from AUTOMATIC1111/lora_emb_bundle
Add lora-embedding bundle system
2 parents 19f5795 + a8cbe50 commit 4be7b62

File tree

4 files changed

+115
-34
lines changed

4 files changed

+115
-34
lines changed
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import sys
2+
import copy
3+
import logging
4+
5+
6+
class ColoredFormatter(logging.Formatter):
7+
COLORS = {
8+
"DEBUG": "\033[0;36m", # CYAN
9+
"INFO": "\033[0;32m", # GREEN
10+
"WARNING": "\033[0;33m", # YELLOW
11+
"ERROR": "\033[0;31m", # RED
12+
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
13+
"RESET": "\033[0m", # RESET COLOR
14+
}
15+
16+
def format(self, record):
17+
colored_record = copy.copy(record)
18+
levelname = colored_record.levelname
19+
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
20+
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
21+
return super().format(colored_record)
22+
23+
24+
logger = logging.getLogger("lora")
25+
logger.propagate = False
26+
27+
28+
if not logger.handlers:
29+
handler = logging.StreamHandler(sys.stdout)
30+
handler.setFormatter(
31+
ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
32+
)
33+
logger.addHandler(handler)

extensions-builtin/Lora/network.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(self, name, network_on_disk: NetworkOnDisk):
9393
self.unet_multiplier = 1.0
9494
self.dyn_dim = None
9595
self.modules = {}
96+
self.bundle_embeddings = {}
9697
self.mtime = None
9798

9899
self.mentioned_name = None

extensions-builtin/Lora/networks.py

+41
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from typing import Union
1717

1818
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
19+
import modules.textual_inversion.textual_inversion as textual_inversion
20+
21+
from lora_logger import logger
1922

2023
module_types = [
2124
network_lora.ModuleTypeLora(),
@@ -151,9 +154,19 @@ def load_network(name, network_on_disk):
151154
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
152155

153156
matched_networks = {}
157+
bundle_embeddings = {}
154158

155159
for key_network, weight in sd.items():
156160
key_network_without_network_parts, network_part = key_network.split(".", 1)
161+
if key_network_without_network_parts == "bundle_emb":
162+
emb_name, vec_name = network_part.split(".", 1)
163+
emb_dict = bundle_embeddings.get(emb_name, {})
164+
if vec_name.split('.')[0] == 'string_to_param':
165+
_, k2 = vec_name.split('.', 1)
166+
emb_dict['string_to_param'] = {k2: weight}
167+
else:
168+
emb_dict[vec_name] = weight
169+
bundle_embeddings[emb_name] = emb_dict
157170

158171
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
159172
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@@ -197,6 +210,14 @@ def load_network(name, network_on_disk):
197210

198211
net.modules[key] = net_module
199212

213+
embeddings = {}
214+
for emb_name, data in bundle_embeddings.items():
215+
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
216+
embedding.loaded = None
217+
embeddings[emb_name] = embedding
218+
219+
net.bundle_embeddings = embeddings
220+
200221
if keys_failed_to_match:
201222
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
202223

@@ -212,11 +233,15 @@ def purge_networks_from_memory():
212233

213234

214235
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
236+
emb_db = sd_hijack.model_hijack.embedding_db
215237
already_loaded = {}
216238

217239
for net in loaded_networks:
218240
if net.name in names:
219241
already_loaded[net.name] = net
242+
for emb_name, embedding in net.bundle_embeddings.items():
243+
if embedding.loaded:
244+
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
220245

221246
loaded_networks.clear()
222247

@@ -259,6 +284,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
259284
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
260285
loaded_networks.append(net)
261286

287+
for emb_name, embedding in net.bundle_embeddings.items():
288+
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
289+
logger.warning(
290+
f'Skip bundle embedding: "{emb_name}"'
291+
' as it was already loaded from embeddings folder'
292+
)
293+
continue
294+
295+
embedding.loaded = False
296+
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
297+
embedding.loaded = True
298+
emb_db.register_embedding(embedding, shared.sd_model)
299+
else:
300+
emb_db.skipped_embeddings[name] = embedding
301+
262302
if failed_to_load_networks:
263303
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
264304

@@ -567,6 +607,7 @@ def infotext_pasted(infotext, params):
567607
available_networks = {}
568608
available_network_aliases = {}
569609
loaded_networks = []
610+
loaded_bundle_embeddings = {}
570611
networks_in_memory = {}
571612
available_network_hash_lookup = {}
572613
forbidden_network_aliases = {}

modules/textual_inversion/textual_inversion.py

+40-34
Original file line numberDiff line numberDiff line change
@@ -181,40 +181,7 @@ def load_from_file(self, path, filename):
181181
else:
182182
return
183183

184-
185-
# textual inversion embeddings
186-
if 'string_to_param' in data:
187-
param_dict = data['string_to_param']
188-
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
189-
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
190-
emb = next(iter(param_dict.items()))[1]
191-
vec = emb.detach().to(devices.device, dtype=torch.float32)
192-
shape = vec.shape[-1]
193-
vectors = vec.shape[0]
194-
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
195-
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
196-
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
197-
vectors = data['clip_g'].shape[0]
198-
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
199-
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
200-
201-
emb = next(iter(data.values()))
202-
if len(emb.shape) == 1:
203-
emb = emb.unsqueeze(0)
204-
vec = emb.detach().to(devices.device, dtype=torch.float32)
205-
shape = vec.shape[-1]
206-
vectors = vec.shape[0]
207-
else:
208-
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
209-
210-
embedding = Embedding(vec, name)
211-
embedding.step = data.get('step', None)
212-
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
213-
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
214-
embedding.vectors = vectors
215-
embedding.shape = shape
216-
embedding.filename = path
217-
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
184+
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
218185

219186
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
220187
self.register_embedding(embedding, shared.sd_model)
@@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
313280
return fn
314281

315282

283+
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
284+
if 'string_to_param' in data: # textual inversion embeddings
285+
param_dict = data['string_to_param']
286+
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
287+
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
288+
emb = next(iter(param_dict.items()))[1]
289+
vec = emb.detach().to(devices.device, dtype=torch.float32)
290+
shape = vec.shape[-1]
291+
vectors = vec.shape[0]
292+
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
293+
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
294+
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
295+
vectors = data['clip_g'].shape[0]
296+
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
297+
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
298+
299+
emb = next(iter(data.values()))
300+
if len(emb.shape) == 1:
301+
emb = emb.unsqueeze(0)
302+
vec = emb.detach().to(devices.device, dtype=torch.float32)
303+
shape = vec.shape[-1]
304+
vectors = vec.shape[0]
305+
else:
306+
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
307+
308+
embedding = Embedding(vec, name)
309+
embedding.step = data.get('step', None)
310+
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
311+
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
312+
embedding.vectors = vectors
313+
embedding.shape = shape
314+
315+
if filepath:
316+
embedding.filename = filepath
317+
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
318+
319+
return embedding
320+
321+
316322
def write_loss(log_directory, filename, step, epoch_len, values):
317323
if shared.opts.training_write_csv_every == 0:
318324
return

0 commit comments

Comments
 (0)