diff options
Diffstat (limited to 'modules/sd_samplers_cfg_denoiser.py')
-rw-r--r-- | modules/sd_samplers_cfg_denoiser.py | 37 |
1 files changed, 32 insertions, 5 deletions
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index d826222c..b8101d38 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -38,16 +38,29 @@ class CFGDenoiser(torch.nn.Module): negative prompt.
"""
- def __init__(self, model, sampler):
+ def __init__(self, sampler):
super().__init__()
- self.inner_model = model
+ self.model_wrap = None
self.mask = None
self.nmask = None
self.init_latent = None
+ self.steps = None
+ """number of steps as specified by user in UI"""
+
+ self.total_steps = None
+ """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
+
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.sampler = sampler
+ self.model_wrap = None
+ self.p = None
+ self.mask_before_denoising = False
+
+ @property
+ def inner_model(self):
+ raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -68,10 +81,21 @@ class CFGDenoiser(torch.nn.Module): def get_pred_x0(self, x_in, x_out, sigma):
return x_out
+ def update_inner_model(self):
+ self.model_wrap = None
+
+ c, uc = self.p.get_conds()
+ self.sampler.sampler_extra_args['cond'] = c
+ self.sampler.sampler_extra_args['uncond'] = uc
+
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
+ if sd_samplers_common.apply_refiner(self):
+ cond = self.sampler.sampler_extra_args['cond']
+ uncond = self.sampler.sampler_extra_args['uncond']
+
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
@@ -81,7 +105,7 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
- if self.mask is not None:
+ if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x
batch_size = len(conds_list)
@@ -141,7 +165,7 @@ class CFGDenoiser(torch.nn.Module): else:
cond_in = catenate_conds([tensor, uncond])
- if shared.batch_cond_uncond:
+ if shared.opts.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else:
x_out = torch.zeros_like(x_in)
@@ -151,7 +175,7 @@ class CFGDenoiser(torch.nn.Module): x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
@@ -183,6 +207,9 @@ class CFGDenoiser(torch.nn.Module): else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ if not self.mask_before_denoising and self.mask is not None:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
if opts.live_preview_content == "Prompt":
|