diff options
Diffstat (limited to 'extensions-builtin')
-rw-r--r-- | extensions-builtin/LDSR/scripts/ldsr_model.py | 7 | ||||
-rw-r--r-- | extensions-builtin/ScuNET/scripts/scunet_model.py | 5 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/scripts/swinir_model.py | 8 |
3 files changed, 10 insertions, 10 deletions
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index dbd6d331..bf9b6de2 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,7 +1,6 @@ import os -from basicsr.utils.download_util import load_file_from_url - +from modules.modelloader import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks, errors @@ -43,9 +42,9 @@ class UpscalerLDSR(Upscaler): if local_safetensors_path is not None and os.path.exists(local_safetensors_path): model = local_safetensors_path else: - model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True) + model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt") - yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True) + yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml") try: return LDSR(model, yaml) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 85b4505f..2785b551 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -6,12 +6,11 @@ import numpy as np import torch from tqdm import tqdm -from basicsr.utils.download_util import load_file_from_url - import modules.upscaler from modules import devices, modelloader, script_callbacks, errors from scunet_model_arch import SCUNet as net +from modules.modelloader import load_file_from_url from modules.shared import opts @@ -120,7 +119,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def load_model(self, path: str): device = devices.get_device_for('scunet') if "http" in path: - filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True) + filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 1c7bf325..a5b0e2eb 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -3,7 +3,6 @@ import os import numpy as np import torch from PIL import Image -from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared @@ -50,8 +49,11 @@ class UpscalerSwinIR(Upscaler): def load_model(self, path, scale=4): if "http" in path: - dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") - filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True) + filename = modelloader.load_file_from_url( + url=path, + model_dir=self.model_download_path, + file_name=f"{self.model_name.replace(' ', '_')}.pth", + ) else: filename = path if filename is None or not os.path.exists(filename): |