aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py64
1 files changed, 47 insertions, 17 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 9ebdb549..f8225b83 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -171,6 +171,7 @@ class StableDiffusionProcessing:
self.prompts = None
self.negative_prompts = None
+ self.extra_network_data = None
self.seeds = None
self.subseeds = None
@@ -311,7 +312,7 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
- def get_conds_with_caching(self, function, required_prompts, steps, cache):
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before.
@@ -320,27 +321,31 @@ class StableDiffusionProcessing:
representing the previously used arguments, or None if no arguments
have been used before. The second element is where the previously
computed result is stored.
+
+ caches is a list with items described above.
"""
- if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
- return cache[1]
+
+ for cache in caches:
+ if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
+ return cache[1]
+
+ cache = caches[0]
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
+ cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
return cache[1]
def setup_conds(self):
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c)
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self):
- self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
-
- return extra_network_data
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
class Processed:
@@ -681,7 +686,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.job_count == -1:
state.job_count = p.n_iter
- extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@@ -702,11 +706,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(p.prompts) == 0:
break
- extra_network_data = p.parse_extra_network_prompts()
+ p.parse_extra_network_prompts()
if not p.disable_extra_networks:
with devices.autocast():
- extra_networks.activate(p, extra_network_data)
+ extra_networks.activate(p, p.extra_network_data)
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@@ -741,7 +745,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
del samples_ddim
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -828,8 +832,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- if not p.disable_extra_networks and extra_network_data:
- extra_networks.deactivate(p, extra_network_data)
+ if not p.disable_extra_networks and p.extra_network_data:
+ extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc()
@@ -896,6 +900,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_negative_prompts = None
self.hr_extra_network_data = None
+ self.cached_hr_uc = [None, None]
+ self.cached_hr_c = [None, None]
self.hr_c = None
self.hr_uc = None
@@ -1059,6 +1065,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
+ with devices.autocast():
+ self.calculate_hr_conds()
+
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
@@ -1070,6 +1079,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
def close(self):
+ self.cached_hr_uc = [None, None]
+ self.cached_hr_c = [None, None]
self.hr_c = None
self.hr_uc = None
@@ -1098,12 +1109,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+ def calculate_hr_conds(self):
+ if self.hr_c is not None:
+ return
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+
def setup_conds(self):
super().setup_conds()
+ self.hr_uc = None
+ self.hr_c = None
+
if self.enable_hr:
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c)
+ if shared.opts.hires_fix_use_firstpass_conds:
+ self.calculate_hr_conds()
+
+ elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
+ self.calculate_hr_conds()
+
+ with devices.autocast():
+ extra_networks.activate(self, self.extra_network_data)
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()