From 740070ea9cdb254209f66417418f2a4af8b099d6 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Mon, 26 Sep 2022 09:29:50 -0500 Subject: Re-implement universal model loading --- modules/esrgan_model.py | 56 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 15 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 7f3baf31..dd0ee629 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -5,15 +5,35 @@ import traceback import numpy as np import torch from PIL import Image +from basicsr.utils.download_util import load_file_from_url import modules.esrgam_model_arch as arch +import modules.images from modules import shared -from modules.shared import opts +from modules import shared, modelloader from modules.devices import has_mps -import modules.images - +from modules.paths import models_path +from modules.shared import opts -def load_model(filename): +model_dir = "ESRGAN" +model_path = os.path.join(models_path, model_dir) +model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download" +model_name = "ESRGAN_x4.pth" + + +def load_model(path: str, name: str): + global model_path + global model_url + global model_dir + global model_name + if "http" in path: + filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True) + else: + filename = path + if not os.path.exists(filename) or filename is None: + print("Unable to load %s from %s" % (model_dir, filename)) + return None + print("Loading %s from %s" % (model_dir, filename)) # this code is adapted from https://github.com/xinntao/ESRGAN pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) @@ -118,24 +138,30 @@ def esrgan_upscale(model, img): class UpscalerESRGAN(modules.images.Upscaler): def __init__(self, filename, title): self.name = title - self.model = load_model(filename) + self.filename = filename def do_upscale(self, img): - model = self.model.to(shared.device) + model = load_model(self.filename, self.name) + if model is None: + return img + model.to(shared.device) img = esrgan_upscale(model, img) return img -def load_models(dirname): - for file in os.listdir(dirname): - path = os.path.join(dirname, file) - model_name, extension = os.path.splitext(file) - - if extension != '.pt' and extension != '.pth': - continue +def setup_model(dirname): + global model_path + global model_name + if not os.path.exists(model_path): + os.makedirs(model_path) + model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"]) + if len(model_paths) == 0: + modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name)) + for file in model_paths: + name = modelloader.friendly_name(file) try: - modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name)) + modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name)) except Exception: - print(f"Error loading ESRGAN model: {path}", file=sys.stderr) + print(f"Error loading ESRGAN model: {file}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) -- cgit v1.2.1