Skip to content

Commit 4c55135

Browse files
author
bghira
committed
state tracker should have a more robust mechanism for loading cache failures from disk
1 parent df233a0 commit 4c55135

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

helpers/training/state_tracker.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from os import environ
22
from pathlib import Path
33
import json
4+
import time, os
45
import logging
56
from helpers.models.all import model_families
67

@@ -82,34 +83,42 @@ def delete_cache_files(
8283
if cache_path.exists():
8384
try:
8485
cache_path.unlink()
86+
logger.warning(f"(rank={os.environ.get('RANK')}) Deleted cache file: {cache_path}")
8587
except:
8688
pass
8789

8890
@classmethod
8991
def _load_from_disk(cls, cache_name, retry_limit: int = 0):
9092
cache_path = Path(cls.args.output_dir) / f"{cache_name}.json"
9193
retry_count = 0
92-
while retry_count < retry_limit and not cache_path.exists():
94+
results = None
95+
while retry_count < retry_limit and (not cache_path.exists() or results is None):
9396
if cache_path.exists():
9497
try:
9598
with cache_path.open("r") as f:
96-
return json.load(f)
99+
results = json.load(f)
97100
except Exception as e:
98101
logger.error(
99102
f"Invalidating cache: error loading {cache_name} from disk. {e}"
100103
)
101104
return None
102-
retry_count += 1
103-
if retry_count < retry_limit:
104-
logger.warning(f"Cache file {cache_name} does not exist. Retry {retry_count}/{retry_limit}.")
105-
logger.warning(f"No cache file was found: {cache_path}")
106-
return None
105+
else:
106+
retry_count += 1
107+
if retry_count < retry_limit:
108+
logger.debug(f"Cache file {cache_name} does not exist. Retry {retry_count}/{retry_limit}.")
109+
time.sleep(1)
110+
else:
111+
logger.warning(f"No cache file was found: {cache_path}")
112+
logger.debug(f"Returning: {type(results)}")
113+
return results
107114

108115
@classmethod
109116
def _save_to_disk(cls, cache_name, data):
110117
cache_path = Path(cls.args.output_dir) / f"{cache_name}.json"
118+
logger.debug(f"(rank={os.environ.get('RANK')}) Saving {cache_name} to disk: {cache_path}")
111119
with cache_path.open("w") as f:
112120
json.dump(data, f)
121+
logger.debug(f"(rank={os.environ.get('RANK')}) Save complete {cache_name} to disk: {cache_path}")
113122

114123
@classmethod
115124
def set_config_path(cls, config_path: str):
@@ -187,10 +196,19 @@ def set_image_files(cls, raw_file_list: list, data_backend_id: str):
187196

188197
@classmethod
189198
def get_image_files(cls, data_backend_id: str, retry_limit: int = 0):
199+
if data_backend_id in cls.all_image_files and cls.all_image_files[data_backend_id] is None:
200+
# we should probaby try to reload it from disk if it failed earlier.
201+
logger.debug(f"(rank={os.environ.get('RANK')}) Clearing out invalid pre-loaded cache entry for {data_backend_id}")
202+
del cls.all_image_files[data_backend_id]
190203
if data_backend_id not in cls.all_image_files:
204+
logger.debug(f"(rank={os.environ.get('RANK')}) Attempting to load from disk: {data_backend_id}")
191205
cls.all_image_files[data_backend_id] = cls._load_from_disk(
192206
"all_image_files_{}".format(data_backend_id), retry_limit=retry_limit
193207
)
208+
logger.debug(f"(rank={os.environ.get('RANK')}) Completed load from disk: {data_backend_id}: {type(cls.all_image_files[data_backend_id])}")
209+
else:
210+
logger.debug(f"()")
211+
logger.debug(f"(rank={os.environ.get('RANK')}) Returning {type(cls.all_image_files[data_backend_id])} for {data_backend_id}")
194212
return cls.all_image_files[data_backend_id]
195213

196214
@classmethod
@@ -571,7 +589,7 @@ def load_aspect_resolution_map(cls, dataloader_resolution: float, retry_limit: i
571589
cls.aspect_resolution_map = {dataloader_resolution: {}}
572590

573591
cls.aspect_resolution_map[dataloader_resolution] = (
574-
cls._load_from_disk(f"aspect_resolution_map-{dataloader_resolution}") or {}, , retry_limit=retry_limit
592+
cls._load_from_disk(f"aspect_resolution_map-{dataloader_resolution}", retry_limit=retry_limit) or {}
575593
)
576594
logger.debug(
577595
f"Aspect resolution map: {cls.aspect_resolution_map[dataloader_resolution]}"

0 commit comments

Comments
 (0)