diff options
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 71 |
1 files changed, 64 insertions, 7 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index df3a6fe8..1b3dc302 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -38,6 +38,17 @@ samplers = [ samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
+def setup_img2img_steps(p):
+ if opts.img2img_fix_steps:
+ steps = int(p.steps / min(p.denoising_strength, 0.999))
+ t_enc = p.steps - 1
+ else:
+ steps = p.steps
+ t_enc = int(min(p.denoising_strength, 0.999) * steps)
+
+ return steps, t_enc
+
+
def sample_to_image(samples):
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
@@ -80,8 +91,12 @@ class VanillaStableDiffusionSampler: self.mask = None
self.nmask = None
self.init_latent = None
+ self.sampler_noises = None
self.step = 0
+ def number_of_needed_noises(self, p):
+ return 0
+
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
@@ -101,13 +116,13 @@ class VanillaStableDiffusionSampler: return res
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
- t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
+ steps, t_enc = setup_img2img_steps(p)
# existing code fails with cetain step counts, like 9
try:
- self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps, verbose=False)
except Exception:
- self.sampler.make_schedule(ddim_num_steps=p.steps+1, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps+1, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
@@ -115,6 +130,7 @@ class VanillaStableDiffusionSampler: self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
+ self.step = 0
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
@@ -127,6 +143,7 @@ class VanillaStableDiffusionSampler: self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
# existing code fails with cetin step counts, like 9
try:
@@ -183,42 +200,82 @@ def extended_trange(count, *args, **kwargs): shared.total_tqdm.update()
+class TorchHijack:
+ def __init__(self, kdiff_sampler):
+ self.kdiff_sampler = kdiff_sampler
+
+ def __getattr__(self, item):
+ if item == 'randn_like':
+ return self.kdiff_sampler.randn_like
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+
+
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ self.sampler_noises = None
+ self.sampler_noise_index = 0
def callback_state(self, d):
store_latent(d["denoised"])
+ def number_of_needed_noises(self, p):
+ return p.steps
+
+ def randn_like(self, x):
+ noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
+
+ if noise is not None and x.shape == noise.shape:
+ res = noise
+ else:
+ res = torch.randn_like(x)
+
+ self.sampler_noise_index += 1
+ return res
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
- t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
- sigmas = self.model_wrap.get_sigmas(p.steps)
+ steps, t_enc = setup_img2img_steps(p)
+
+ sigmas = self.model_wrap.get_sigmas(steps)
- noise = noise * sigmas[p.steps - t_enc - 1]
+ noise = noise * sigmas[steps - t_enc - 1]
xi = x + noise
- sigma_sched = sigmas[p.steps - t_enc - 1:]
+ sigma_sched = sigmas[steps - t_enc - 1:]
self.model_wrap_cfg.mask = p.mask
self.model_wrap_cfg.nmask = p.nmask
self.model_wrap_cfg.init_latent = p.init_latent
+ self.model_wrap.step = 0
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
+ if self.sampler_noises is not None:
+ k_diffusion.sampling.torch = TorchHijack(self)
+
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
def sample(self, p, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0]
+ self.model_wrap_cfg.step = 0
+
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
+ if self.sampler_noises is not None:
+ k_diffusion.sampling.torch = TorchHijack(self)
+
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
return samples_ddim
|