1
1
from os import environ
2
2
from pathlib import Path
3
3
import json
4
+ import time , os
4
5
import logging
5
6
from helpers .models .all import model_families
6
7
@@ -82,28 +83,42 @@ def delete_cache_files(
82
83
if cache_path .exists ():
83
84
try :
84
85
cache_path .unlink ()
86
+ logger .warning (f"(rank={ os .environ .get ('RANK' )} ) Deleted cache file: { cache_path } " )
85
87
except :
86
88
pass
87
89
88
90
@classmethod
89
- def _load_from_disk (cls , cache_name ):
91
+ def _load_from_disk (cls , cache_name , retry_limit : int = 0 ):
90
92
cache_path = Path (cls .args .output_dir ) / f"{ cache_name } .json"
91
- if cache_path .exists ():
92
- try :
93
- with cache_path .open ("r" ) as f :
94
- return json .load (f )
95
- except Exception as e :
96
- logger .error (
97
- f"Invalidating cache: error loading { cache_name } from disk. { e } "
98
- )
99
- return None
100
- return None
93
+ retry_count = 0
94
+ results = None
95
+ while retry_count < retry_limit and (not cache_path .exists () or results is None ):
96
+ if cache_path .exists ():
97
+ try :
98
+ with cache_path .open ("r" ) as f :
99
+ results = json .load (f )
100
+ except Exception as e :
101
+ logger .error (
102
+ f"Invalidating cache: error loading { cache_name } from disk. { e } "
103
+ )
104
+ return None
105
+ else :
106
+ retry_count += 1
107
+ if retry_count < retry_limit :
108
+ logger .debug (f"Cache file { cache_name } does not exist. Retry { retry_count } /{ retry_limit } ." )
109
+ time .sleep (1 )
110
+ else :
111
+ logger .warning (f"No cache file was found: { cache_path } " )
112
+ logger .debug (f"Returning: { type (results )} " )
113
+ return results
101
114
102
115
@classmethod
103
116
def _save_to_disk (cls , cache_name , data ):
104
117
cache_path = Path (cls .args .output_dir ) / f"{ cache_name } .json"
118
+ logger .debug (f"(rank={ os .environ .get ('RANK' )} ) Saving { cache_name } to disk: { cache_path } " )
105
119
with cache_path .open ("w" ) as f :
106
120
json .dump (data , f )
121
+ logger .debug (f"(rank={ os .environ .get ('RANK' )} ) Save complete { cache_name } to disk: { cache_path } " )
107
122
108
123
@classmethod
109
124
def set_config_path (cls , config_path : str ):
@@ -180,11 +195,20 @@ def set_image_files(cls, raw_file_list: list, data_backend_id: str):
180
195
return cls .all_image_files [data_backend_id ]
181
196
182
197
@classmethod
183
- def get_image_files (cls , data_backend_id : str ):
198
+ def get_image_files (cls , data_backend_id : str , retry_limit : int = 0 ):
199
+ if data_backend_id in cls .all_image_files and cls .all_image_files [data_backend_id ] is None :
200
+ # we should probaby try to reload it from disk if it failed earlier.
201
+ logger .debug (f"(rank={ os .environ .get ('RANK' )} ) Clearing out invalid pre-loaded cache entry for { data_backend_id } " )
202
+ del cls .all_image_files [data_backend_id ]
184
203
if data_backend_id not in cls .all_image_files :
204
+ logger .debug (f"(rank={ os .environ .get ('RANK' )} ) Attempting to load from disk: { data_backend_id } " )
185
205
cls .all_image_files [data_backend_id ] = cls ._load_from_disk (
186
- "all_image_files_{}" .format (data_backend_id )
206
+ "all_image_files_{}" .format (data_backend_id ), retry_limit = retry_limit
187
207
)
208
+ logger .debug (f"(rank={ os .environ .get ('RANK' )} ) Completed load from disk: { data_backend_id } : { type (cls .all_image_files [data_backend_id ])} " )
209
+ else :
210
+ logger .debug (f"()" )
211
+ logger .debug (f"(rank={ os .environ .get ('RANK' )} ) Returning { type (cls .all_image_files [data_backend_id ])} for { data_backend_id } " )
188
212
return cls .all_image_files [data_backend_id ]
189
213
190
214
@classmethod
@@ -330,7 +354,7 @@ def set_vae_cache_files(cls, raw_file_list: list, data_backend_id: str):
330
354
)
331
355
332
356
@classmethod
333
- def get_vae_cache_files (cls : list , data_backend_id : str ):
357
+ def get_vae_cache_files (cls : list , data_backend_id : str , retry_limit : int = 0 ):
334
358
if (
335
359
data_backend_id not in cls .all_vae_cache_files
336
360
or cls .all_vae_cache_files .get (data_backend_id ) is None
@@ -359,10 +383,10 @@ def set_text_cache_files(cls, raw_file_list: list, data_backend_id: str):
359
383
)
360
384
361
385
@classmethod
362
- def get_text_cache_files (cls : list , data_backend_id : str ):
386
+ def get_text_cache_files (cls : list , data_backend_id : str , retry_limit : int = 0 ):
363
387
if data_backend_id not in cls .all_text_cache_files :
364
388
cls .all_text_cache_files [data_backend_id ] = cls ._load_from_disk (
365
- "all_text_cache_files_{}" .format (data_backend_id )
389
+ "all_text_cache_files_{}" .format (data_backend_id ), retry_limit = retry_limit
366
390
)
367
391
return cls .all_text_cache_files [data_backend_id ]
368
392
@@ -372,9 +396,9 @@ def set_caption_files(cls, caption_files):
372
396
cls ._save_to_disk ("all_caption_files" , cls .all_caption_files )
373
397
374
398
@classmethod
375
- def get_caption_files (cls ):
399
+ def get_caption_files (cls , retry_limit : int = 0 ):
376
400
if not cls .all_caption_files :
377
- cls .all_caption_files = cls ._load_from_disk ("all_caption_files" )
401
+ cls .all_caption_files = cls ._load_from_disk ("all_caption_files" , retry_limit = retry_limit )
378
402
return cls .all_caption_files
379
403
380
404
@classmethod
@@ -560,12 +584,12 @@ def save_aspect_resolution_map(cls, dataloader_resolution: float):
560
584
)
561
585
562
586
@classmethod
563
- def load_aspect_resolution_map (cls , dataloader_resolution : float ):
587
+ def load_aspect_resolution_map (cls , dataloader_resolution : float , retry_limit : int = 0 ):
564
588
if dataloader_resolution not in cls .aspect_resolution_map :
565
589
cls .aspect_resolution_map = {dataloader_resolution : {}}
566
590
567
591
cls .aspect_resolution_map [dataloader_resolution ] = (
568
- cls ._load_from_disk (f"aspect_resolution_map-{ dataloader_resolution } " ) or {}
592
+ cls ._load_from_disk (f"aspect_resolution_map-{ dataloader_resolution } " , retry_limit = retry_limit ) or {}
569
593
)
570
594
logger .debug (
571
595
f"Aspect resolution map: { cls .aspect_resolution_map [dataloader_resolution ]} "
0 commit comments