From f299645aeeb65fcddde2d136fd550b6b01ffebb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 4 Sep 2022 18:54:12 +0300 Subject: ESRGAN support --- modules/images.py | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) (limited to 'modules/images.py') diff --git a/modules/images.py b/modules/images.py index 4b9667d2..4226db00 100644 --- a/modules/images.py +++ b/modules/images.py @@ -6,6 +6,7 @@ import re import numpy as np from PIL import Image, ImageFont, ImageDraw, PngImagePlugin +import modules.shared from modules.shared import opts LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -45,20 +46,20 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): cols = math.ceil((w - overlap) / non_overlap_width) rows = math.ceil((h - overlap) / non_overlap_height) - dx = (w - tile_w) // (cols-1) if cols > 1 else 0 - dy = (h - tile_h) // (rows-1) if rows > 1 else 0 + dx = (w - tile_w) / (cols-1) if cols > 1 else 0 + dy = (h - tile_h) / (rows-1) if rows > 1 else 0 grid = Grid([], tile_w, tile_h, w, h, overlap) for row in range(rows): row_images = [] - y = row * dy + y = int(row * dy) if y + tile_h >= h: y = h - tile_h for col in range(cols): - x = col * dx + x = int(col * dx) if x+tile_w >= w: x = w - tile_w @@ -291,3 +292,32 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: file.write(info + "\n") + +class Upscaler: + name = "Lanczos" + + def do_upscale(self, img): + return img + + def upscale(self, img, w, h): + for i in range(3): + if img.width >= w and img.height >= h: + break + + img = self.do_upscale(img) + + if img.width != w or img.height != h: + img = img.resize((w, h), resample=LANCZOS) + + return img + + +class UpscalerNone(Upscaler): + name = "None" + + def upscale(self, img, w, h): + return img + + +modules.shared.sd_upscalers.append(UpscalerNone()) +modules.shared.sd_upscalers.append(Upscaler()) -- cgit v1.2.1