diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/gfpgan_model.py | 10 | ||||
-rw-r--r-- | modules/modelloader.py | 25 | ||||
-rw-r--r-- | modules/realesrgan_model.py | 4 | ||||
-rw-r--r-- | modules/upscaler_utils.py | 74 |
4 files changed, 94 insertions, 19 deletions
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 48f8ad5e..445b0409 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging
import os
+import torch
+
from modules import (
devices,
errors,
@@ -25,7 +27,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): def get_device(self):
return devices.device_gfpgan
- def load_net(self) -> None:
+ def load_net(self) -> torch.Module:
for model_path in modelloader.load_models(
model_path=self.model_path,
model_url=model_url,
@@ -34,13 +36,13 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): ext_filter=['.pth'],
):
if 'GFPGAN' in os.path.basename(model_path):
- net = modelloader.load_spandrel_model(
+ model = modelloader.load_spandrel_model(
model_path,
device=self.get_device(),
expected_architecture='GFPGAN',
).model
- net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
- return net
+ model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
+ return model
raise ValueError("No GFPGAN model found")
def restore(self, np_image):
diff --git a/modules/modelloader.py b/modules/modelloader.py index 0b89d682..a7194137 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,8 +1,9 @@ from __future__ import annotations +import importlib import logging import os -import importlib +from typing import TYPE_CHECKING from urllib.parse import urlparse import torch @@ -10,6 +11,8 @@ import torch from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone +if TYPE_CHECKING: + import spandrel logger = logging.getLogger(__name__) @@ -140,19 +143,19 @@ def load_spandrel_model( *, device: str | torch.device | None, half: bool = False, - dtype: str | None = None, + dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, -): +) -> spandrel.ModelDescriptor: import spandrel - model = spandrel.ModelLoader(device=device).load_from_file(path) - if expected_architecture and model.architecture != expected_architecture: + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( - f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) if half: - model = model.model.half() + model_descriptor.model.half() if dtype: - model = model.model.to(dtype=dtype) - model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) - return model + model_descriptor.model.to(dtype=dtype) + model_descriptor.model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + return model_descriptor diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 65f2e880..4d35b695 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -36,14 +36,14 @@ class UpscalerRealESRGAN(Upscaler): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
- mod = modelloader.load_spandrel_model(
+ model_descriptor = modelloader.load_spandrel_model(
info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
)
return upscale_with_model(
- mod,
+ model_descriptor,
img,
tile_size=opts.ESRGAN_tile,
tile_overlap=opts.ESRGAN_tile_overlap,
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b..8e413854 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import devices, images +from modules import images, shared logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.parameters())) + model_weight = next(iter(model.model.parameters())) img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) with torch.no_grad(): @@ -68,3 +68,73 @@ def upscale_with_model( overlap=grid.overlap * scale_factor, ) return images.combine_grid(newgrid) + + +def tiled_upscale_2( + img, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + device, + desc="Tiled upscale", +): + # Alternative implementation of `upscale_with_model` originally used by + # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and + # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in + # Pillow space without weighting. + b, c, h, w = img.size() + tile_size = min(tile_size, h, w) + + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img.shape) + return model(img) + + stride = tile_size - tile_overlap + h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] + w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] + result = torch.zeros( + b, + c, + h * scale, + w * scale, + device=device, + ).type_as(img) + weights = torch.zeros_like(result) + logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + for h_idx in h_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + for w_idx in w_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + in_patch = img[ + ..., + h_idx : h_idx + tile_size, + w_idx : w_idx + tile_size, + ] + out_patch = model(in_patch) + + result[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch) + + out_patch_mask = torch.ones_like(out_patch) + + weights[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch_mask) + + pbar.update(1) + + output = result.div_(weights) + + return output |