From 56e557c6ff8a6480887c9c585bf908045ee8e791 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 24 Dec 2022 22:39:00 +0300 Subject: added cheap NN approximation for VAE --- modules/sd_samplers.py | 29 +++++++++++++----------- modules/sd_vae_approx.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 6 ++--- 3 files changed, 77 insertions(+), 16 deletions(-) create mode 100644 modules/sd_vae_approx.py (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 27ef4ff8..177b5338 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -9,7 +9,7 @@ import k_diffusion.sampling import torchsde._brownian.brownian_interval import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing, images +from modules import prompt_parser, devices, processing, images, sd_vae_approx from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -106,28 +106,31 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc -def single_sample_to_image(sample, approximation=False): - if approximation: - # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - coefs = torch.tensor( - [[ 0.298, 0.207, 0.208], - [ 0.187, 0.286, 0.173], - [-0.158, 0.189, 0.264], - [-0.184, -0.271, -0.473]]).to(sample.device) - x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) +approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2} + + +def single_sample_to_image(sample, approximation=None): + if approximation is None: + approximation = approximation_indexes.get(opts.show_progress_type, 0) + + if approximation == 2: + x_sample = sd_vae_approx.cheap_approximation(sample) + elif approximation == 1: + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() else: x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) -def sample_to_image(samples, index=0, approximation=False): +def sample_to_image(samples, index=0, approximation=None): return single_sample_to_image(samples[index], approximation) -def samples_to_image_grid(samples, approximation=False): +def samples_to_image_grid(samples, approximation=None): return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) @@ -136,7 +139,7 @@ def store_latent(decoded): if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: - shared.state.current_image = sample_to_image(decoded, approximation=opts.show_progress_approximate) + shared.state.current_image = sample_to_image(decoded) class InterruptedException(BaseException): diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py new file mode 100644 index 00000000..0a58542d --- /dev/null +++ b/modules/sd_vae_approx.py @@ -0,0 +1,58 @@ +import os + +import torch +from torch import nn +from modules import devices, paths + +sd_vae_approx_model = None + + +class VAEApprox(nn.Module): + def __init__(self): + super(VAEApprox, self).__init__() + self.conv1 = nn.Conv2d(4, 8, (7, 7)) + self.conv2 = nn.Conv2d(8, 16, (5, 5)) + self.conv3 = nn.Conv2d(16, 32, (3, 3)) + self.conv4 = nn.Conv2d(32, 64, (3, 3)) + self.conv5 = nn.Conv2d(64, 32, (3, 3)) + self.conv6 = nn.Conv2d(32, 16, (3, 3)) + self.conv7 = nn.Conv2d(16, 8, (3, 3)) + self.conv8 = nn.Conv2d(8, 3, (3, 3)) + + def forward(self, x): + extra = 11 + x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2)) + x = nn.functional.pad(x, (extra, extra, extra, extra)) + + for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]: + x = layer(x) + x = nn.functional.leaky_relu(x, 0.1) + + return x + + +def model(): + global sd_vae_approx_model + + if sd_vae_approx_model is None: + sd_vae_approx_model = VAEApprox() + sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"))) + sd_vae_approx_model.eval() + sd_vae_approx_model.to(devices.device, devices.dtype) + + return sd_vae_approx_model + + +def cheap_approximation(sample): + # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 + + coefs = torch.tensor([ + [0.298, 0.207, 0.208], + [0.187, 0.286, 0.173], + [-0.158, 0.189, 0.264], + [-0.184, -0.271, -0.473], + ]).to(sample.device) + + x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) + + return x_sample diff --git a/modules/shared.py b/modules/shared.py index eb3e5aec..3cc3c724 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -212,9 +212,9 @@ class State: import modules.sd_samplers if opts.show_progress_grid: - self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent, approximation=opts.show_progress_approximate) + self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) else: - self.current_image = modules.sd_samplers.sample_to_image(self.current_latent, approximation=opts.show_progress_approximate) + self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) self.current_image_sampling_step = self.sampling_step @@ -392,7 +392,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), - "show_progress_approximate": OptionInfo(False, "Calculate small previews using fast linear approximation instead of VAE"), + "show_progress_type": OptionInfo("Full", "Image creation progress mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), -- cgit v1.2.1