From e472383acbb9e07dca311abe5fb16ee2675e410a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:04:33 +0200 Subject: Refactor esrgan_upscale to more generic upscale_with_model --- modules/upscaler_utils.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 modules/upscaler_utils.py (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py new file mode 100644 index 00000000..8bdda51c --- /dev/null +++ b/modules/upscaler_utils.py @@ -0,0 +1,66 @@ +import logging +from typing import Callable + +import numpy as np +import torch +import tqdm +from PIL import Image + +from modules import devices, images + +logger = logging.getLogger(__name__) + + +def upscale_without_tiling(model, img: Image.Image): + img = np.array(img) + img = img[:, :, ::-1] + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(devices.device_esrgan) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def upscale_with_model( + model: Callable[[torch.Tensor], torch.Tensor], + img: Image.Image, + *, + tile_size: int, + tile_overlap: int = 0, + desc="tiled upscale", +) -> Image.Image: + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img) + output = upscale_without_tiling(model, img) + logger.debug("=> %s", output) + return output + + grid = images.split_grid(img, tile_size, tile_size, tile_overlap) + newtiles = [] + + with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + for y, h, row in grid.tiles: + newrow = [] + for x, w, tile in row: + logger.debug("Tile (%d, %d) %s...", x, y, tile) + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + logger.debug("=> %s (scale factor %s)", output, scale_factor) + newrow.append([x * scale_factor, w * scale_factor, output]) + p.update(1) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = images.Grid( + newtiles, + tile_w=grid.tile_w * scale_factor, + tile_h=grid.tile_h * scale_factor, + image_w=grid.image_w * scale_factor, + image_h=grid.image_h * scale_factor, + overlap=grid.overlap * scale_factor, + ) + return images.combine_grid(newgrid) -- cgit v1.2.1 From 8100e901ab0c5b04d289eebb722c8a653b8beef1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 22:41:53 +0300 Subject: fix error with RealESRGAN model failing to upscale fp32 image --- modules/upscaler_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8bdda51c..39f78a0b 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -16,9 +16,13 @@ def upscale_without_tiling(model, img: Image.Image): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) + + model_weight = next(iter(model.parameters())) + img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + with torch.no_grad(): output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255. * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) -- cgit v1.2.1 From 3be90740316f8fbb950b31d440458a5e8ed4beb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 00:43:41 +0300 Subject: fix for the previous fix. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b..dde5d7ad 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -17,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.parameters())) + model_weight = next(iter(model.model.parameters())) img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) with torch.no_grad(): -- cgit v1.2.1 From 777af661a21821994993df3ef566b01df2bb61a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:09:51 +0200 Subject: Be more clear about Spandrel model nomenclature --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index dde5d7ad..174c9bc3 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) -- cgit v1.2.1 From 6f86b62a1be7993073ba3a789d522e0b8870605a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 22:53:49 +0200 Subject: Deduplicate tiled inference code from SwinIR/ScuNET --- modules/upscaler_utils.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 174c9bc3..8e413854 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images +from modules import images, shared logger = logging.getLogger(__name__) @@ -68,3 +68,73 @@ def upscale_with_model( overlap=grid.overlap * scale_factor, ) return images.combine_grid(newgrid) + + +def tiled_upscale_2( + img, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + device, + desc="Tiled upscale", +): + # Alternative implementation of `upscale_with_model` originally used by + # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and + # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in + # Pillow space without weighting. + b, c, h, w = img.size() + tile_size = min(tile_size, h, w) + + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img.shape) + return model(img) + + stride = tile_size - tile_overlap + h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] + w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] + result = torch.zeros( + b, + c, + h * scale, + w * scale, + device=device, + ).type_as(img) + weights = torch.zeros_like(result) + logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + for h_idx in h_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + for w_idx in w_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + in_patch = img[ + ..., + h_idx : h_idx + tile_size, + w_idx : w_idx + tile_size, + ] + out_patch = model(in_patch) + + result[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch) + + out_patch_mask = torch.ones_like(out_patch) + + weights[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch_mask) + + pbar.update(1) + + output = result.div_(weights) + + return output -- cgit v1.2.1 From 5768afc776a66bb94e77a9c1daebeea58fa731d5 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:20:30 +0200 Subject: Add utility to inspect a model's parameters (to get dtype/device) --- modules/upscaler_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8e413854..c60e3beb 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -7,6 +7,7 @@ import tqdm from PIL import Image from modules import images, shared +from modules.torch_utils import get_param logger = logging.getLogger(__name__) @@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.model.parameters())) - img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + param = get_param(model) + img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): output = model(img) -- cgit v1.2.1 From a70dfb64a86b9b6d869deffdb0ffebe980365473 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 22:38:30 +0300 Subject: change import statements for #14478 --- modules/upscaler_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index c60e3beb..f5cb92d5 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,8 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images, shared -from modules.torch_utils import get_param +from modules import images, shared, torch_utils logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - param = get_param(model) + param = torch_utils.get_param(model) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): -- cgit v1.2.1