From 8285a149d8c488ae6c7a566eb85fb5e825145464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 19:20:11 +0300 Subject: add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them) --- modules/sd_samplers_common.py | 140 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 39586b40..adda963b 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,9 +1,11 @@ -from collections import namedtuple +import inspect +from collections import namedtuple, deque import numpy as np import torch from PIL import Image from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared from modules.shared import opts, state +import k_diffusion.sampling SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -127,3 +129,139 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() + + +class TorchHijack: + def __init__(self, sampler_noises): + # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based + # implementation. + self.sampler_noises = deque(sampler_noises) + + def __getattr__(self, item): + if item == 'randn_like': + return self.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + + def randn_like(self, x): + if self.sampler_noises: + noise = self.sampler_noises.popleft() + if noise.shape == x.shape: + return noise + + return devices.randn_like(x) + + +class Sampler: + def __init__(self, funcname): + self.funcname = funcname + self.func = funcname + self.extra_params = [] + self.sampler_noises = None + self.stop_at = None + self.eta = None + self.config = None # set by the function calling the constructor + self.last_latent = None + self.s_min_uncond = None + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + + self.eta_option_field = 'eta_ancestral' + self.eta_infotext_field = 'Eta' + + self.conditioning_key = shared.sd_model.model.conditioning_key + + self.model_wrap = None + self.model_wrap_cfg = None + + def callback_state(self, d): + step = d['i'] + + if self.stop_at is not None and step > self.stop_at: + raise InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except RecursionError: + print( + 'Encountered RecursionError during sampling, returning last latent. ' + 'rho >5 with a polyexponential scheduler may cause this error. ' + 'You should try to use a smaller rho value instead.' + ) + return self.last_latent + except InterruptedException: + return self.last_latent + + def number_of_needed_noises(self, p): + return p.steps + + def initialize(self, p) -> dict: + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None + self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) + self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) + self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) + + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params[self.eta_infotext_field] = self.eta + + extra_params_kwargs['eta'] = self.eta + + if len(self.extra_params) > 0: + s_churn = getattr(opts, 's_churn', p.s_churn) + s_tmin = getattr(opts, 's_tmin', p.s_tmin) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf + s_noise = getattr(opts, 's_noise', p.s_noise) + + if s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + p.s_churn = s_churn + p.extra_generation_params['Sigma churn'] = s_churn + if s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_tmin + p.s_tmin = s_tmin + p.extra_generation_params['Sigma tmin'] = s_tmin + if s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_tmax + p.s_tmax = s_tmax + p.extra_generation_params['Sigma tmax'] = s_tmax + if s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_noise + p.s_noise = s_noise + p.extra_generation_params['Sigma noise'] = s_noise + + return extra_params_kwargs + + def create_noise_sampler(self, x, sigmas, p): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + if shared.opts.no_dpmpp_sde_batch_determinism: + return None + + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + + + -- cgit v1.2.1