diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-12 12:39:59 +0300 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-12 12:39:59 +0300 |
commit | 64311faa6848d641cc452115e4e1eb47d2a7b519 (patch) | |
tree | 3849ad4ca8ad7f44c8f20e3ab359578f3a4021ec /modules/sd_samplers_common.py | |
parent | 26c92f056acc795af5066779f1b8aedb8dfa983d (diff) |
put refiner into main UI, into the new accordions section
add VAE from main model into infotext, not from refiner model
option to make scripts UI without gr.Group
fix inconsistencies with refiner when usings samplers that do more denoising than steps
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r-- | modules/sd_samplers_common.py | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 35c4d657..85f3c7e0 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s from modules.shared import opts, state
import k_diffusion.sampling
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+
+class SamplerData(SamplerDataTuple):
+ def total_steps(self, steps):
+ if self.options.get("second_order", False):
+ steps = steps * 2
+
+ return steps
def setup_img2img_steps(p, steps=None):
@@ -131,31 +140,26 @@ def replace_torchsde_browinan(): replace_torchsde_browinan()
-def apply_refiner(sampler):
- completed_ratio = sampler.step / sampler.steps
+def apply_refiner(cfg_denoiser):
+ completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
+ refiner_switch_at = cfg_denoiser.p.refiner_switch_at
+ refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
- if completed_ratio <= shared.opts.sd_refiner_switch_at:
+ if refiner_switch_at is not None and completed_ratio <= refiner_switch_at:
return False
- if shared.opts.sd_refiner_checkpoint == "None":
+ if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False
- if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
- return False
-
- refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
- if refiner_checkpoint_info is None:
- raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
-
- sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
- sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
+ cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
+ cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
- sampler.p.setup_conds()
- sampler.update_inner_model()
+ cfg_denoiser.p.setup_conds()
+ cfg_denoiser.update_inner_model()
return True
@@ -192,7 +196,7 @@ class Sampler: self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.config = None # set by the function calling the constructor
+ self.config: SamplerData = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
self.s_churn = 0.0
@@ -208,6 +212,7 @@ class Sampler: self.p = None
self.model_wrap_cfg = None
self.sampler_extra_args = None
+ self.options = {}
def callback_state(self, d):
step = d['i']
@@ -220,6 +225,7 @@ class Sampler: def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps
+ self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
state.sampling_steps = steps
state.sampling_step = 0
|