aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rwxr-xr-xmodules/processing.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 131c4c3c..5996cbac 100755
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -373,9 +373,10 @@ class StableDiffusionProcessing:
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
- self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
+ self.step_multiplier = total_steps // self.steps
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
def get_conds(self):
return self.c, self.uc
@@ -579,8 +580,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
- "VAE hash": sd_vae.get_loaded_vae_hash() if opts.add_model_hash_to_info else None,
- "VAE": sd_vae.get_loaded_vae_name() if opts.add_model_name_to_info else None,
+ "VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
+ "VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -669,6 +670,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.tiling is None:
p.tiling = opts.tiling
+ p.loaded_vae_name = sd_vae.get_loaded_vae_name()
+ p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()
+
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
@@ -1188,8 +1192,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+ sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
+ steps = self.hr_second_pass_steps or self.steps
+ total_steps = sampler_config.total_steps(steps) if sampler_config else steps
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
def setup_conds(self):
super().setup_conds()