diff options
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r-- | modules/sd_samplers_common.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 07fc4434..60fa161c 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -35,22 +35,27 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": def samples_to_images_tensor(sample, approximation=None, model=None):
- '''latents -> images [-1, 1]'''
- if approximation is None:
+ """Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
+
+ if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
+ from modules import lowvram
+ if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
+ approximation = 1
+
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3:
- x_sample = sample * 1.5
- x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
+ x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
else:
if model is None:
model = shared.sd_model
- x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
+ with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
+ x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample
@@ -217,6 +222,7 @@ class Sampler: self.eta_option_field = 'eta_ancestral'
self.eta_infotext_field = 'Eta'
+ self.eta_default = 1.0
self.conditioning_key = shared.sd_model.model.conditioning_key
@@ -273,7 +279,7 @@ class Sampler: extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
- if self.eta != 1.0:
+ if self.eta != self.eta_default:
p.extra_generation_params[self.eta_infotext_field] = self.eta
extra_params_kwargs['eta'] = self.eta
|