From 740070ea9cdb254209f66417418f2a4af8b099d6 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Mon, 26 Sep 2022 09:29:50 -0500 Subject: Re-implement universal model loading --- modules/gfpgan_model.py | 60 ++++++++++++++++++++++++------------------------- 1 file changed, 29 insertions(+), 31 deletions(-) (limited to 'modules/gfpgan_model.py') diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 44c5dc6c..ffb6960d 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -7,33 +7,20 @@ from modules import shared, devices from modules.shared import cmd_opts from modules.paths import script_path import modules.face_restoration +from modules import shared, devices, modelloader +from modules.paths import models_path - -def gfpgan_model_path(): - from modules.shared import cmd_opts - - filemask = 'GFPGAN*.pth' - - if cmd_opts.gfpgan_model is not None: - return cmd_opts.gfpgan_model - - places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')] - - filename = None - for place in places: - filename = next(iter(glob(os.path.join(place, filemask))), None) - if filename is not None: - break - - return filename - +model_dir = "GFPGAN" +cmd_dir = None +model_path = os.path.join(models_path, model_dir) +model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" loaded_gfpgan_model = None def gfpgan(): global loaded_gfpgan_model - + global model_path if loaded_gfpgan_model is not None: loaded_gfpgan_model.gfpgan.to(shared.device) return loaded_gfpgan_model @@ -41,7 +28,15 @@ def gfpgan(): if gfpgan_constructor is None: return None - model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) + models = modelloader.load_models(model_path, model_url, cmd_dir) + if len(models) != 0: + latest_file = max(models, key=os.path.getctime) + model_file = latest_file + else: + print("Unable to load gfpgan model!") + return None + model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2, + bg_upsampler=None) model.gfpgan.to(shared.device) loaded_gfpgan_model = model @@ -50,7 +45,8 @@ def gfpgan(): def gfpgan_fix_faces(np_image): model = gfpgan() - + if model is None: + return np_image np_image_bgr = np_image[:, :, ::-1] cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] @@ -64,19 +60,21 @@ def gfpgan_fix_faces(np_image): have_gfpgan = False gfpgan_constructor = None -def setup_gfpgan(): - try: - gfpgan_model_path() - if os.path.exists(cmd_opts.gfpgan_dir): - sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir)) - from gfpgan import GFPGANer +def setup_model(dirname): + global model_path + if not os.path.exists(model_path): + os.makedirs(model_path) + try: + from modules.gfpgan_model_arch import GFPGANerr + global cmd_dir global have_gfpgan - have_gfpgan = True - global gfpgan_constructor - gfpgan_constructor = GFPGANer + + cmd_dir = dirname + have_gfpgan = True + gfpgan_constructor = GFPGANerr class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): def name(self): -- cgit v1.2.1