diff options
author | brkirch <brkirch@users.noreply.github.com> | 2022-10-04 01:04:19 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-04 01:04:19 -0400 |
commit | b88e4ea7d6536ddc814a9101a258f87b06230019 (patch) | |
tree | 913bdda300b7082aa61429689d0a730cecda1525 /modules/modelloader.py | |
parent | bdaa36c84470adbdce3e98c01a69af5e95adfb02 (diff) | |
parent | 2865ef4b9ab16d56326cc805541bebcf01d099bc (diff) |
Merge branch 'master' into master
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r-- | modules/modelloader.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py index 8c862b42..b0f2f33d 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -5,7 +5,6 @@ import importlib from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url - from modules import shared from modules.upscaler import Upscaler from modules.paths import script_path, models_path @@ -43,7 +42,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None for place in places: if os.path.exists(place): for file in glob.iglob(place + '**/**', recursive=True): - full_path = os.path.join(place, file) + full_path = file if os.path.isdir(full_path): continue if len(ext_filter) != 0: @@ -121,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): def load_upscalers(): + sd = shared.script_path + # We can only do this 'magic' method to dynamically load upscalers if they are referenced, + # so we'll try to import any _model.py files before looking in __subclasses__ + modules_dir = os.path.join(sd, "modules") + for file in os.listdir(modules_dir): + if "_model.py" in file: + model_name = file.replace("_model.py", "") + full_model = f"modules.{model_name}_model" + try: + importlib.import_module(full_model) + except: + pass datas = [] + c_o = vars(shared.cmd_opts) for cls in Upscaler.__subclasses__(): name = cls.__name__ module_name = cls.__module__ module = importlib.import_module(module_name) class_ = getattr(module, name) - cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" + cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" opt_string = None try: - opt_string = shared.opts.__getattr__(cmd_name) + if cmd_name in c_o: + opt_string = c_o[cmd_name] except: pass scaler = class_(opt_string) |