From af081211ee93622473ee575de30fed2fd8263c09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 11 Jul 2023 21:16:43 +0300 Subject: getting SD2.1 to run on SDXL repo --- modules/sd_samplers_kdiffusion.py | 45 ++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 10 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b76..73289ce4 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -53,6 +53,28 @@ k_diffusion_scheduler = { } +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) @@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module): if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} else: image_uncond = image_cond - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) @@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module): num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: - tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1) + tensor = pad_cond(tensor, -num_repeats, empty) self.padded_cond_uncond = True elif num_repeats > 0: - uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) + uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: - cond_in = torch.cat([tensor, uncond, uncond]) + cond_in = catenate_conds([tensor, uncond, uncond]) elif skip_uncond: cond_in = tensor else: - cond_in = torch.cat([tensor, uncond]) + cond_in = catenate_conds([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in)) + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size @@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module): b = min(a + batch_size, tensor.shape[0]) if not is_edit_model: - c_crossattn = [tensor[a:b]] + c_crossattn = subscript_cond(tensor, a, b) else: c_crossattn = torch.cat([tensor[a:b]], uncond) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) denoised_image_indexes = [x[0][0] for x in conds_list] if skip_uncond: -- cgit v1.2.1 From da464a3fb39ecc6ea7b22fe87271194480d8501c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 12 Jul 2023 23:52:43 +0300 Subject: SDXL support --- modules/sd_samplers_kdiffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 73289ce4..5552a8dc 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -186,7 +186,7 @@ class CFGDenoiser(torch.nn.Module): for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size -- cgit v1.2.1