From 40a18d38a8fcb88d1c2947a2653b52cd2085536f Mon Sep 17 00:00:00 2001 From: lambertae Date: Tue, 18 Jul 2023 00:32:01 -0400 Subject: add restart sampler --- modules/sd_samplers_kdiffusion.py | 70 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b76..c63b677c 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,3 +1,5 @@ +# export PIP_CACHE_DIR=/scratch/dengm/cache +# export XDG_CACHE_HOME=/scratch/dengm/cache from collections import deque import torch import inspect @@ -30,12 +32,76 @@ samplers_k_diffusion = [ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('Restart (new)', 'restart_sampler', ['restart'], {'scheduler': 'karras', "second_order": True}), ] + +@torch.no_grad() +def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = {0.1: [10, 2, 2]}): + """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" + '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' + + from tqdm.auto import trange, tqdm + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + step_id = 0 + + from k_diffusion.sampling import to_d, append_zero + + def heun_step(x, old_sigma, new_sigma): + nonlocal step_id + denoised = model(x, old_sigma * s_in, **extra_args) + d = to_d(x, old_sigma, denoised) + if callback is not None: + callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) + dt = new_sigma - old_sigma + if new_sigma == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, new_sigma * s_in, **extra_args) + d_2 = to_d(x_2, new_sigma, denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + step_id += 1 + return x + # print(sigmas) + temp_list = dict() + for key, value in restart_list.items(): + temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value + restart_list = temp_list + + + def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + ramp = torch.linspace(0, 1, n).to(device) + min_inv_rho = (sigma_min ** (1 / rho)) + max_inv_rho = (sigma_max ** (1 / rho)) + if isinstance(min_inv_rho, torch.Tensor): + min_inv_rho = min_inv_rho.to(device) + if isinstance(max_inv_rho, torch.Tensor): + max_inv_rho = max_inv_rho.to(device) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + for i in trange(len(sigmas) - 1, disable=disable): + x = heun_step(x, sigmas[i], sigmas[i+1]) + if i + 1 in restart_list: + restart_steps, restart_times, restart_max = restart_list[i + 1] + min_idx = i + 1 + max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end + for times in range(restart_times): + x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5 + for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]): + x = heun_step(x, old_sigma, new_sigma) + return x + samplers_data_k_diffusion = [ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) for label, funcname, aliases, options in samplers_k_diffusion - if hasattr(k_diffusion.sampling, funcname) + if (hasattr(k_diffusion.sampling, funcname) or funcname == 'restart_sampler') ] sampler_extra_params = { @@ -245,7 +311,7 @@ class KDiffusionSampler: self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) + self.func = getattr(k_diffusion.sampling, self.funcname) if funcname != "restart_sampler" else restart_sampler self.extra_params = sampler_extra_params.get(funcname, []) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None -- cgit v1.2.1 From 15a94d6cf7fa075c09362e73c1239692d021c559 Mon Sep 17 00:00:00 2001 From: lambertae Date: Tue, 18 Jul 2023 00:39:26 -0400 Subject: remove useless header --- modules/sd_samplers_kdiffusion.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index c63b677c..7888d864 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,5 +1,3 @@ -# export PIP_CACHE_DIR=/scratch/dengm/cache -# export XDG_CACHE_HOME=/scratch/dengm/cache from collections import deque import torch import inspect -- cgit v1.2.1 From 37e048a7e2356f4caebfd976351112f03856f082 Mon Sep 17 00:00:00 2001 From: lambertae Date: Tue, 18 Jul 2023 00:55:02 -0400 Subject: fix floating error --- modules/sd_samplers_kdiffusion.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 7888d864..1bb25adf 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -89,11 +89,12 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No restart_steps, restart_times, restart_max = restart_list[i + 1] min_idx = i + 1 max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) - sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end - for times in range(restart_times): - x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5 - for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]): - x = heun_step(x, old_sigma, new_sigma) + if max_idx < min_idx: + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end + for times in range(restart_times): + x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5 + for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]): + x = heun_step(x, old_sigma, new_sigma) return x samplers_data_k_diffusion = [ -- cgit v1.2.1 From 7bb0fbed136c6a345b211e09102659fd89362576 Mon Sep 17 00:00:00 2001 From: lambertae Date: Tue, 18 Jul 2023 01:02:04 -0400 Subject: code styling --- modules/sd_samplers_kdiffusion.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 1bb25adf..db7013f2 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -35,17 +35,15 @@ samplers_k_diffusion = [ @torch.no_grad() -def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = {0.1: [10, 2, 2]}): +def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.): """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' - - from tqdm.auto import trange, tqdm + restart_list = {0.1: [10, 2, 2]} + from tqdm.auto import trange extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) step_id = 0 - from k_diffusion.sampling import to_d, append_zero - def heun_step(x, old_sigma, new_sigma): nonlocal step_id denoised = model(x, old_sigma * s_in, **extra_args) @@ -70,8 +68,6 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No for key, value in restart_list.items(): temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value restart_list = temp_list - - def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): ramp = torch.linspace(0, 1, n).to(device) min_inv_rho = (sigma_min ** (1 / rho)) @@ -82,7 +78,6 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No max_inv_rho = max_inv_rho.to(device) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return append_zero(sigmas).to(device) - for i in trange(len(sigmas) - 1, disable=disable): x = heun_step(x, sigmas[i], sigmas[i+1]) if i + 1 in restart_list: @@ -91,7 +86,8 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) if max_idx < min_idx: sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end - for times in range(restart_times): + while restart_times > 0: + restart_times -= 1 x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5 for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]): x = heun_step(x, old_sigma, new_sigma) -- cgit v1.2.1 From ddbf4a73f5c0cfe63ca0988b8e642d3b977a3fa9 Mon Sep 17 00:00:00 2001 From: lambertae Date: Thu, 20 Jul 2023 02:24:18 -0400 Subject: restart-sampler with correct steps --- modules/sd_samplers_kdiffusion.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index db7013f2..ed5e6c79 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -38,20 +38,19 @@ samplers_k_diffusion = [ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.): """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' - restart_list = {0.1: [10, 2, 2]} from tqdm.auto import trange extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) step_id = 0 from k_diffusion.sampling import to_d, append_zero - def heun_step(x, old_sigma, new_sigma): + def heun_step(x, old_sigma, new_sigma, second_order = True): nonlocal step_id denoised = model(x, old_sigma * s_in, **extra_args) d = to_d(x, old_sigma, denoised) if callback is not None: callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) dt = new_sigma - old_sigma - if new_sigma == 0: + if new_sigma == 0 or not second_order: # Euler method x = x + d * dt else: @@ -63,11 +62,6 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No x = x + d_prime * dt step_id += 1 return x - # print(sigmas) - temp_list = dict() - for key, value in restart_list.items(): - temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value - restart_list = temp_list def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): ramp = torch.linspace(0, 1, n).to(device) min_inv_rho = (sigma_min ** (1 / rho)) @@ -78,6 +72,18 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No max_inv_rho = max_inv_rho.to(device) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return append_zero(sigmas).to(device) + steps = sigmas.shape[0] - 1 + if steps >= 20: + restart_steps = 9 + restart_times = 2 if steps >= 36 else 1 + sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2], sigmas[0], device=sigmas.device) + restart_list = {0.1: [restart_steps + 1, restart_times, 2]} + else: + restart_list = dict() + temp_list = dict() + for key, value in restart_list.items(): + temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value + restart_list = temp_list for i in trange(len(sigmas) - 1, disable=disable): x = heun_step(x, sigmas[i], sigmas[i+1]) if i + 1 in restart_list: -- cgit v1.2.1 From 2f57a559ac3381c1ef2516655c3a3d1088191c54 Mon Sep 17 00:00:00 2001 From: lambertae Date: Thu, 20 Jul 2023 20:34:41 -0400 Subject: allow choise of restart_list & use karras from kdiffusion --- modules/sd_samplers_kdiffusion.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index ed5e6c79..c72d01c8 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -35,14 +35,15 @@ samplers_k_diffusion = [ @torch.no_grad() -def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.): +def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = None): """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' + '''If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list''' from tqdm.auto import trange extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) step_id = 0 - from k_diffusion.sampling import to_d, append_zero + from k_diffusion.sampling import to_d, append_zero, get_sigmas_karras def heun_step(x, old_sigma, new_sigma, second_order = True): nonlocal step_id denoised = model(x, old_sigma * s_in, **extra_args) @@ -62,24 +63,15 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No x = x + d_prime * dt step_id += 1 return x - def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): - ramp = torch.linspace(0, 1, n).to(device) - min_inv_rho = (sigma_min ** (1 / rho)) - max_inv_rho = (sigma_max ** (1 / rho)) - if isinstance(min_inv_rho, torch.Tensor): - min_inv_rho = min_inv_rho.to(device) - if isinstance(max_inv_rho, torch.Tensor): - max_inv_rho = max_inv_rho.to(device) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return append_zero(sigmas).to(device) steps = sigmas.shape[0] - 1 - if steps >= 20: - restart_steps = 9 - restart_times = 2 if steps >= 36 else 1 - sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2], sigmas[0], device=sigmas.device) - restart_list = {0.1: [restart_steps + 1, restart_times, 2]} - else: - restart_list = dict() + if restart_list is None: + if steps >= 20: + restart_steps = 9 + restart_times = 2 if steps >= 36 else 1 + sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) + restart_list = {0.1: [restart_steps + 1, restart_times, 2]} + else: + restart_list = dict() temp_list = dict() for key, value in restart_list.items(): temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value @@ -91,7 +83,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No min_idx = i + 1 max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) if max_idx < min_idx: - sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] # remove the zero at the end while restart_times > 0: restart_times -= 1 x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5 -- cgit v1.2.1 From 128d59c9ccfbc9c7fccd6f1b2fe58bbbb18459f9 Mon Sep 17 00:00:00 2001 From: lambertae Date: Thu, 20 Jul 2023 20:36:40 -0400 Subject: fix ruff --- modules/sd_samplers_kdiffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index c72d01c8..21b347ed 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -43,7 +43,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) step_id = 0 - from k_diffusion.sampling import to_d, append_zero, get_sigmas_karras + from k_diffusion.sampling import to_d, get_sigmas_karras def heun_step(x, old_sigma, new_sigma, second_order = True): nonlocal step_id denoised = model(x, old_sigma * s_in, **extra_args) -- cgit v1.2.1 From f87389029839a27464a18846815339e81787b882 Mon Sep 17 00:00:00 2001 From: lambertae Date: Thu, 20 Jul 2023 21:27:43 -0400 Subject: new restart scheme --- modules/sd_samplers_kdiffusion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 21b347ed..ed60670c 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -67,7 +67,10 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No if restart_list is None: if steps >= 20: restart_steps = 9 - restart_times = 2 if steps >= 36 else 1 + restart_times = 1 + if steps >= 36: + restart_steps = steps // 4 + restart_times = 2 sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) restart_list = {0.1: [restart_steps + 1, restart_times, 2]} else: -- cgit v1.2.1 From 8de6d3ff77e841a5fd9d5f1b16bdd22737c8d657 Mon Sep 17 00:00:00 2001 From: lambertae Date: Tue, 25 Jul 2023 22:35:43 -0400 Subject: fix progress bar & torchHijack --- modules/sd_samplers_kdiffusion.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index ed60670c..7a2427b5 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -79,19 +79,26 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No for key, value in restart_list.items(): temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value restart_list = temp_list - for i in trange(len(sigmas) - 1, disable=disable): - x = heun_step(x, sigmas[i], sigmas[i+1]) + step_list = [] + for i in range(len(sigmas) - 1): + step_list.append((sigmas[i], sigmas[i + 1])) if i + 1 in restart_list: restart_steps, restart_times, restart_max = restart_list[i + 1] min_idx = i + 1 max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) if max_idx < min_idx: - sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] # remove the zero at the end + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] while restart_times > 0: restart_times -= 1 - x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5 - for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:]): - x = heun_step(x, old_sigma, new_sigma) + step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) + last_sigma = None + for i in trange(len(step_list), disable=disable): + if last_sigma is None: + last_sigma = step_list[i][0] + elif last_sigma < step_list[i][0]: + x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (step_list[i][0] ** 2 - last_sigma ** 2) ** 0.5 + x = heun_step(x, step_list[i][0], step_list[i][1]) + last_sigma = step_list[i][1] return x samplers_data_k_diffusion = [ -- cgit v1.2.1 From e1323fc1b70438165bbfd90d80dec00b8b2ab884 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 29 Jul 2023 08:06:03 +0300 Subject: Split history: mv modules/sd_samplers_kdiffusion.py temp --- modules/sd_samplers_kdiffusion.py | 545 -------------------------------------- 1 file changed, 545 deletions(-) delete mode 100644 modules/sd_samplers_kdiffusion.py (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py deleted file mode 100644 index a54673eb..00000000 --- a/modules/sd_samplers_kdiffusion.py +++ /dev/null @@ -1,545 +0,0 @@ -from collections import deque -import torch -import inspect -import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common - -from modules.shared import opts, state -import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback - -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), - ('Restart (new)', 'restart_sampler', ['restart'], {'scheduler': 'karras', "second_order": True}), -] - - -@torch.no_grad() -def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = None): - """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" - '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' - '''If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list''' - from tqdm.auto import trange - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - step_id = 0 - from k_diffusion.sampling import to_d, get_sigmas_karras - def heun_step(x, old_sigma, new_sigma, second_order = True): - nonlocal step_id - denoised = model(x, old_sigma * s_in, **extra_args) - d = to_d(x, old_sigma, denoised) - if callback is not None: - callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) - dt = new_sigma - old_sigma - if new_sigma == 0 or not second_order: - # Euler method - x = x + d * dt - else: - # Heun's method - x_2 = x + d * dt - denoised_2 = model(x_2, new_sigma * s_in, **extra_args) - d_2 = to_d(x_2, new_sigma, denoised_2) - d_prime = (d + d_2) / 2 - x = x + d_prime * dt - step_id += 1 - return x - steps = sigmas.shape[0] - 1 - if restart_list is None: - if steps >= 20: - restart_steps = 9 - restart_times = 1 - if steps >= 36: - restart_steps = steps // 4 - restart_times = 2 - sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) - restart_list = {0.1: [restart_steps + 1, restart_times, 2]} - else: - restart_list = dict() - temp_list = dict() - for key, value in restart_list.items(): - temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value - restart_list = temp_list - step_list = [] - for i in range(len(sigmas) - 1): - step_list.append((sigmas[i], sigmas[i + 1])) - if i + 1 in restart_list: - restart_steps, restart_times, restart_max = restart_list[i + 1] - min_idx = i + 1 - max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) - if max_idx < min_idx: - sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] - while restart_times > 0: - restart_times -= 1 - step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) - last_sigma = None - for i in trange(len(step_list), disable=disable): - if last_sigma is None: - last_sigma = step_list[i][0] - elif last_sigma < step_list[i][0]: - x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (step_list[i][0] ** 2 - last_sigma ** 2) ** 0.5 - x = heun_step(x, step_list[i][0], step_list[i][1]) - last_sigma = step_list[i][1] - return x - -samplers_data_k_diffusion = [ - sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if (hasattr(k_diffusion.sampling, funcname) or funcname == 'restart_sampler') -] - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - -k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} -k_diffusion_scheduler = { - 'Automatic': None, - 'karras': k_diffusion.sampling.get_sigmas_karras, - 'exponential': k_diffusion.sampling.get_sigmas_exponential, - 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential -} - - -def catenate_conds(conds): - if not isinstance(conds[0], dict): - return torch.cat(conds) - - return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} - - -def subscript_cond(cond, a, b): - if not isinstance(cond, dict): - return cond[a:b] - - return {key: vec[a:b] for key, vec in cond.items()} - - -def pad_cond(tensor, repeats, empty): - if not isinstance(tensor, dict): - return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) - - tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) - return tensor - - -class CFGDenoiser(torch.nn.Module): - """ - Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) - that can take a noisy picture and produce a noise-free picture using two guidances (prompts) - instead of one. Originally, the second prompt is just an empty string, but we use non-empty - negative prompt. - """ - - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - self.image_cfg_scale = None - self.padded_cond_uncond = False - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def combine_denoised_for_edit_model(self, x_out, cond_scale): - out_cond, out_img_cond, out_uncond = x_out.chunk(3) - denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, - # so is_edit_model is set to False to support AND composition. - is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - if shared.sd_model.model.conditioning_key == "crossattn-adm": - image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} - else: - image_uncond = image_cond - if isinstance(uncond, dict): - make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} - else: - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} - - if not is_edit_model: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) - else: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - tensor = denoiser_params.text_cond - uncond = denoiser_params.text_uncond - skip_uncond = False - - # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it - if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: - skip_uncond = True - x_in = x_in[:-batch_size] - sigma_in = sigma_in[:-batch_size] - - self.padded_cond_uncond = False - if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: - empty = shared.sd_model.cond_stage_model_empty_prompt - num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] - - if num_repeats < 0: - tensor = pad_cond(tensor, -num_repeats, empty) - self.padded_cond_uncond = True - elif num_repeats > 0: - uncond = pad_cond(uncond, num_repeats, empty) - self.padded_cond_uncond = True - - if tensor.shape[1] == uncond.shape[1] or skip_uncond: - if is_edit_model: - cond_in = catenate_conds([tensor, uncond, uncond]) - elif skip_uncond: - cond_in = tensor - else: - cond_in = catenate_conds([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - - if not is_edit_model: - c_crossattn = subscript_cond(tensor, a, b) - else: - c_crossattn = torch.cat([tensor[a:b]], uncond) - - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) - - if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) - - denoised_image_indexes = [x[0][0] for x in conds_list] - if skip_uncond: - fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) - x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be - - denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) - cfg_denoised_callback(denoised_params) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes])) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - if is_edit_model: - denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) - elif skip_uncond: - denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) - else: - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) - cfg_after_cfg_callback(after_cfg_callback_params) - denoised = after_cfg_callback_params.x - - self.step += 1 - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - if opts.randn_source == "CPU" or x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) if funcname != "restart_sampler" else restart_sampler - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.config = None # set by the function calling the constructor - self.last_latent = None - self.s_min_uncond = None - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except RecursionError: - print( - 'Encountered RecursionError during sampling, returning last latent. ' - 'rho >5 with a polyexponential scheduler may cause this error. ' - 'You should try to use a smaller rho value instead.' - ) - return self.last_latent - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) - self.eta = p.eta if p.eta is not None else opts.eta_ancestral - self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - if self.eta != 1.0: - p.extra_generation_params["Eta"] = self.eta - - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif opts.k_sched_type != "Automatic": - m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) - sigmas_kwargs = { - 'sigma_min': sigma_min, - 'sigma_max': sigma_max, - } - - sigmas_func = k_diffusion_scheduler[opts.k_sched_type] - p.extra_generation_params["Schedule type"] = opts.k_sched_type - - if opts.sigma_min != m_sigma_min and opts.sigma_min != 0: - sigmas_kwargs['sigma_min'] = opts.sigma_min - p.extra_generation_params["Schedule min sigma"] = opts.sigma_min - if opts.sigma_max != m_sigma_max and opts.sigma_max != 0: - sigmas_kwargs['sigma_max'] = opts.sigma_max - p.extra_generation_params["Schedule max sigma"] = opts.sigma_max - - default_rho = 1. if opts.k_sched_type == "polyexponential" else 7. - - if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho: - sigmas_kwargs['rho'] = opts.rho - p.extra_generation_params["Schedule rho"] = opts.rho - - sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def create_noise_sampler(self, x, sigmas, p): - """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" - if shared.opts.no_dpmpp_sde_batch_determinism: - return None - - from k_diffusion.sampling import BrownianTreeNoiseSampler - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] - return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - parameters = inspect.signature(self.func).parameters - - if 'sigma_min' in parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - if self.config.options.get('brownian_noise', False): - noise_sampler = self.create_noise_sampler(x, sigmas, p) - extra_params_kwargs['noise_sampler'] = noise_sampler - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - extra_args = { - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale, - 's_min_uncond': self.s_min_uncond - } - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - if self.model_wrap_cfg.padded_cond_uncond: - p.extra_generation_params["Pad conds"] = True - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - parameters = inspect.signature(self.func).parameters - - if 'sigma_min' in parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - if self.config.options.get('brownian_noise', False): - noise_sampler = self.create_noise_sampler(x, sigmas, p) - extra_params_kwargs['noise_sampler'] = noise_sampler - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale, - 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - if self.model_wrap_cfg.padded_cond_uncond: - p.extra_generation_params["Pad conds"] = True - - return samples - -- cgit v1.2.1 From 11dc92dc0ae8fe477539583e6a743e57a7cd69ad Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 29 Jul 2023 08:06:04 +0300 Subject: Split history: mv temp modules/sd_samplers_kdiffusion.py --- modules/sd_samplers_kdiffusion.py | 545 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 545 insertions(+) create mode 100644 modules/sd_samplers_kdiffusion.py (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py new file mode 100644 index 00000000..a54673eb --- /dev/null +++ b/modules/sd_samplers_kdiffusion.py @@ -0,0 +1,545 @@ +from collections import deque +import torch +import inspect +import k_diffusion.sampling +from modules import prompt_parser, devices, sd_samplers_common + +from modules.shared import opts, state +import modules.shared as shared +from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback +from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback + +samplers_k_diffusion = [ + ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), + ('Euler', 'sample_euler', ['k_euler'], {}), + ('LMS', 'sample_lms', ['k_lms'], {}), + ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), + ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), + ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), + ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), + ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), + ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), + ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('Restart (new)', 'restart_sampler', ['restart'], {'scheduler': 'karras', "second_order": True}), +] + + +@torch.no_grad() +def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = None): + """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" + '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' + '''If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list''' + from tqdm.auto import trange + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + step_id = 0 + from k_diffusion.sampling import to_d, get_sigmas_karras + def heun_step(x, old_sigma, new_sigma, second_order = True): + nonlocal step_id + denoised = model(x, old_sigma * s_in, **extra_args) + d = to_d(x, old_sigma, denoised) + if callback is not None: + callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) + dt = new_sigma - old_sigma + if new_sigma == 0 or not second_order: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, new_sigma * s_in, **extra_args) + d_2 = to_d(x_2, new_sigma, denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + step_id += 1 + return x + steps = sigmas.shape[0] - 1 + if restart_list is None: + if steps >= 20: + restart_steps = 9 + restart_times = 1 + if steps >= 36: + restart_steps = steps // 4 + restart_times = 2 + sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) + restart_list = {0.1: [restart_steps + 1, restart_times, 2]} + else: + restart_list = dict() + temp_list = dict() + for key, value in restart_list.items(): + temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value + restart_list = temp_list + step_list = [] + for i in range(len(sigmas) - 1): + step_list.append((sigmas[i], sigmas[i + 1])) + if i + 1 in restart_list: + restart_steps, restart_times, restart_max = restart_list[i + 1] + min_idx = i + 1 + max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) + if max_idx < min_idx: + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] + while restart_times > 0: + restart_times -= 1 + step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) + last_sigma = None + for i in trange(len(step_list), disable=disable): + if last_sigma is None: + last_sigma = step_list[i][0] + elif last_sigma < step_list[i][0]: + x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (step_list[i][0] ** 2 - last_sigma ** 2) ** 0.5 + x = heun_step(x, step_list[i][0], step_list[i][1]) + last_sigma = step_list[i][1] + return x + +samplers_data_k_diffusion = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_k_diffusion + if (hasattr(k_diffusion.sampling, funcname) or funcname == 'restart_sampler') +] + +sampler_extra_params = { + 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], +} + +k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} +k_diffusion_scheduler = { + 'Automatic': None, + 'karras': k_diffusion.sampling.get_sigmas_karras, + 'exponential': k_diffusion.sampling.get_sigmas_exponential, + 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential +} + + +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + +class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, model): + super().__init__() + self.inner_model = model + self.mask = None + self.nmask = None + self.init_latent = None + self.step = 0 + self.image_cfg_scale = None + self.padded_cond_uncond = False + + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) + + return denoised + + def combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) + + return denoised + + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, + # so is_edit_model is set to False to support AND composition. + is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + + assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + if shared.sd_model.model.conditioning_key == "crossattn-adm": + image_uncond = torch.zeros_like(image_cond) + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} + else: + image_uncond = image_cond + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} + + if not is_edit_model: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) + else: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) + + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) + cfg_denoiser_callback(denoiser_params) + x_in = denoiser_params.x + image_cond_in = denoiser_params.image_cond + sigma_in = denoiser_params.sigma + tensor = denoiser_params.text_cond + uncond = denoiser_params.text_uncond + skip_uncond = False + + # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it + if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: + skip_uncond = True + x_in = x_in[:-batch_size] + sigma_in = sigma_in[:-batch_size] + + self.padded_cond_uncond = False + if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: + empty = shared.sd_model.cond_stage_model_empty_prompt + num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] + + if num_repeats < 0: + tensor = pad_cond(tensor, -num_repeats, empty) + self.padded_cond_uncond = True + elif num_repeats > 0: + uncond = pad_cond(uncond, num_repeats, empty) + self.padded_cond_uncond = True + + if tensor.shape[1] == uncond.shape[1] or skip_uncond: + if is_edit_model: + cond_in = catenate_conds([tensor, uncond, uncond]) + elif skip_uncond: + cond_in = tensor + else: + cond_in = catenate_conds([tensor, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) + else: + x_out = torch.zeros_like(x_in) + batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): + a = batch_offset + b = min(a + batch_size, tensor.shape[0]) + + if not is_edit_model: + c_crossattn = subscript_cond(tensor, a, b) + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) + + if not skip_uncond: + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) + + denoised_image_indexes = [x[0][0] for x in conds_list] + if skip_uncond: + fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) + x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be + + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) + cfg_denoised_callback(denoised_params) + + devices.test_for_nans(x_out, "unet") + + if opts.live_preview_content == "Prompt": + sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes])) + elif opts.live_preview_content == "Negative prompt": + sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) + + if is_edit_model: + denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) + elif skip_uncond: + denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) + else: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + + if self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + + after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) + cfg_after_cfg_callback(after_cfg_callback_params) + denoised = after_cfg_callback_params.x + + self.step += 1 + return denoised + + +class TorchHijack: + def __init__(self, sampler_noises): + # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based + # implementation. + self.sampler_noises = deque(sampler_noises) + + def __getattr__(self, item): + if item == 'randn_like': + return self.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + + def randn_like(self, x): + if self.sampler_noises: + noise = self.sampler_noises.popleft() + if noise.shape == x.shape: + return noise + + if opts.randn_source == "CPU" or x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + else: + return torch.randn_like(x) + + +class KDiffusionSampler: + def __init__(self, funcname, sd_model): + denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + + self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) + self.funcname = funcname + self.func = getattr(k_diffusion.sampling, self.funcname) if funcname != "restart_sampler" else restart_sampler + self.extra_params = sampler_extra_params.get(funcname, []) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.sampler_noises = None + self.stop_at = None + self.eta = None + self.config = None # set by the function calling the constructor + self.last_latent = None + self.s_min_uncond = None + + self.conditioning_key = sd_model.model.conditioning_key + + def callback_state(self, d): + step = d['i'] + latent = d["denoised"] + if opts.live_preview_content == "Combined": + sd_samplers_common.store_latent(latent) + self.last_latent = latent + + if self.stop_at is not None and step > self.stop_at: + raise sd_samplers_common.InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except RecursionError: + print( + 'Encountered RecursionError during sampling, returning last latent. ' + 'rho >5 with a polyexponential scheduler may cause this error. ' + 'You should try to use a smaller rho value instead.' + ) + return self.last_latent + except sd_samplers_common.InterruptedException: + return self.last_latent + + def number_of_needed_noises(self, p): + return p.steps + + def initialize(self, p): + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None + self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) + self.eta = p.eta if p.eta is not None else opts.eta_ancestral + self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) + + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params["Eta"] = self.eta + + extra_params_kwargs['eta'] = self.eta + + return extra_params_kwargs + + def get_sigmas(self, p, steps): + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: + discard_next_to_last_sigma = True + p.extra_generation_params["Discard penultimate sigma"] = True + + steps += 1 if discard_next_to_last_sigma else 0 + + if p.sampler_noise_scheduler_override: + sigmas = p.sampler_noise_scheduler_override(steps) + elif opts.k_sched_type != "Automatic": + m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) + sigmas_kwargs = { + 'sigma_min': sigma_min, + 'sigma_max': sigma_max, + } + + sigmas_func = k_diffusion_scheduler[opts.k_sched_type] + p.extra_generation_params["Schedule type"] = opts.k_sched_type + + if opts.sigma_min != m_sigma_min and opts.sigma_min != 0: + sigmas_kwargs['sigma_min'] = opts.sigma_min + p.extra_generation_params["Schedule min sigma"] = opts.sigma_min + if opts.sigma_max != m_sigma_max and opts.sigma_max != 0: + sigmas_kwargs['sigma_max'] = opts.sigma_max + p.extra_generation_params["Schedule max sigma"] = opts.sigma_max + + default_rho = 1. if opts.k_sched_type == "polyexponential" else 7. + + if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho: + sigmas_kwargs['rho'] = opts.rho + p.extra_generation_params["Schedule rho"] = opts.rho + + sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) + else: + sigmas = self.model_wrap.get_sigmas(steps) + + if discard_next_to_last_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + return sigmas + + def create_noise_sampler(self, x, sigmas, p): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + if shared.opts.no_dpmpp_sde_batch_determinism: + return None + + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + + sigmas = self.get_sigmas(p, steps) + + sigma_sched = sigmas[steps - t_enc - 1:] + xi = x + noise * sigma_sched[0] + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'sigma_min' in parameters: + ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last + extra_params_kwargs['sigma_min'] = sigma_sched[-2] + if 'sigma_max' in parameters: + extra_params_kwargs['sigma_max'] = sigma_sched[0] + if 'n' in parameters: + extra_params_kwargs['n'] = len(sigma_sched) - 1 + if 'sigma_sched' in parameters: + extra_params_kwargs['sigma_sched'] = sigma_sched + if 'sigmas' in parameters: + extra_params_kwargs['sigmas'] = sigma_sched + + if self.config.options.get('brownian_noise', False): + noise_sampler = self.create_noise_sampler(x, sigmas, p) + extra_params_kwargs['noise_sampler'] = noise_sampler + + self.model_wrap_cfg.init_latent = x + self.last_latent = x + extra_args = { + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + } + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps = steps or p.steps + + sigmas = self.get_sigmas(p, steps) + + x = x * sigmas[0] + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'sigma_min' in parameters: + extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() + extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() + if 'n' in parameters: + extra_params_kwargs['n'] = steps + else: + extra_params_kwargs['sigmas'] = sigmas + + if self.config.options.get('brownian_noise', False): + noise_sampler = self.create_noise_sampler(x, sigmas, p) + extra_params_kwargs['noise_sampler'] = noise_sampler + + self.last_latent = x + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + -- cgit v1.2.1 From aefe1325df60a925b3a75a2cb58bf74e8ca86df4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 29 Jul 2023 08:11:59 +0300 Subject: split the new sampler into a different file --- modules/sd_samplers_kdiffusion.py | 75 +++------------------------------------ 1 file changed, 4 insertions(+), 71 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index a54673eb..e0da3425 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra from modules.shared import opts, state import modules.shared as shared @@ -30,81 +30,14 @@ samplers_k_diffusion = [ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), - ('Restart (new)', 'restart_sampler', ['restart'], {'scheduler': 'karras', "second_order": True}), + ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), ] -@torch.no_grad() -def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = None): - """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" - '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' - '''If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list''' - from tqdm.auto import trange - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - step_id = 0 - from k_diffusion.sampling import to_d, get_sigmas_karras - def heun_step(x, old_sigma, new_sigma, second_order = True): - nonlocal step_id - denoised = model(x, old_sigma * s_in, **extra_args) - d = to_d(x, old_sigma, denoised) - if callback is not None: - callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) - dt = new_sigma - old_sigma - if new_sigma == 0 or not second_order: - # Euler method - x = x + d * dt - else: - # Heun's method - x_2 = x + d * dt - denoised_2 = model(x_2, new_sigma * s_in, **extra_args) - d_2 = to_d(x_2, new_sigma, denoised_2) - d_prime = (d + d_2) / 2 - x = x + d_prime * dt - step_id += 1 - return x - steps = sigmas.shape[0] - 1 - if restart_list is None: - if steps >= 20: - restart_steps = 9 - restart_times = 1 - if steps >= 36: - restart_steps = steps // 4 - restart_times = 2 - sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) - restart_list = {0.1: [restart_steps + 1, restart_times, 2]} - else: - restart_list = dict() - temp_list = dict() - for key, value in restart_list.items(): - temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value - restart_list = temp_list - step_list = [] - for i in range(len(sigmas) - 1): - step_list.append((sigmas[i], sigmas[i + 1])) - if i + 1 in restart_list: - restart_steps, restart_times, restart_max = restart_list[i + 1] - min_idx = i + 1 - max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) - if max_idx < min_idx: - sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] - while restart_times > 0: - restart_times -= 1 - step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) - last_sigma = None - for i in trange(len(step_list), disable=disable): - if last_sigma is None: - last_sigma = step_list[i][0] - elif last_sigma < step_list[i][0]: - x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (step_list[i][0] ** 2 - last_sigma ** 2) ** 0.5 - x = heun_step(x, step_list[i][0], step_list[i][1]) - last_sigma = step_list[i][1] - return x - samplers_data_k_diffusion = [ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) for label, funcname, aliases, options in samplers_k_diffusion - if (hasattr(k_diffusion.sampling, funcname) or funcname == 'restart_sampler') + if callable(funcname) or hasattr(k_diffusion.sampling, funcname) ] sampler_extra_params = { @@ -339,7 +272,7 @@ class KDiffusionSampler: self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) if funcname != "restart_sampler" else restart_sampler + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None -- cgit v1.2.1 From 84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 3 Aug 2023 00:00:23 +0300 Subject: add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards. --- modules/sd_samplers_kdiffusion.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e0da3425..d72c1b5f 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -260,10 +260,7 @@ class TorchHijack: if noise.shape == x.shape: return noise - if opts.randn_source == "CPU" or x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) + return devices.randn_like(x) class KDiffusionSampler: -- cgit v1.2.1 From 3bd2c68eb4df7630b88b3b05e1e11b7a4ad40185 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 00:51:49 -0400 Subject: Add exponential scheduler for DPM-Solver++(2M) SDE Better quality results than Karras. Related discussion: https://gist.github.com/crowsonkb/3ed16fba35c73ece7cf4b9a2095f2b78 --- modules/sd_samplers_kdiffusion.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index d72c1b5f..8bb639f5 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -30,6 +30,7 @@ samplers_k_diffusion = [ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), ] @@ -375,6 +376,9 @@ class KDiffusionSampler: sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) + elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential': + m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) -- cgit v1.2.1 From 31506f07718803190e67cbbd8180af313d9e2a08 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 5 Aug 2023 22:37:25 -0400 Subject: Add sigma params to infotext --- modules/sd_samplers_kdiffusion.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 8bb639f5..6f46da7c 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -4,6 +4,7 @@ import inspect import k_diffusion.sampling from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra +from modules.processing import StableDiffusionProcessing from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback @@ -280,6 +281,14 @@ class KDiffusionSampler: self.last_latent = None self.s_min_uncond = None + # NOTE: These are also defined in the StableDiffusionProcessing class. + # They should have been here to begin with but we're going to + # leave that class __init__ signature alone. + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + self.conditioning_key = sd_model.model.conditioning_key def callback_state(self, d): @@ -314,7 +323,7 @@ class KDiffusionSampler: def number_of_needed_noises(self, p): return p.steps - def initialize(self, p): + def initialize(self, p: StableDiffusionProcessing): self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 @@ -335,6 +344,29 @@ class KDiffusionSampler: extra_params_kwargs['eta'] = self.eta + if len(self.extra_params) > 0: + s_churn = p.override_settings.get('s_churn', p.s_churn) + s_tmin = p.override_settings.get('s_tmin', p.s_tmin) + s_tmax = p.override_settings.get('s_tmax', p.s_tmax) + s_noise = p.override_settings.get('s_noise', p.s_noise) + + if s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + p.s_churn = s_churn + p.extra_generation_params['Sigma churn'] = s_churn + if s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_churn + p.s_tmin = s_tmin + p.extra_generation_params['Sigma tmin'] = s_tmin + if s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_churn + p.s_tmax = s_tmax + p.extra_generation_params['Sigma tmax'] = s_tmax + if s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_churn + p.s_noise = s_noise + p.extra_generation_params['Sigma noise'] = s_noise + return extra_params_kwargs def get_sigmas(self, p, steps): -- cgit v1.2.1 From ce4be668fe1d2192378a7b4d3173cfd18031a017 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 5 Aug 2023 23:42:20 -0400 Subject: Read kdiffusion sigma params from opts --- modules/sd_samplers_kdiffusion.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 6f46da7c..0a0a3474 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -345,25 +345,25 @@ class KDiffusionSampler: extra_params_kwargs['eta'] = self.eta if len(self.extra_params) > 0: - s_churn = p.override_settings.get('s_churn', p.s_churn) - s_tmin = p.override_settings.get('s_tmin', p.s_tmin) - s_tmax = p.override_settings.get('s_tmax', p.s_tmax) - s_noise = p.override_settings.get('s_noise', p.s_noise) + s_churn = getattr(opts, 's_churn', p.s_churn) + s_tmin = getattr(opts, 's_tmin', p.s_tmin) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) + s_noise = getattr(opts, 's_noise', p.s_noise) if s_churn != self.s_churn: extra_params_kwargs['s_churn'] = s_churn p.s_churn = s_churn p.extra_generation_params['Sigma churn'] = s_churn if s_tmin != self.s_tmin: - extra_params_kwargs['s_tmin'] = s_churn + extra_params_kwargs['s_tmin'] = s_tmin p.s_tmin = s_tmin p.extra_generation_params['Sigma tmin'] = s_tmin if s_tmax != self.s_tmax: - extra_params_kwargs['s_tmax'] = s_churn + extra_params_kwargs['s_tmax'] = s_tmax p.s_tmax = s_tmax p.extra_generation_params['Sigma tmax'] = s_tmax if s_noise != self.s_noise: - extra_params_kwargs['s_noise'] = s_churn + extra_params_kwargs['s_noise'] = s_noise p.s_noise = s_noise p.extra_generation_params['Sigma noise'] = s_noise -- cgit v1.2.1 From 8f31b139b8898921853f686c313ba54c57ad7cb8 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 5 Aug 2023 23:50:33 -0400 Subject: Assume 0 = inf for s_tmax --- modules/sd_samplers_kdiffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 0a0a3474..db71a549 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -347,7 +347,7 @@ class KDiffusionSampler: if len(self.extra_params) > 0: s_churn = getattr(opts, 's_churn', p.s_churn) s_tmin = getattr(opts, 's_tmin', p.s_tmin) - s_tmax = getattr(opts, 's_tmax', p.s_tmax) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf s_noise = getattr(opts, 's_noise', p.s_noise) if s_churn != self.s_churn: -- cgit v1.2.1 From f1975b0213f5be400889ec04b3891d1cb571fe20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:01:07 +0300 Subject: initial refiner support --- modules/sd_samplers_kdiffusion.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index db71a549..be1bd35e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra, sd_models from modules.processing import StableDiffusionProcessing from modules.shared import opts, state @@ -87,15 +87,25 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model): + def __init__(self): super().__init__() - self.inner_model = model + self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None + self.steps = None self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False + self.p = None + + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -113,10 +123,15 @@ class CFGDenoiser(torch.nn.Module): return denoised + def update_inner_model(self): + self.model_wrap = None + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + sd_samplers_common.apply_refiner(self) + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 @@ -267,13 +282,13 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) + self.p = None self.funcname = funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.model_wrap_cfg = CFGDenoiser() + self.model_wrap = self.model_wrap_cfg.inner_model self.sampler_noises = None self.stop_at = None self.eta = None @@ -305,6 +320,7 @@ class KDiffusionSampler: shared.total_tqdm.update() def launch_sampling(self, steps, func): + self.model_wrap_cfg.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -324,6 +340,8 @@ class KDiffusionSampler: return p.steps def initialize(self, p: StableDiffusionProcessing): + self.p = p + self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 -- cgit v1.2.1 From 956e69bf3a9387b6a6cda6823d729e3d3f13c3e1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:07:08 +0300 Subject: lint! --- modules/sd_samplers_kdiffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index be1bd35e..3aee4e3a 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra, sd_models +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra from modules.processing import StableDiffusionProcessing from modules.shared import opts, state -- cgit v1.2.1 From 5a0db84b6c7322082c7532df11a29a95a59a612b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:53:33 +0300 Subject: add infotext add proper support for recalculating conds in k-diffusion samplers remove support for compvis samplers --- modules/sd_samplers_kdiffusion.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3aee4e3a..46da0a97 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -87,8 +87,9 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self): + def __init__(self, sampler): super().__init__() + self.sampler = sampler self.model_wrap = None self.mask = None self.nmask = None @@ -126,11 +127,17 @@ class CFGDenoiser(torch.nn.Module): def update_inner_model(self): self.model_wrap = None + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - sd_samplers_common.apply_refiner(self) + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. @@ -282,12 +289,12 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - self.p = None self.funcname = funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser() + self.sampler_extra_args = {} + self.model_wrap_cfg = CFGDenoiser(self) self.model_wrap = self.model_wrap_cfg.inner_model self.sampler_noises = None self.stop_at = None @@ -476,7 +483,7 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -484,7 +491,7 @@ class KDiffusionSampler: 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -514,13 +521,14 @@ class KDiffusionSampler: extra_params_kwargs['noise_sampler'] = noise_sampler self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True -- cgit v1.2.1 From a3e27019e44e8f357181992e510f989ce59b992f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 18:32:17 +0300 Subject: Split history: mv modules/sd_samplers_kdiffusion.py temp --- modules/sd_samplers_kdiffusion.py | 511 -------------------------------------- 1 file changed, 511 deletions(-) delete mode 100644 modules/sd_samplers_kdiffusion.py (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py deleted file mode 100644 index db71a549..00000000 --- a/modules/sd_samplers_kdiffusion.py +++ /dev/null @@ -1,511 +0,0 @@ -from collections import deque -import torch -import inspect -import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra - -from modules.processing import StableDiffusionProcessing -from modules.shared import opts, state -import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback - -samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), - ('Euler', 'sample_euler', ['k_euler'], {}), - ('LMS', 'sample_lms', ['k_lms'], {}), - ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), - ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), - ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), - ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), - ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), -] - - -samplers_data_k_diffusion = [ - sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) - for label, funcname, aliases, options in samplers_k_diffusion - if callable(funcname) or hasattr(k_diffusion.sampling, funcname) -] - -sampler_extra_params = { - 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], -} - -k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} -k_diffusion_scheduler = { - 'Automatic': None, - 'karras': k_diffusion.sampling.get_sigmas_karras, - 'exponential': k_diffusion.sampling.get_sigmas_exponential, - 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential -} - - -def catenate_conds(conds): - if not isinstance(conds[0], dict): - return torch.cat(conds) - - return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} - - -def subscript_cond(cond, a, b): - if not isinstance(cond, dict): - return cond[a:b] - - return {key: vec[a:b] for key, vec in cond.items()} - - -def pad_cond(tensor, repeats, empty): - if not isinstance(tensor, dict): - return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) - - tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) - return tensor - - -class CFGDenoiser(torch.nn.Module): - """ - Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) - that can take a noisy picture and produce a noise-free picture using two guidances (prompts) - instead of one. Originally, the second prompt is just an empty string, but we use non-empty - negative prompt. - """ - - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - self.image_cfg_scale = None - self.padded_cond_uncond = False - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def combine_denoised_for_edit_model(self, x_out, cond_scale): - out_cond, out_img_cond, out_uncond = x_out.chunk(3) - denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, - # so is_edit_model is set to False to support AND composition. - is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - if shared.sd_model.model.conditioning_key == "crossattn-adm": - image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} - else: - image_uncond = image_cond - if isinstance(uncond, dict): - make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} - else: - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} - - if not is_edit_model: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) - else: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - tensor = denoiser_params.text_cond - uncond = denoiser_params.text_uncond - skip_uncond = False - - # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it - if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: - skip_uncond = True - x_in = x_in[:-batch_size] - sigma_in = sigma_in[:-batch_size] - - self.padded_cond_uncond = False - if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: - empty = shared.sd_model.cond_stage_model_empty_prompt - num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] - - if num_repeats < 0: - tensor = pad_cond(tensor, -num_repeats, empty) - self.padded_cond_uncond = True - elif num_repeats > 0: - uncond = pad_cond(uncond, num_repeats, empty) - self.padded_cond_uncond = True - - if tensor.shape[1] == uncond.shape[1] or skip_uncond: - if is_edit_model: - cond_in = catenate_conds([tensor, uncond, uncond]) - elif skip_uncond: - cond_in = tensor - else: - cond_in = catenate_conds([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - - if not is_edit_model: - c_crossattn = subscript_cond(tensor, a, b) - else: - c_crossattn = torch.cat([tensor[a:b]], uncond) - - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) - - if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) - - denoised_image_indexes = [x[0][0] for x in conds_list] - if skip_uncond: - fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) - x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be - - denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) - cfg_denoised_callback(denoised_params) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes])) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - if is_edit_model: - denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) - elif skip_uncond: - denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) - else: - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) - cfg_after_cfg_callback(after_cfg_callback_params) - denoised = after_cfg_callback_params.x - - self.step += 1 - return denoised - - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - return devices.randn_like(x) - - -class KDiffusionSampler: - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.config = None # set by the function calling the constructor - self.last_latent = None - self.s_min_uncond = None - - # NOTE: These are also defined in the StableDiffusionProcessing class. - # They should have been here to begin with but we're going to - # leave that class __init__ signature alone. - self.s_churn = 0.0 - self.s_tmin = 0.0 - self.s_tmax = float('inf') - self.s_noise = 1.0 - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except RecursionError: - print( - 'Encountered RecursionError during sampling, returning last latent. ' - 'rho >5 with a polyexponential scheduler may cause this error. ' - 'You should try to use a smaller rho value instead.' - ) - return self.last_latent - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p: StableDiffusionProcessing): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) - self.eta = p.eta if p.eta is not None else opts.eta_ancestral - self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - if self.eta != 1.0: - p.extra_generation_params["Eta"] = self.eta - - extra_params_kwargs['eta'] = self.eta - - if len(self.extra_params) > 0: - s_churn = getattr(opts, 's_churn', p.s_churn) - s_tmin = getattr(opts, 's_tmin', p.s_tmin) - s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf - s_noise = getattr(opts, 's_noise', p.s_noise) - - if s_churn != self.s_churn: - extra_params_kwargs['s_churn'] = s_churn - p.s_churn = s_churn - p.extra_generation_params['Sigma churn'] = s_churn - if s_tmin != self.s_tmin: - extra_params_kwargs['s_tmin'] = s_tmin - p.s_tmin = s_tmin - p.extra_generation_params['Sigma tmin'] = s_tmin - if s_tmax != self.s_tmax: - extra_params_kwargs['s_tmax'] = s_tmax - p.s_tmax = s_tmax - p.extra_generation_params['Sigma tmax'] = s_tmax - if s_noise != self.s_noise: - extra_params_kwargs['s_noise'] = s_noise - p.s_noise = s_noise - p.extra_generation_params['Sigma noise'] = s_noise - - return extra_params_kwargs - - def get_sigmas(self, p, steps): - discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) - if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: - discard_next_to_last_sigma = True - p.extra_generation_params["Discard penultimate sigma"] = True - - steps += 1 if discard_next_to_last_sigma else 0 - - if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif opts.k_sched_type != "Automatic": - m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) - sigmas_kwargs = { - 'sigma_min': sigma_min, - 'sigma_max': sigma_max, - } - - sigmas_func = k_diffusion_scheduler[opts.k_sched_type] - p.extra_generation_params["Schedule type"] = opts.k_sched_type - - if opts.sigma_min != m_sigma_min and opts.sigma_min != 0: - sigmas_kwargs['sigma_min'] = opts.sigma_min - p.extra_generation_params["Schedule min sigma"] = opts.sigma_min - if opts.sigma_max != m_sigma_max and opts.sigma_max != 0: - sigmas_kwargs['sigma_max'] = opts.sigma_max - p.extra_generation_params["Schedule max sigma"] = opts.sigma_max - - default_rho = 1. if opts.k_sched_type == "polyexponential" else 7. - - if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho: - sigmas_kwargs['rho'] = opts.rho - p.extra_generation_params["Schedule rho"] = opts.rho - - sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device) - elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) - elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential': - m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) - sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device) - else: - sigmas = self.model_wrap.get_sigmas(steps) - - if discard_next_to_last_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - - return sigmas - - def create_noise_sampler(self, x, sigmas, p): - """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" - if shared.opts.no_dpmpp_sde_batch_determinism: - return None - - from k_diffusion.sampling import BrownianTreeNoiseSampler - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] - return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - - sigmas = self.get_sigmas(p, steps) - - sigma_sched = sigmas[steps - t_enc - 1:] - xi = x + noise * sigma_sched[0] - - extra_params_kwargs = self.initialize(p) - parameters = inspect.signature(self.func).parameters - - if 'sigma_min' in parameters: - ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last - extra_params_kwargs['sigma_min'] = sigma_sched[-2] - if 'sigma_max' in parameters: - extra_params_kwargs['sigma_max'] = sigma_sched[0] - if 'n' in parameters: - extra_params_kwargs['n'] = len(sigma_sched) - 1 - if 'sigma_sched' in parameters: - extra_params_kwargs['sigma_sched'] = sigma_sched - if 'sigmas' in parameters: - extra_params_kwargs['sigmas'] = sigma_sched - - if self.config.options.get('brownian_noise', False): - noise_sampler = self.create_noise_sampler(x, sigmas, p) - extra_params_kwargs['noise_sampler'] = noise_sampler - - self.model_wrap_cfg.init_latent = x - self.last_latent = x - extra_args = { - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale, - 's_min_uncond': self.s_min_uncond - } - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - if self.model_wrap_cfg.padded_cond_uncond: - p.extra_generation_params["Pad conds"] = True - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps = steps or p.steps - - sigmas = self.get_sigmas(p, steps) - - x = x * sigmas[0] - - extra_params_kwargs = self.initialize(p) - parameters = inspect.signature(self.func).parameters - - if 'sigma_min' in parameters: - extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() - extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in parameters: - extra_params_kwargs['n'] = steps - else: - extra_params_kwargs['sigmas'] = sigmas - - if self.config.options.get('brownian_noise', False): - noise_sampler = self.create_noise_sampler(x, sigmas, p) - extra_params_kwargs['noise_sampler'] = noise_sampler - - self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ - 'cond': conditioning, - 'image_cond': image_conditioning, - 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale, - 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) - - if self.model_wrap_cfg.padded_cond_uncond: - p.extra_generation_params["Pad conds"] = True - - return samples - -- cgit v1.2.1 From c721884cf5b9692c32461ffdecfc9121ca0d47b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 18:32:18 +0300 Subject: Split history: mv temp modules/sd_samplers_kdiffusion.py --- modules/sd_samplers_kdiffusion.py | 511 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 511 insertions(+) create mode 100644 modules/sd_samplers_kdiffusion.py (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py new file mode 100644 index 00000000..db71a549 --- /dev/null +++ b/modules/sd_samplers_kdiffusion.py @@ -0,0 +1,511 @@ +from collections import deque +import torch +import inspect +import k_diffusion.sampling +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra + +from modules.processing import StableDiffusionProcessing +from modules.shared import opts, state +import modules.shared as shared +from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback +from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback + +samplers_k_diffusion = [ + ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), + ('Euler', 'sample_euler', ['k_euler'], {}), + ('LMS', 'sample_lms', ['k_lms'], {}), + ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), + ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), + ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), + ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), + ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), + ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), + ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), + ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), +] + + +samplers_data_k_diffusion = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_k_diffusion + if callable(funcname) or hasattr(k_diffusion.sampling, funcname) +] + +sampler_extra_params = { + 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], +} + +k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} +k_diffusion_scheduler = { + 'Automatic': None, + 'karras': k_diffusion.sampling.get_sigmas_karras, + 'exponential': k_diffusion.sampling.get_sigmas_exponential, + 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential +} + + +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + +class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, model): + super().__init__() + self.inner_model = model + self.mask = None + self.nmask = None + self.init_latent = None + self.step = 0 + self.image_cfg_scale = None + self.padded_cond_uncond = False + + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) + + return denoised + + def combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) + + return denoised + + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, + # so is_edit_model is set to False to support AND composition. + is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + + assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + if shared.sd_model.model.conditioning_key == "crossattn-adm": + image_uncond = torch.zeros_like(image_cond) + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} + else: + image_uncond = image_cond + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} + + if not is_edit_model: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) + else: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) + + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) + cfg_denoiser_callback(denoiser_params) + x_in = denoiser_params.x + image_cond_in = denoiser_params.image_cond + sigma_in = denoiser_params.sigma + tensor = denoiser_params.text_cond + uncond = denoiser_params.text_uncond + skip_uncond = False + + # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it + if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: + skip_uncond = True + x_in = x_in[:-batch_size] + sigma_in = sigma_in[:-batch_size] + + self.padded_cond_uncond = False + if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: + empty = shared.sd_model.cond_stage_model_empty_prompt + num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] + + if num_repeats < 0: + tensor = pad_cond(tensor, -num_repeats, empty) + self.padded_cond_uncond = True + elif num_repeats > 0: + uncond = pad_cond(uncond, num_repeats, empty) + self.padded_cond_uncond = True + + if tensor.shape[1] == uncond.shape[1] or skip_uncond: + if is_edit_model: + cond_in = catenate_conds([tensor, uncond, uncond]) + elif skip_uncond: + cond_in = tensor + else: + cond_in = catenate_conds([tensor, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) + else: + x_out = torch.zeros_like(x_in) + batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): + a = batch_offset + b = min(a + batch_size, tensor.shape[0]) + + if not is_edit_model: + c_crossattn = subscript_cond(tensor, a, b) + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) + + if not skip_uncond: + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) + + denoised_image_indexes = [x[0][0] for x in conds_list] + if skip_uncond: + fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) + x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be + + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) + cfg_denoised_callback(denoised_params) + + devices.test_for_nans(x_out, "unet") + + if opts.live_preview_content == "Prompt": + sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes])) + elif opts.live_preview_content == "Negative prompt": + sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) + + if is_edit_model: + denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) + elif skip_uncond: + denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) + else: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + + if self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + + after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) + cfg_after_cfg_callback(after_cfg_callback_params) + denoised = after_cfg_callback_params.x + + self.step += 1 + return denoised + + +class TorchHijack: + def __init__(self, sampler_noises): + # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based + # implementation. + self.sampler_noises = deque(sampler_noises) + + def __getattr__(self, item): + if item == 'randn_like': + return self.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + + def randn_like(self, x): + if self.sampler_noises: + noise = self.sampler_noises.popleft() + if noise.shape == x.shape: + return noise + + return devices.randn_like(x) + + +class KDiffusionSampler: + def __init__(self, funcname, sd_model): + denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + + self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) + self.funcname = funcname + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) + self.extra_params = sampler_extra_params.get(funcname, []) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.sampler_noises = None + self.stop_at = None + self.eta = None + self.config = None # set by the function calling the constructor + self.last_latent = None + self.s_min_uncond = None + + # NOTE: These are also defined in the StableDiffusionProcessing class. + # They should have been here to begin with but we're going to + # leave that class __init__ signature alone. + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + + self.conditioning_key = sd_model.model.conditioning_key + + def callback_state(self, d): + step = d['i'] + latent = d["denoised"] + if opts.live_preview_content == "Combined": + sd_samplers_common.store_latent(latent) + self.last_latent = latent + + if self.stop_at is not None and step > self.stop_at: + raise sd_samplers_common.InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except RecursionError: + print( + 'Encountered RecursionError during sampling, returning last latent. ' + 'rho >5 with a polyexponential scheduler may cause this error. ' + 'You should try to use a smaller rho value instead.' + ) + return self.last_latent + except sd_samplers_common.InterruptedException: + return self.last_latent + + def number_of_needed_noises(self, p): + return p.steps + + def initialize(self, p: StableDiffusionProcessing): + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None + self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) + self.eta = p.eta if p.eta is not None else opts.eta_ancestral + self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) + + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params["Eta"] = self.eta + + extra_params_kwargs['eta'] = self.eta + + if len(self.extra_params) > 0: + s_churn = getattr(opts, 's_churn', p.s_churn) + s_tmin = getattr(opts, 's_tmin', p.s_tmin) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf + s_noise = getattr(opts, 's_noise', p.s_noise) + + if s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + p.s_churn = s_churn + p.extra_generation_params['Sigma churn'] = s_churn + if s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_tmin + p.s_tmin = s_tmin + p.extra_generation_params['Sigma tmin'] = s_tmin + if s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_tmax + p.s_tmax = s_tmax + p.extra_generation_params['Sigma tmax'] = s_tmax + if s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_noise + p.s_noise = s_noise + p.extra_generation_params['Sigma noise'] = s_noise + + return extra_params_kwargs + + def get_sigmas(self, p, steps): + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: + discard_next_to_last_sigma = True + p.extra_generation_params["Discard penultimate sigma"] = True + + steps += 1 if discard_next_to_last_sigma else 0 + + if p.sampler_noise_scheduler_override: + sigmas = p.sampler_noise_scheduler_override(steps) + elif opts.k_sched_type != "Automatic": + m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) + sigmas_kwargs = { + 'sigma_min': sigma_min, + 'sigma_max': sigma_max, + } + + sigmas_func = k_diffusion_scheduler[opts.k_sched_type] + p.extra_generation_params["Schedule type"] = opts.k_sched_type + + if opts.sigma_min != m_sigma_min and opts.sigma_min != 0: + sigmas_kwargs['sigma_min'] = opts.sigma_min + p.extra_generation_params["Schedule min sigma"] = opts.sigma_min + if opts.sigma_max != m_sigma_max and opts.sigma_max != 0: + sigmas_kwargs['sigma_max'] = opts.sigma_max + p.extra_generation_params["Schedule max sigma"] = opts.sigma_max + + default_rho = 1. if opts.k_sched_type == "polyexponential" else 7. + + if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho: + sigmas_kwargs['rho'] = opts.rho + p.extra_generation_params["Schedule rho"] = opts.rho + + sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) + elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential': + m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device) + else: + sigmas = self.model_wrap.get_sigmas(steps) + + if discard_next_to_last_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + return sigmas + + def create_noise_sampler(self, x, sigmas, p): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + if shared.opts.no_dpmpp_sde_batch_determinism: + return None + + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + + sigmas = self.get_sigmas(p, steps) + + sigma_sched = sigmas[steps - t_enc - 1:] + xi = x + noise * sigma_sched[0] + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'sigma_min' in parameters: + ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last + extra_params_kwargs['sigma_min'] = sigma_sched[-2] + if 'sigma_max' in parameters: + extra_params_kwargs['sigma_max'] = sigma_sched[0] + if 'n' in parameters: + extra_params_kwargs['n'] = len(sigma_sched) - 1 + if 'sigma_sched' in parameters: + extra_params_kwargs['sigma_sched'] = sigma_sched + if 'sigmas' in parameters: + extra_params_kwargs['sigmas'] = sigma_sched + + if self.config.options.get('brownian_noise', False): + noise_sampler = self.create_noise_sampler(x, sigmas, p) + extra_params_kwargs['noise_sampler'] = noise_sampler + + self.model_wrap_cfg.init_latent = x + self.last_latent = x + extra_args = { + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + } + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps = steps or p.steps + + sigmas = self.get_sigmas(p, steps) + + x = x * sigmas[0] + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'sigma_min' in parameters: + extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() + extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() + if 'n' in parameters: + extra_params_kwargs['n'] = steps + else: + extra_params_kwargs['sigmas'] = sigmas + + if self.config.options.get('brownian_noise', False): + noise_sampler = self.create_noise_sampler(x, sigmas, p) + extra_params_kwargs['noise_sampler'] = noise_sampler + + self.last_latent = x + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + -- cgit v1.2.1 From 2d8e4a654480ea080fec62834331a3c632ed0330 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 18:35:31 +0300 Subject: split sd_samplers_kdiffusion into two --- modules/sd_samplers_kdiffusion.py | 191 +------------------------------------- 1 file changed, 2 insertions(+), 189 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index db71a549..9c9b46d1 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,14 +2,11 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra +from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.processing import StableDiffusionProcessing from modules.shared import opts, state import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback samplers_k_diffusion = [ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), @@ -57,190 +54,6 @@ k_diffusion_scheduler = { } -def catenate_conds(conds): - if not isinstance(conds[0], dict): - return torch.cat(conds) - - return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} - - -def subscript_cond(cond, a, b): - if not isinstance(cond, dict): - return cond[a:b] - - return {key: vec[a:b] for key, vec in cond.items()} - - -def pad_cond(tensor, repeats, empty): - if not isinstance(tensor, dict): - return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) - - tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) - return tensor - - -class CFGDenoiser(torch.nn.Module): - """ - Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) - that can take a noisy picture and produce a noise-free picture using two guidances (prompts) - instead of one. Originally, the second prompt is just an empty string, but we use non-empty - negative prompt. - """ - - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - self.image_cfg_scale = None - self.padded_cond_uncond = False - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised - - def combine_denoised_for_edit_model(self, x_out, cond_scale): - out_cond, out_img_cond, out_uncond = x_out.chunk(3) - denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, - # so is_edit_model is set to False to support AND composition. - is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - if shared.sd_model.model.conditioning_key == "crossattn-adm": - image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} - else: - image_uncond = image_cond - if isinstance(uncond, dict): - make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} - else: - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} - - if not is_edit_model: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) - else: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - tensor = denoiser_params.text_cond - uncond = denoiser_params.text_uncond - skip_uncond = False - - # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it - if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: - skip_uncond = True - x_in = x_in[:-batch_size] - sigma_in = sigma_in[:-batch_size] - - self.padded_cond_uncond = False - if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: - empty = shared.sd_model.cond_stage_model_empty_prompt - num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] - - if num_repeats < 0: - tensor = pad_cond(tensor, -num_repeats, empty) - self.padded_cond_uncond = True - elif num_repeats > 0: - uncond = pad_cond(uncond, num_repeats, empty) - self.padded_cond_uncond = True - - if tensor.shape[1] == uncond.shape[1] or skip_uncond: - if is_edit_model: - cond_in = catenate_conds([tensor, uncond, uncond]) - elif skip_uncond: - cond_in = tensor - else: - cond_in = catenate_conds([tensor, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - - if not is_edit_model: - c_crossattn = subscript_cond(tensor, a, b) - else: - c_crossattn = torch.cat([tensor[a:b]], uncond) - - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) - - if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) - - denoised_image_indexes = [x[0][0] for x in conds_list] - if skip_uncond: - fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) - x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be - - denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) - cfg_denoised_callback(denoised_params) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes])) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - if is_edit_model: - denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) - elif skip_uncond: - denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) - else: - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) - cfg_after_cfg_callback(after_cfg_callback_params) - denoised = after_cfg_callback_params.x - - self.step += 1 - return denoised - - class TorchHijack: def __init__(self, sampler_noises): # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based @@ -273,7 +86,7 @@ class KDiffusionSampler: self.funcname = funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap) self.sampler_noises = None self.stop_at = None self.eta = None -- cgit v1.2.1 From 8285a149d8c488ae6c7a566eb85fb5e825145464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 19:20:11 +0300 Subject: add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them) --- modules/sd_samplers_kdiffusion.py | 152 ++++---------------------------------- 1 file changed, 14 insertions(+), 138 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 9c9b46d1..3a2e01b7 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -4,8 +4,7 @@ import inspect import k_diffusion.sampling from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser -from modules.processing import StableDiffusionProcessing -from modules.shared import opts, state +from modules.shared import opts import modules.shared as shared samplers_k_diffusion = [ @@ -54,133 +53,17 @@ k_diffusion_scheduler = { } -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise +class KDiffusionSampler(sd_samplers_common.Sampler): + def __init__(self, funcname, sd_model): - return devices.randn_like(x) + super().__init__(funcname) + self.extra_params = sampler_extra_params.get(funcname, []) + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) -class KDiffusionSampler: - def __init__(self, funcname, sd_model): denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.config = None # set by the function calling the constructor - self.last_latent = None - self.s_min_uncond = None - - # NOTE: These are also defined in the StableDiffusionProcessing class. - # They should have been here to begin with but we're going to - # leave that class __init__ signature alone. - self.s_churn = 0.0 - self.s_tmin = 0.0 - self.s_tmax = float('inf') - self.s_noise = 1.0 - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except RecursionError: - print( - 'Encountered RecursionError during sampling, returning last latent. ' - 'rho >5 with a polyexponential scheduler may cause this error. ' - 'You should try to use a smaller rho value instead.' - ) - return self.last_latent - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p: StableDiffusionProcessing): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) - self.eta = p.eta if p.eta is not None else opts.eta_ancestral - self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - if self.eta != 1.0: - p.extra_generation_params["Eta"] = self.eta - - extra_params_kwargs['eta'] = self.eta - - if len(self.extra_params) > 0: - s_churn = getattr(opts, 's_churn', p.s_churn) - s_tmin = getattr(opts, 's_tmin', p.s_tmin) - s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf - s_noise = getattr(opts, 's_noise', p.s_noise) - - if s_churn != self.s_churn: - extra_params_kwargs['s_churn'] = s_churn - p.s_churn = s_churn - p.extra_generation_params['Sigma churn'] = s_churn - if s_tmin != self.s_tmin: - extra_params_kwargs['s_tmin'] = s_tmin - p.s_tmin = s_tmin - p.extra_generation_params['Sigma tmin'] = s_tmin - if s_tmax != self.s_tmax: - extra_params_kwargs['s_tmax'] = s_tmax - p.s_tmax = s_tmax - p.extra_generation_params['Sigma tmax'] = s_tmax - if s_noise != self.s_noise: - extra_params_kwargs['s_noise'] = s_noise - p.s_noise = s_noise - p.extra_generation_params['Sigma noise'] = s_noise - - return extra_params_kwargs + self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -232,22 +115,12 @@ class KDiffusionSampler: return sigmas - def create_noise_sampler(self, x, sigmas, p): - """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" - if shared.opts.no_dpmpp_sde_batch_determinism: - return None - - from k_diffusion.sampling import BrownianTreeNoiseSampler - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] - return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) sigmas = self.get_sigmas(p, steps) - sigma_sched = sigmas[steps - t_enc - 1:] + xi = x + noise * sigma_sched[0] extra_params_kwargs = self.initialize(p) @@ -296,12 +169,14 @@ class KDiffusionSampler: extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters + if 'n' in parameters: + extra_params_kwargs['n'] = steps + if 'sigma_min' in parameters: extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in parameters: - extra_params_kwargs['n'] = steps - else: + + if 'sigmas' in parameters: extra_params_kwargs['sigmas'] = sigmas if self.config.options.get('brownian_noise', False): @@ -322,3 +197,4 @@ class KDiffusionSampler: return samples + -- cgit v1.2.1 From a8a256f9b5b445206818bfc8a363ed5a1ba50c86 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 21:07:18 +0300 Subject: REMOVE --- modules/sd_samplers_kdiffusion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3a2e01b7..27a73486 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,8 +1,7 @@ -from collections import deque import torch import inspect import k_diffusion.sampling -from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.shared import opts import modules.shared as shared -- cgit v1.2.1 From ae1bde1aa1a987cd233fccb2caaec3abf8012178 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 21:10:12 +0300 Subject: put commonly used samplers on top, make DPM++ 2M Karras the default choice --- modules/sd_samplers_kdiffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 27a73486..f47431af 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -7,6 +7,10 @@ from modules.shared import opts import modules.shared as shared samplers_k_diffusion = [ + ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), + ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), @@ -23,10 +27,6 @@ samplers_k_diffusion = [ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), - ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), ] -- cgit v1.2.1 From f8ff8c0638997fd0aef217db1505598846f14782 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 22:09:40 +0300 Subject: merge errors --- modules/sd_samplers_kdiffusion.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3ff4b634..95a43cef 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -52,17 +52,24 @@ k_diffusion_scheduler = { } +class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap + + class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): - super().__init__(funcname) - self.extra_params = sampler_extra_params.get(funcname, []) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserKDiffusion(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) -- cgit v1.2.1 From 0e83c675257f473e024511845e7940802333fd5f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 22:27:32 +0300 Subject: by request: fix tiled vae extension --- modules/sd_samplers_kdiffusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index f47431af..5613b8c1 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,7 +1,8 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser +from modules import sd_samplers_common, sd_samplers_extra +from modules.sd_samplers_cfg_denoiser import CFGDenoiser from modules.shared import opts import modules.shared as shared @@ -62,7 +63,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap, self) def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) -- cgit v1.2.1 From ac8a5d18d3ede6bcb8fa5a3da1c7c28e064cd65d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 10 Aug 2023 17:04:59 +0300 Subject: resolve merge issues --- modules/sd_samplers_kdiffusion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e1854980..95a43cef 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,8 +1,7 @@ import torch import inspect import k_diffusion.sampling -from modules import sd_samplers_common, sd_samplers_extra -from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.shared import opts import modules.shared as shared -- cgit v1.2.1 From f4979422dd937f8a14a9eb143b837c958e8557b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 10 Aug 2023 17:18:33 +0300 Subject: return the line lost during the merge --- modules/sd_samplers_kdiffusion.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 95a43cef..60534b38 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,6 +2,7 @@ import torch import inspect import k_diffusion.sampling from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser +from modules.sd_samplers_cfg_denoiser import CFGDenoiser from modules.shared import opts import modules.shared as shared -- cgit v1.2.1 From 4549f2a9cc0ed9045661192e74d146fa21f300d6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 10 Aug 2023 17:21:01 +0300 Subject: lint --- modules/sd_samplers_kdiffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 60534b38..d10fe12e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ import torch import inspect import k_diffusion.sampling from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser -from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 from modules.shared import opts import modules.shared as shared -- cgit v1.2.1 From 64311faa6848d641cc452115e4e1eb47d2a7b519 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 12 Aug 2023 12:39:59 +0300 Subject: put refiner into main UI, into the new accordions section add VAE from main model into infotext, not from refiner model option to make scripts UI without gr.Group fix inconsistencies with refiner when usings samplers that do more denoising than steps --- modules/sd_samplers_kdiffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index d10fe12e..1f8e9c4b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -64,9 +64,10 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): class KDiffusionSampler(sd_samplers_common.Sampler): - def __init__(self, funcname, sd_model): + def __init__(self, funcname, sd_model, options=None): super().__init__(funcname) + self.options = options or {} self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.model_wrap_cfg = CFGDenoiserKDiffusion(self) -- cgit v1.2.1 From 1ae9dacb4b036a6cb4b5fb9b9ff030962f43908e Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 13 Aug 2023 07:57:29 -0400 Subject: Add DPM-Solver++(3M) SDE --- modules/sd_samplers_kdiffusion.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 1f8e9c4b..a48a563f 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -22,6 +22,9 @@ samplers_k_diffusion = [ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), + ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {"brownian_noise": True}), + ('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), -- cgit v1.2.1 From d1a70c3f0534b88665d55bdac1e6b48a63f7f035 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 13 Aug 2023 08:22:24 -0400 Subject: Add s_noise param to more samplers --- modules/sd_samplers_kdiffusion.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index a48a563f..9f5dfd6d 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -45,6 +45,12 @@ sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_fast': ['s_noise'], + 'sample_dpm_2_ancestral': ['s_noise'], + 'sample_dpmpp_2s_ancestral': ['s_noise'], + 'sample_dpmpp_sde': ['s_noise'], + 'sample_dpmpp_2m_sde': ['s_noise'], + 'sample_dpmpp_3m_sde': ['s_noise'], } k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} -- cgit v1.2.1 From ac790fc49b314761a7ad711989d7cc88bba64578 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 13 Aug 2023 08:46:07 -0400 Subject: Discard penultimate sigma for DPM-Solver++(3M) SDE --- modules/sd_samplers_kdiffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index a48a563f..c8b81fa1 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -22,9 +22,9 @@ samplers_k_diffusion = [ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), - ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {"brownian_noise": True}), - ('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), - ('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), + ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}), + ('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}), + ('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), -- cgit v1.2.1 From 525b55b1e9dc589a1133a8031ed3b1646e5cccff Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 13 Aug 2023 09:08:34 -0400 Subject: Restore extra_params that was lost in merge --- modules/sd_samplers_kdiffusion.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 1f8e9c4b..9a89e7fa 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -67,6 +67,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model, options=None): super().__init__(funcname) + self.extra_params = sampler_extra_params.get(funcname, []) + self.options = options or {} self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) -- cgit v1.2.1 From 0ea61a74be5e9666a16de8b92b1d55ed0b678a16 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 14 Aug 2023 11:46:36 +0800 Subject: add res(dpmdd 2m sde heun) and reorder the sampler list --- modules/sd_samplers_kdiffusion.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 0bacfe8d..c140d2d0 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -8,10 +8,6 @@ from modules.shared import opts import modules.shared as shared samplers_k_diffusion = [ - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), @@ -20,8 +16,15 @@ samplers_k_diffusion = [ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), + ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), + ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), + ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}), + ('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}), + ('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}), ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}), ('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}), ('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}), @@ -161,6 +164,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler): noise_sampler = self.create_noise_sampler(x, sigmas, p) extra_params_kwargs['noise_sampler'] = noise_sampler + if self.config.options.get('solver_type', None) == 'heun': + extra_params_kwargs['solver_type'] = 'heun' + self.model_wrap_cfg.init_latent = x self.last_latent = x self.sampler_extra_args = { @@ -202,6 +208,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler): noise_sampler = self.create_noise_sampler(x, sigmas, p) extra_params_kwargs['noise_sampler'] = noise_sampler + if self.config.options.get('solver_type', None) == 'heun': + extra_params_kwargs['solver_type'] = 'heun' + self.last_latent = x self.sampler_extra_args = { 'cond': conditioning, @@ -210,6 +219,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: -- cgit v1.2.1 From aa26f8eb40c2f99adb6752668b1fad1b5ae0158f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 14 Aug 2023 13:50:53 +0800 Subject: Put frequently used sampler back --- modules/sd_samplers_kdiffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index c140d2d0..67853ff1 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -8,6 +8,10 @@ from modules.shared import opts import modules.shared as shared samplers_k_diffusion = [ + ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), + ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), + ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}), ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), @@ -16,12 +20,8 @@ samplers_k_diffusion = [ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), - ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), - ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}), ('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}), ('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}), -- cgit v1.2.1 From 6bfd4dfecfd03bb74e157d2b81484276875cc0ad Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 14 Aug 2023 12:07:38 +0300 Subject: add second_order to samplers that mistakenly didn't have it --- modules/sd_samplers_kdiffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 67853ff1..08b9e740 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -16,8 +16,8 @@ samplers_k_diffusion = [ ('Euler', 'sample_euler', ['k_euler'], {}), ('LMS', 'sample_lms', ['k_lms'], {}), ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), @@ -34,7 +34,7 @@ samplers_k_diffusion = [ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), - ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), + ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}), ] -- cgit v1.2.1 From 371b24b17c1cf98c9068a4b585b93cc1610702dc Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 15 Aug 2023 02:19:19 -0400 Subject: Add extra img2img noise --- modules/sd_samplers_kdiffusion.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 08b9e740..08866e41 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -145,6 +145,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler): xi = x + noise * sigma_sched[0] + if opts.img2img_extra_noise > 0: + p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise + xi += noise * opts.img2img_extra_noise + extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters -- cgit v1.2.1 From 254be4eeb259fd3bd2452250fca1bd278fa248ff Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 16 Aug 2023 21:45:19 -0400 Subject: Add extra noise callback --- modules/sd_samplers_kdiffusion.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 08866e41..b9e0d577 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -3,6 +3,7 @@ import inspect import k_diffusion.sampling from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 +from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback from modules.shared import opts import modules.shared as shared @@ -147,6 +148,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler): if opts.img2img_extra_noise > 0: p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise + extra_noise_params = ExtraNoiseParams(noise, x) + extra_noise_callback(extra_noise_params) + noise = extra_noise_params.noise xi += noise * opts.img2img_extra_noise extra_params_kwargs = self.initialize(p) -- cgit v1.2.1