From 95821f0132f5437ef30b0dbcac7c51e55818c18f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 18:11:13 +0300 Subject: split webui.py's initialization and utility functions into separate files --- modules/gradio_extensons.py | 4 +- modules/initialize.py | 168 +++++++++++++++++++++++++++++++++++++ modules/initialize_util.py | 195 +++++++++++++++++++++++++++++++++++++++++++ modules/shared_init.py | 3 - modules/ui_extra_networks.py | 3 +- modules/ui_tempdir.py | 5 +- 6 files changed, 371 insertions(+), 7 deletions(-) create mode 100644 modules/initialize.py create mode 100644 modules/initialize_util.py (limited to 'modules') diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py index 5af7fd8e..77c34c8b 100644 --- a/modules/gradio_extensons.py +++ b/modules/gradio_extensons.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import scripts +from modules import scripts, ui_tempdir def add_classes_to_gradio_component(comp): """ @@ -58,3 +58,5 @@ original_BlockContext_init = gr.blocks.BlockContext.__init__ gr.components.IOComponent.__init__ = IOComponent_init gr.blocks.Block.get_config = Block_get_config gr.blocks.BlockContext.__init__ = BlockContext_init + +ui_tempdir.install_ui_tempdir_override() diff --git a/modules/initialize.py b/modules/initialize.py new file mode 100644 index 00000000..f24f7637 --- /dev/null +++ b/modules/initialize.py @@ -0,0 +1,168 @@ +import importlib +import logging +import sys +import warnings +from threading import Thread + +from modules.timer import startup_timer + + +def imports(): + logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... + logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + + import torch # noqa: F401 + startup_timer.record("import torch") + import pytorch_lightning # noqa: F401 + startup_timer.record("import torch") + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + + import gradio # noqa: F401 + startup_timer.record("import gradio") + + from modules import paths, timer, import_hook, errors # noqa: F401 + startup_timer.record("setup paths") + + import ldm.modules.encoders.modules # noqa: F401 + startup_timer.record("import ldm") + + import sgm.modules.encoders.modules # noqa: F401 + startup_timer.record("import sgm") + + from modules import shared_init + shared_init.initialize() + startup_timer.record("initialize shared") + + from modules import processing, gradio_extensons, ui # noqa: F401 + startup_timer.record("other imports") + + +def check_versions(): + from modules.shared_cmd_options import cmd_opts + + if not cmd_opts.skip_version_check: + from modules import errors + errors.check_versions() + + +def initialize(): + from modules import initialize_util + initialize_util.fix_torch_version() + initialize_util.fix_asyncio_event_loop_policy() + initialize_util.validate_tls_options() + initialize_util.configure_sigint_handler() + initialize_util.configure_opts_onchange() + + from modules import modelloader + modelloader.cleanup_models() + + from modules import sd_models + sd_models.setup_model() + startup_timer.record("setup SD model") + + from modules.shared_cmd_options import cmd_opts + + from modules import codeformer_model + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor") + codeformer_model.setup_model(cmd_opts.codeformer_models_path) + startup_timer.record("setup codeformer") + + from modules import gfpgan_model + gfpgan_model.setup_model(cmd_opts.gfpgan_models_path) + startup_timer.record("setup gfpgan") + + initialize_rest(reload_script_modules=False) + + +def initialize_rest(*, reload_script_modules=False): + """ + Called both from initialize() and when reloading the webui. + """ + from modules.shared_cmd_options import cmd_opts + + from modules import sd_samplers + sd_samplers.set_samplers() + startup_timer.record("set samplers") + + from modules import extensions + extensions.list_extensions() + startup_timer.record("list extensions") + + from modules import initialize_util + initialize_util.restore_config_state_file() + startup_timer.record("restore config state file") + + from modules import shared, upscaler, scripts + if cmd_opts.ui_debug_mode: + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + scripts.load_scripts() + return + + from modules import sd_models + sd_models.list_models() + startup_timer.record("list SD models") + + from modules import localization + localization.list_localizations(cmd_opts.localizations_dir) + startup_timer.record("list localizations") + + with startup_timer.subcategory("load scripts"): + scripts.load_scripts() + + if reload_script_modules: + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + startup_timer.record("reload script modules") + + from modules import modelloader + modelloader.load_upscalers() + startup_timer.record("load upscalers") + + from modules import sd_vae + sd_vae.refresh_vae_list() + startup_timer.record("refresh VAE") + + from modules import textual_inversion + textual_inversion.textual_inversion.list_textual_inversion_templates() + startup_timer.record("refresh textual inversion templates") + + from modules import script_callbacks, sd_hijack_optimizations, sd_hijack + script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers) + sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + + from modules import sd_unet + sd_unet.list_unets() + startup_timer.record("scripts list_unets") + + def load_model(): + """ + Accesses shared.sd_model property to load model. + After it's available, if it has been loaded before this access by some extension, + its optimization may be None because the list of optimizaers has neet been filled + by that time, so we apply optimization again. + """ + + shared.sd_model # noqa: B018 + + if sd_hijack.current_optimizer is None: + sd_hijack.apply_optimizations() + + from modules import devices + devices.first_time_calculation() + + Thread(target=load_model).start() + + from modules import shared_items + shared_items.reload_hypernetworks() + startup_timer.record("reload hypernetworks") + + from modules import ui_extra_networks + ui_extra_networks.initialize() + ui_extra_networks.register_default_pages() + + from modules import extra_networks + extra_networks.initialize() + extra_networks.register_default_extra_networks() + startup_timer.record("initialize extra networks") diff --git a/modules/initialize_util.py b/modules/initialize_util.py new file mode 100644 index 00000000..e59bd3c4 --- /dev/null +++ b/modules/initialize_util.py @@ -0,0 +1,195 @@ +import json +import logging +import os +import signal +import sys +import re + +from modules.timer import startup_timer + +def setup_logging(): + # We can't use cmd_opts for this because it will not have been initialized at this point. + log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") + if log_level: + log_level = getattr(logging, log_level.upper(), None) or logging.INFO + logging.basicConfig( + level=log_level, + format='%(asctime)s %(levelname)s [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + + +def gradio_server_name(): + from modules.shared_cmd_options import cmd_opts + + if cmd_opts.server_name: + return cmd_opts.server_name + else: + return "0.0.0.0" if cmd_opts.listen else None + + +def fix_torch_version(): + import torch + + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors + if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__long_version__ = torch.__version__ + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) + + +def fix_asyncio_event_loop_policy(): + """ + The default `asyncio` event loop policy only automatically creates + event loops in the main threads. Other threads must create event + loops explicitly or `asyncio.get_event_loop` (and therefore + `.IOLoop.current`) will fail. Installing this policy allows event + loops to be created automatically on any thread, matching the + behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). + """ + + import asyncio + + if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # "Any thread" and "selector" should be orthogonal, but there's not a clean + # interface for composing policies so pick the right base. + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore + else: + _BasePolicy = asyncio.DefaultEventLoopPolicy + + class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore + """Event loop policy that allows loop creation on any thread. + Usage:: + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + """ + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + # This was an AssertionError in python 3.4.2 (which ships with debian jessie) + # and changed to a RuntimeError in 3.4.3. + # "There is no current event loop in thread %r" + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + + +def restore_config_state_file(): + from modules import shared, config_states + + config_state_file = shared.opts.restore_config_state_file + if config_state_file == "": + return + + shared.opts.restore_config_state_file = "" + shared.opts.save(shared.config_filename) + + if os.path.isfile(config_state_file): + print(f"*** About to restore extension state from file: {config_state_file}") + with open(config_state_file, "r", encoding="utf-8") as f: + config_state = json.load(f) + config_states.restore_extension_config(config_state) + startup_timer.record("restore extension config") + elif config_state_file: + print(f"!!! Config state backup not found: {config_state_file}") + + +def validate_tls_options(): + from modules.shared_cmd_options import cmd_opts + + if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile): + return + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + startup_timer.record("TLS") + + +def get_gradio_auth_creds(): + """ + Convert the gradio_auth and gradio_auth_path commandline arguments into + an iterable of (username, password) tuples. + """ + from modules.shared_cmd_options import cmd_opts + + def process_credential_line(s): + s = s.strip() + if not s: + return None + return tuple(s.split(':', 1)) + + if cmd_opts.gradio_auth: + for cred in cmd_opts.gradio_auth.split(','): + cred = process_credential_line(cred) + if cred: + yield cred + + if cmd_opts.gradio_auth_path: + with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + for cred in line.strip().split(','): + cred = process_credential_line(cred) + if cred: + yield cred + + +def configure_sigint_handler(): + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + if not os.environ.get("COVERAGE_RUN"): + # Don't install the immediate-quit handler when running under coverage, + # as then the coverage report won't be generated. + signal.signal(signal.SIGINT, sigint_handler) + + +def configure_opts_onchange(): + from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack + from modules.call_queue import wrap_queued_call + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) + shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + startup_timer.record("opts onchange") + + +def setup_middleware(app): + from starlette.middleware.gzip import GZipMiddleware + + app.middleware_stack = None # reset current middleware to allow modifying user provided list + app.add_middleware(GZipMiddleware, minimum_size=1000) + configure_cors_middleware(app) + app.build_middleware_stack() # rebuild middleware stack on-the-fly + + +def configure_cors_middleware(app): + from starlette.middleware.cors import CORSMiddleware + from modules.shared_cmd_options import cmd_opts + + cors_options = { + "allow_methods": ["*"], + "allow_headers": ["*"], + "allow_credentials": True, + } + if cmd_opts.cors_allow_origins: + cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',') + if cmd_opts.cors_allow_origins_regex: + cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex + app.add_middleware(CORSMiddleware, **cors_options) + diff --git a/modules/shared_init.py b/modules/shared_init.py index b88d1d8e..d3fb687e 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -5,9 +5,6 @@ import torch from modules import shared from modules.shared import cmd_opts -import sys -sys.setrecursionlimit(1000) - def initialize(): """Initializes fields inside the shared module in a controlled manner. diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e0b932b9..16d76a45 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -4,7 +4,6 @@ from pathlib import Path from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo -from modules.ui import up_down_symbol import gradio as gr import json import html @@ -348,6 +347,8 @@ def pages_in_preferred_order(pages): def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): + from modules.ui import up_down_symbol + ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index fb75137e..506017e5 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -57,8 +57,9 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"): return file_obj.name -# override save to file function so that it also writes PNG info -gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file +def install_ui_tempdir_override(): + """override save to file function so that it also writes PNG info""" + gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file def on_tmpdir_changed(): -- cgit v1.2.1