Skip to content

Commit 6dc71b7

Browse files
authored
prevent undoing refresh model load params (AUTOMATIC1111#2092)
Ensures `should_refresh_model_loading_params()` is called when needed. Improved code clarity.
1 parent 9efa4ea commit 6dc71b7

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

modules/sysinfo.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ def set_config(req: dict[str, Any], is_api=False, run_callbacks=True, save_confi
237237
main_entry.checkpoint_change(v, save=False, refresh=False)
238238
should_refresh_model_loading_params = True
239239
elif k == 'forge_additional_modules':
240-
should_refresh_model_loading_params = main_entry.modules_change(v, save=False, refresh=False)
240+
modules_changed = main_entry.modules_change(v, save=False, refresh=False)
241+
if modules_changed:
242+
should_refresh_model_loading_params = True
241243
elif k in memory_keys:
242244
mem_key = k[len('forge_'):] # remove 'forge_' prefix
243245
memory_changes[mem_key] = v

modules_forge/main_entry.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def checkpoint_change(ckpt_name:str, save=True, refresh=True):
250250

251251

252252
def modules_change(module_values:list, save=True, refresh=True) -> bool:
253-
""" module values may be provided as file paths or as simply the module names """
253+
""" module values may be provided as file paths, or just the module names. Returns True if modules changed. """
254254
modules = []
255255
for v in module_values:
256256
module_name = os.path.basename(v) # If the input is a filepath, extract the file name

0 commit comments

Comments
 (0)