@@ -181,40 +181,7 @@ def load_from_file(self, path, filename):
181
181
else :
182
182
return
183
183
184
-
185
- # textual inversion embeddings
186
- if 'string_to_param' in data :
187
- param_dict = data ['string_to_param' ]
188
- param_dict = getattr (param_dict , '_parameters' , param_dict ) # fix for torch 1.12.1 loading saved file from torch 1.11
189
- assert len (param_dict ) == 1 , 'embedding file has multiple terms in it'
190
- emb = next (iter (param_dict .items ()))[1 ]
191
- vec = emb .detach ().to (devices .device , dtype = torch .float32 )
192
- shape = vec .shape [- 1 ]
193
- vectors = vec .shape [0 ]
194
- elif type (data ) == dict and 'clip_g' in data and 'clip_l' in data : # SDXL embedding
195
- vec = {k : v .detach ().to (devices .device , dtype = torch .float32 ) for k , v in data .items ()}
196
- shape = data ['clip_g' ].shape [- 1 ] + data ['clip_l' ].shape [- 1 ]
197
- vectors = data ['clip_g' ].shape [0 ]
198
- elif type (data ) == dict and type (next (iter (data .values ()))) == torch .Tensor : # diffuser concepts
199
- assert len (data .keys ()) == 1 , 'embedding file has multiple terms in it'
200
-
201
- emb = next (iter (data .values ()))
202
- if len (emb .shape ) == 1 :
203
- emb = emb .unsqueeze (0 )
204
- vec = emb .detach ().to (devices .device , dtype = torch .float32 )
205
- shape = vec .shape [- 1 ]
206
- vectors = vec .shape [0 ]
207
- else :
208
- raise Exception (f"Couldn't identify { filename } as neither textual inversion embedding nor diffuser concept." )
209
-
210
- embedding = Embedding (vec , name )
211
- embedding .step = data .get ('step' , None )
212
- embedding .sd_checkpoint = data .get ('sd_checkpoint' , None )
213
- embedding .sd_checkpoint_name = data .get ('sd_checkpoint_name' , None )
214
- embedding .vectors = vectors
215
- embedding .shape = shape
216
- embedding .filename = path
217
- embedding .set_hash (hashes .sha256 (embedding .filename , "textual_inversion/" + name ) or '' )
184
+ embedding = create_embedding_from_data (data , name , filename = filename , filepath = path )
218
185
219
186
if self .expected_shape == - 1 or self .expected_shape == embedding .shape :
220
187
self .register_embedding (embedding , shared .sd_model )
@@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
313
280
return fn
314
281
315
282
283
+ def create_embedding_from_data (data , name , filename = 'unknown embedding file' , filepath = None ):
284
+ if 'string_to_param' in data : # textual inversion embeddings
285
+ param_dict = data ['string_to_param' ]
286
+ param_dict = getattr (param_dict , '_parameters' , param_dict ) # fix for torch 1.12.1 loading saved file from torch 1.11
287
+ assert len (param_dict ) == 1 , 'embedding file has multiple terms in it'
288
+ emb = next (iter (param_dict .items ()))[1 ]
289
+ vec = emb .detach ().to (devices .device , dtype = torch .float32 )
290
+ shape = vec .shape [- 1 ]
291
+ vectors = vec .shape [0 ]
292
+ elif type (data ) == dict and 'clip_g' in data and 'clip_l' in data : # SDXL embedding
293
+ vec = {k : v .detach ().to (devices .device , dtype = torch .float32 ) for k , v in data .items ()}
294
+ shape = data ['clip_g' ].shape [- 1 ] + data ['clip_l' ].shape [- 1 ]
295
+ vectors = data ['clip_g' ].shape [0 ]
296
+ elif type (data ) == dict and type (next (iter (data .values ()))) == torch .Tensor : # diffuser concepts
297
+ assert len (data .keys ()) == 1 , 'embedding file has multiple terms in it'
298
+
299
+ emb = next (iter (data .values ()))
300
+ if len (emb .shape ) == 1 :
301
+ emb = emb .unsqueeze (0 )
302
+ vec = emb .detach ().to (devices .device , dtype = torch .float32 )
303
+ shape = vec .shape [- 1 ]
304
+ vectors = vec .shape [0 ]
305
+ else :
306
+ raise Exception (f"Couldn't identify { filename } as neither textual inversion embedding nor diffuser concept." )
307
+
308
+ embedding = Embedding (vec , name )
309
+ embedding .step = data .get ('step' , None )
310
+ embedding .sd_checkpoint = data .get ('sd_checkpoint' , None )
311
+ embedding .sd_checkpoint_name = data .get ('sd_checkpoint_name' , None )
312
+ embedding .vectors = vectors
313
+ embedding .shape = shape
314
+
315
+ if filepath :
316
+ embedding .filename = filepath
317
+ embedding .set_hash (hashes .sha256 (filepath , "textual_inversion/" + name ) or '' )
318
+
319
+ return embedding
320
+
321
+
316
322
def write_loss (log_directory , filename , step , epoch_len , values ):
317
323
if shared .opts .training_write_csv_every == 0 :
318
324
return
0 commit comments