Skip to content

Commit 6647a41

Browse files
authored
Merge pull request #1414 from bghira/debug/cache-file-load-failure
fix NoneType error when all images list hits race condition during file save/load
2 parents 22ca734 + 4c55135 commit 6647a41

File tree

2 files changed

+45
-21
lines changed

2 files changed

+45
-21
lines changed

helpers/data_backend/factory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,7 @@ def configure_multi_databackend(
13171317
init_backend["vaecache"].discover_all_files()
13181318
accelerator.wait_for_everyone()
13191319
all_image_files = StateTracker.get_image_files(
1320-
data_backend_id=init_backend["id"]
1320+
data_backend_id=init_backend["id"], retry_limit=3 # some filesystems maybe take longer to make it available.
13211321
)
13221322
init_backend["vaecache"].build_vae_cache_filename_map(
13231323
all_image_files=all_image_files

helpers/training/state_tracker.py

+44-20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from os import environ
22
from pathlib import Path
33
import json
4+
import time, os
45
import logging
56
from helpers.models.all import model_families
67

@@ -82,28 +83,42 @@ def delete_cache_files(
8283
if cache_path.exists():
8384
try:
8485
cache_path.unlink()
86+
logger.warning(f"(rank={os.environ.get('RANK')}) Deleted cache file: {cache_path}")
8587
except:
8688
pass
8789

8890
@classmethod
89-
def _load_from_disk(cls, cache_name):
91+
def _load_from_disk(cls, cache_name, retry_limit: int = 0):
9092
cache_path = Path(cls.args.output_dir) / f"{cache_name}.json"
91-
if cache_path.exists():
92-
try:
93-
with cache_path.open("r") as f:
94-
return json.load(f)
95-
except Exception as e:
96-
logger.error(
97-
f"Invalidating cache: error loading {cache_name} from disk. {e}"
98-
)
99-
return None
100-
return None
93+
retry_count = 0
94+
results = None
95+
while retry_count < retry_limit and (not cache_path.exists() or results is None):
96+
if cache_path.exists():
97+
try:
98+
with cache_path.open("r") as f:
99+
results = json.load(f)
100+
except Exception as e:
101+
logger.error(
102+
f"Invalidating cache: error loading {cache_name} from disk. {e}"
103+
)
104+
return None
105+
else:
106+
retry_count += 1
107+
if retry_count < retry_limit:
108+
logger.debug(f"Cache file {cache_name} does not exist. Retry {retry_count}/{retry_limit}.")
109+
time.sleep(1)
110+
else:
111+
logger.warning(f"No cache file was found: {cache_path}")
112+
logger.debug(f"Returning: {type(results)}")
113+
return results
101114

102115
@classmethod
103116
def _save_to_disk(cls, cache_name, data):
104117
cache_path = Path(cls.args.output_dir) / f"{cache_name}.json"
118+
logger.debug(f"(rank={os.environ.get('RANK')}) Saving {cache_name} to disk: {cache_path}")
105119
with cache_path.open("w") as f:
106120
json.dump(data, f)
121+
logger.debug(f"(rank={os.environ.get('RANK')}) Save complete {cache_name} to disk: {cache_path}")
107122

108123
@classmethod
109124
def set_config_path(cls, config_path: str):
@@ -180,11 +195,20 @@ def set_image_files(cls, raw_file_list: list, data_backend_id: str):
180195
return cls.all_image_files[data_backend_id]
181196

182197
@classmethod
183-
def get_image_files(cls, data_backend_id: str):
198+
def get_image_files(cls, data_backend_id: str, retry_limit: int = 0):
199+
if data_backend_id in cls.all_image_files and cls.all_image_files[data_backend_id] is None:
200+
# we should probaby try to reload it from disk if it failed earlier.
201+
logger.debug(f"(rank={os.environ.get('RANK')}) Clearing out invalid pre-loaded cache entry for {data_backend_id}")
202+
del cls.all_image_files[data_backend_id]
184203
if data_backend_id not in cls.all_image_files:
204+
logger.debug(f"(rank={os.environ.get('RANK')}) Attempting to load from disk: {data_backend_id}")
185205
cls.all_image_files[data_backend_id] = cls._load_from_disk(
186-
"all_image_files_{}".format(data_backend_id)
206+
"all_image_files_{}".format(data_backend_id), retry_limit=retry_limit
187207
)
208+
logger.debug(f"(rank={os.environ.get('RANK')}) Completed load from disk: {data_backend_id}: {type(cls.all_image_files[data_backend_id])}")
209+
else:
210+
logger.debug(f"()")
211+
logger.debug(f"(rank={os.environ.get('RANK')}) Returning {type(cls.all_image_files[data_backend_id])} for {data_backend_id}")
188212
return cls.all_image_files[data_backend_id]
189213

190214
@classmethod
@@ -330,7 +354,7 @@ def set_vae_cache_files(cls, raw_file_list: list, data_backend_id: str):
330354
)
331355

332356
@classmethod
333-
def get_vae_cache_files(cls: list, data_backend_id: str):
357+
def get_vae_cache_files(cls: list, data_backend_id: str, retry_limit: int = 0):
334358
if (
335359
data_backend_id not in cls.all_vae_cache_files
336360
or cls.all_vae_cache_files.get(data_backend_id) is None
@@ -359,10 +383,10 @@ def set_text_cache_files(cls, raw_file_list: list, data_backend_id: str):
359383
)
360384

361385
@classmethod
362-
def get_text_cache_files(cls: list, data_backend_id: str):
386+
def get_text_cache_files(cls: list, data_backend_id: str, retry_limit: int = 0):
363387
if data_backend_id not in cls.all_text_cache_files:
364388
cls.all_text_cache_files[data_backend_id] = cls._load_from_disk(
365-
"all_text_cache_files_{}".format(data_backend_id)
389+
"all_text_cache_files_{}".format(data_backend_id), retry_limit=retry_limit
366390
)
367391
return cls.all_text_cache_files[data_backend_id]
368392

@@ -372,9 +396,9 @@ def set_caption_files(cls, caption_files):
372396
cls._save_to_disk("all_caption_files", cls.all_caption_files)
373397

374398
@classmethod
375-
def get_caption_files(cls):
399+
def get_caption_files(cls, retry_limit: int = 0):
376400
if not cls.all_caption_files:
377-
cls.all_caption_files = cls._load_from_disk("all_caption_files")
401+
cls.all_caption_files = cls._load_from_disk("all_caption_files", retry_limit=retry_limit)
378402
return cls.all_caption_files
379403

380404
@classmethod
@@ -560,12 +584,12 @@ def save_aspect_resolution_map(cls, dataloader_resolution: float):
560584
)
561585

562586
@classmethod
563-
def load_aspect_resolution_map(cls, dataloader_resolution: float):
587+
def load_aspect_resolution_map(cls, dataloader_resolution: float, retry_limit: int = 0):
564588
if dataloader_resolution not in cls.aspect_resolution_map:
565589
cls.aspect_resolution_map = {dataloader_resolution: {}}
566590

567591
cls.aspect_resolution_map[dataloader_resolution] = (
568-
cls._load_from_disk(f"aspect_resolution_map-{dataloader_resolution}") or {}
592+
cls._load_from_disk(f"aspect_resolution_map-{dataloader_resolution}", retry_limit=retry_limit) or {}
569593
)
570594
logger.debug(
571595
f"Aspect resolution map: {cls.aspect_resolution_map[dataloader_resolution]}"

0 commit comments

Comments
 (0)