Skip to content

Commit 34fb0df

Browse files
committed
Gradio theme cache
1 parent c6278c1 commit 34fb0df

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

modules/shared.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def list_samplers():
550550
options_templates.update(options_section(('ui', "User interface"), {
551551
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
552552
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
553-
"return_grid": OptionInfo(True, "Show grid in results for web"),
553+
"re_download_theme": OptionInfo(False, "Re-download the selected Gradio theme"),
554554
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
555555
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
556556
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
@@ -846,6 +846,38 @@ def sd_model(self, value):
846846

847847
progress_print_out = sys.stdout
848848

849+
850+
def from_hub_with_cache_wrapper(func):
851+
def wrapper(*args, **kwargs):
852+
import pickle
853+
repo_name = ''
854+
if 'key_name' in kwargs:
855+
repo_name = kwargs['repo_name']
856+
elif args and len(args) >= 1:
857+
repo_name = args[0]
858+
if repo_name:
859+
theme_cache_path = os.path.join(script_path, 'tmp', 'gradio_themes', repo_name.replace('/', '_'))
860+
# if theme is cached use cache and same gradio version
861+
if not opts.re_download_theme and os.path.exists(theme_cache_path):
862+
with open(theme_cache_path, 'rb') as cached_theme:
863+
theme_cache = pickle.load(cached_theme)
864+
if gr.__version__ == theme_cache.get('gradio_version'):
865+
return theme_cache.get('theme')
866+
# get theme from hub
867+
result = func(*args, **kwargs)
868+
# save theme to cache
869+
os.makedirs(os.path.dirname(theme_cache_path), exist_ok=True)
870+
with open(theme_cache_path, 'wb') as cached_theme:
871+
theme_cache = {'theme': result, 'gradio_version': gr.__version__}
872+
pickle.dump(theme_cache, cached_theme)
873+
874+
return result
875+
return wrapper
876+
877+
878+
gr.themes.ThemeClass.from_hub = from_hub_with_cache_wrapper(gr.themes.ThemeClass.from_hub) # decorates gr.themes.ThemeClass.from_hub with from_hub_with_cache_wrapper
879+
880+
849881
gradio_theme = gr.themes.Base()
850882

851883

@@ -869,7 +901,6 @@ def reload_gradio_theme(theme_name=None):
869901
gradio_theme = gr.themes.Default(**default_theme_args)
870902

871903

872-
873904
class TotalTQDM:
874905
def __init__(self):
875906
self._tqdm = None

0 commit comments

Comments
 (0)