aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_cfg_denoiser.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-08 19:20:11 +0300
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-08 21:04:44 +0300
commit8285a149d8c488ae6c7a566eb85fb5e825145464 (patch)
tree89f204c4dc44f58ac2f15719f9bdbf28c4590cb5 /modules/sd_samplers_cfg_denoiser.py
parent2d8e4a654480ea080fec62834331a3c632ed0330 (diff)
add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them)
Diffstat (limited to 'modules/sd_samplers_cfg_denoiser.py')
-rw-r--r--modules/sd_samplers_cfg_denoiser.py50
1 files changed, 18 insertions, 32 deletions
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index 33a49783..166a00c7 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -39,7 +39,7 @@ class CFGDenoiser(torch.nn.Module):
negative prompt.
"""
- def __init__(self, model):
+ def __init__(self, model, sampler):
super().__init__()
self.inner_model = model
self.mask = None
@@ -48,6 +48,7 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
+ self.sampler = sampler
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -65,6 +66,9 @@ class CFGDenoiser(torch.nn.Module):
return denoised
+ def get_pred_x0(self, x_in, x_out, sigma):
+ return x_out
+
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -78,6 +82,9 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ if self.mask is not None:
+ x = self.init_latent * self.mask + self.nmask * x
+
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -170,11 +177,6 @@ class CFGDenoiser(torch.nn.Module):
devices.test_for_nans(x_out, "unet")
- if opts.live_preview_content == "Prompt":
- sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
- elif opts.live_preview_content == "Negative prompt":
- sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
-
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
@@ -182,8 +184,16 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
- if self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
+
+ if opts.live_preview_content == "Prompt":
+ preview = self.sampler.last_latent
+ elif opts.live_preview_content == "Negative prompt":
+ preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
+ else:
+ preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
+
+ sd_samplers_common.store_latent(preview)
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
@@ -192,27 +202,3 @@ class CFGDenoiser(torch.nn.Module):
self.step += 1
return denoised
-
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
-
- return devices.randn_like(x)
-