diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/generation_parameters_copypaste.py | 2 | ||||
-rw-r--r-- | modules/processing.py | 14 | ||||
-rw-r--r-- | modules/sd_models.py | 18 | ||||
-rw-r--r-- | modules/sd_samplers_common.py | 28 | ||||
-rw-r--r-- | modules/sd_samplers_compvis.py | 0 | ||||
-rw-r--r-- | modules/sd_samplers_kdiffusion.py | 9 | ||||
-rw-r--r-- | modules/shared.py | 2 |
7 files changed, 67 insertions, 6 deletions
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 5758e6f3..20e30b53 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [ ('Pad conds', 'pad_cond_uncond'),
('VAE Encoder', 'sd_vae_encode_method'),
('VAE Decoder', 'sd_vae_decode_method'),
+ ('Refiner', 'sd_refiner_checkpoint'),
+ ('Refiner switch at', 'sd_refiner_switch_at'),
]
diff --git a/modules/processing.py b/modules/processing.py index 61ba5f11..b635cc74 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -370,6 +370,9 @@ class StableDiffusionProcessing: self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ def get_conds(self):
+ return self.c, self.uc
+
def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
@@ -666,6 +669,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
+ # after running refiner, the refiner model is not unloaded - webui swaps back to main model here
+ if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
+ sd_models.reload_model_weights()
+
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
@@ -1244,6 +1251,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast():
extra_networks.activate(self, self.extra_network_data)
+ def get_conds(self):
+ if self.is_hr_pass:
+ return self.hr_c, self.hr_uc
+
+ return super().get_conds()
+
+
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()
diff --git a/modules/sd_models.py b/modules/sd_models.py index 53c1df54..a97af215 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return res
+class SkipWritingToConfig:
+ """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
+
+ skip = False
+ previous = None
+
+ def __enter__(self):
+ self.previous = SkipWritingToConfig.skip
+ SkipWritingToConfig.skip = True
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ SkipWritingToConfig.skip = self.previous
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
+ if not SkipWritingToConfig.skip:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index adda963b..15f27970 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -3,7 +3,7 @@ from collections import namedtuple, deque import numpy as np
import torch
from PIL import Image
-from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling
@@ -131,6 +131,32 @@ def replace_torchsde_browinan(): replace_torchsde_browinan()
+def apply_refiner(sampler):
+ completed_ratio = sampler.step / sampler.steps
+
+ if completed_ratio <= shared.opts.sd_refiner_switch_at:
+ return False
+
+ if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
+ return False
+
+ refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
+ if refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
+
+ sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
+ sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
+
+ with sd_models.SkipWritingToConfig():
+ sd_models.reload_model_weights(info=refiner_checkpoint_info)
+
+ devices.torch_gc()
+ sampler.p.setup_conds()
+ sampler.update_inner_model()
+
+ return True
+
+
class TorchHijack:
def __init__(self, sampler_noises):
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/modules/sd_samplers_compvis.py diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index f47431af..3ff4b634 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -143,7 +143,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): self.model_wrap_cfg.init_latent = x
self.last_latent = x
- extra_args = {
+ self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
@@ -151,7 +151,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): 's_min_uncond': self.s_min_uncond
}
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
@@ -183,13 +183,14 @@ class KDiffusionSampler(sd_samplers_common.Sampler): extra_params_kwargs['noise_sampler'] = noise_sampler
self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ self.sampler_extra_args = {
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ }
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
diff --git a/modules/shared.py b/modules/shared.py index 97f1eab5..2fd29904 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -462,6 +462,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
+ "sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"),
+ "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|