aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r--modules/sd_samplers_common.py102
1 files changed, 86 insertions, 16 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 97bc0804..58efcad2 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -3,11 +3,20 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+
+class SamplerData(SamplerDataTuple):
+ def total_steps(self, steps):
+ if self.options.get("second_order", False):
+ steps = steps * 2
+
+ return steps
def setup_img2img_steps(p, steps=None):
@@ -26,22 +35,27 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD":
def samples_to_images_tensor(sample, approximation=None, model=None):
- '''latents -> images [-1, 1]'''
- if approximation is None:
+ """Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
+
+ if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
+ from modules import lowvram
+ if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
+ approximation = 1
+
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3:
- x_sample = sample * 1.5
- x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
+ x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
else:
if model is None:
model = shared.sd_model
- x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
+ with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
+ x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample
@@ -81,9 +95,19 @@ def images_tensor_to_samples(image, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
+ model.first_stage_model.to(devices.dtype_vae)
+
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
- x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
+ if len(image) > 1:
+ x_latent = torch.stack([
+ model.get_first_stage_encoding(
+ model.encode_first_stage(torch.unsqueeze(img, 0))
+ )[0]
+ for img in image
+ ])
+ else:
+ x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent
@@ -131,6 +155,42 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
+def apply_refiner(cfg_denoiser):
+ completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
+ refiner_switch_at = cfg_denoiser.p.refiner_switch_at
+ refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
+
+ if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
+ return False
+
+ if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
+ return False
+
+ if getattr(cfg_denoiser.p, "enable_hr", False):
+ is_second_pass = cfg_denoiser.p.is_hr_pass
+
+ if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
+ return False
+
+ if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
+ return False
+
+ if opts.hires_fix_refiner_pass != "second pass":
+ cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
+
+ cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
+ cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
+
+ with sd_models.SkipWritingToConfig():
+ sd_models.reload_model_weights(info=refiner_checkpoint_info)
+
+ devices.torch_gc()
+ cfg_denoiser.p.setup_conds()
+ cfg_denoiser.update_inner_model()
+
+ return True
+
+
class TorchHijack:
"""This is here to replace torch.randn_like of k-diffusion.
@@ -163,7 +223,7 @@ class Sampler:
self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.config = None # set by the function calling the constructor
+ self.config: SamplerData = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
self.s_churn = 0.0
@@ -173,11 +233,14 @@ class Sampler:
self.eta_option_field = 'eta_ancestral'
self.eta_infotext_field = 'Eta'
+ self.eta_default = 1.0
self.conditioning_key = shared.sd_model.model.conditioning_key
- self.model_wrap = None
+ self.p = None
self.model_wrap_cfg = None
+ self.sampler_extra_args = None
+ self.options = {}
def callback_state(self, d):
step = d['i']
@@ -189,6 +252,8 @@ class Sampler:
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
+ self.model_wrap_cfg.steps = steps
+ self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
state.sampling_steps = steps
state.sampling_step = 0
@@ -208,6 +273,8 @@ class Sampler:
return p.steps
def initialize(self, p) -> dict:
+ self.p = p
+ self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
@@ -223,7 +290,7 @@ class Sampler:
extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
+ if self.eta != self.eta_default:
p.extra_generation_params[self.eta_infotext_field] = self.eta
extra_params_kwargs['eta'] = self.eta
@@ -234,19 +301,19 @@ class Sampler:
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise)
- if s_churn != self.s_churn:
+ if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn
- if s_tmin != self.s_tmin:
+ if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin
- if s_tmax != self.s_tmax:
+ if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax
- if s_noise != self.s_noise:
+ if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise
@@ -263,5 +330,8 @@ class Sampler:
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ raise NotImplementedError()
-
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ raise NotImplementedError()