Skip to content

Commit a8cbe50

Browse files
committed
remove duplicated code
1 parent 891ccb7 commit a8cbe50

File tree

2 files changed

+42
-63
lines changed

2 files changed

+42
-63
lines changed

extensions-builtin/Lora/networks.py

+2-29
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Union
1616

1717
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
18-
from modules.textual_inversion.textual_inversion import Embedding
18+
import modules.textual_inversion.textual_inversion as textual_inversion
1919

2020
from lora_logger import logger
2121

@@ -210,34 +210,7 @@ def load_network(name, network_on_disk):
210210

211211
embeddings = {}
212212
for emb_name, data in bundle_embeddings.items():
213-
# textual inversion embeddings
214-
if 'string_to_param' in data:
215-
param_dict = data['string_to_param']
216-
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
217-
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
218-
emb = next(iter(param_dict.items()))[1]
219-
vec = emb.detach().to(devices.device, dtype=torch.float32)
220-
shape = vec.shape[-1]
221-
vectors = vec.shape[0]
222-
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
223-
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
224-
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
225-
vectors = data['clip_g'].shape[0]
226-
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
227-
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
228-
229-
emb = next(iter(data.values()))
230-
if len(emb.shape) == 1:
231-
emb = emb.unsqueeze(0)
232-
vec = emb.detach().to(devices.device, dtype=torch.float32)
233-
shape = vec.shape[-1]
234-
vectors = vec.shape[0]
235-
else:
236-
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
237-
238-
embedding = Embedding(vec, emb_name)
239-
embedding.vectors = vectors
240-
embedding.shape = shape
213+
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
241214
embedding.loaded = None
242215
embeddings[emb_name] = embedding
243216

modules/textual_inversion/textual_inversion.py

+40-34
Original file line numberDiff line numberDiff line change
@@ -181,40 +181,7 @@ def load_from_file(self, path, filename):
181181
else:
182182
return
183183

184-
185-
# textual inversion embeddings
186-
if 'string_to_param' in data:
187-
param_dict = data['string_to_param']
188-
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
189-
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
190-
emb = next(iter(param_dict.items()))[1]
191-
vec = emb.detach().to(devices.device, dtype=torch.float32)
192-
shape = vec.shape[-1]
193-
vectors = vec.shape[0]
194-
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
195-
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
196-
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
197-
vectors = data['clip_g'].shape[0]
198-
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
199-
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
200-
201-
emb = next(iter(data.values()))
202-
if len(emb.shape) == 1:
203-
emb = emb.unsqueeze(0)
204-
vec = emb.detach().to(devices.device, dtype=torch.float32)
205-
shape = vec.shape[-1]
206-
vectors = vec.shape[0]
207-
else:
208-
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
209-
210-
embedding = Embedding(vec, name)
211-
embedding.step = data.get('step', None)
212-
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
213-
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
214-
embedding.vectors = vectors
215-
embedding.shape = shape
216-
embedding.filename = path
217-
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
184+
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
218185

219186
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
220187
self.register_embedding(embedding, shared.sd_model)
@@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
313280
return fn
314281

315282

283+
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
284+
if 'string_to_param' in data: # textual inversion embeddings
285+
param_dict = data['string_to_param']
286+
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
287+
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
288+
emb = next(iter(param_dict.items()))[1]
289+
vec = emb.detach().to(devices.device, dtype=torch.float32)
290+
shape = vec.shape[-1]
291+
vectors = vec.shape[0]
292+
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
293+
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
294+
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
295+
vectors = data['clip_g'].shape[0]
296+
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
297+
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
298+
299+
emb = next(iter(data.values()))
300+
if len(emb.shape) == 1:
301+
emb = emb.unsqueeze(0)
302+
vec = emb.detach().to(devices.device, dtype=torch.float32)
303+
shape = vec.shape[-1]
304+
vectors = vec.shape[0]
305+
else:
306+
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
307+
308+
embedding = Embedding(vec, name)
309+
embedding.step = data.get('step', None)
310+
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
311+
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
312+
embedding.vectors = vectors
313+
embedding.shape = shape
314+
315+
if filepath:
316+
embedding.filename = filepath
317+
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
318+
319+
return embedding
320+
321+
316322
def write_loss(log_directory, filename, step, epoch_len, values):
317323
if shared.opts.training_write_csv_every == 0:
318324
return

0 commit comments

Comments
 (0)