aboutsummaryrefslogtreecommitdiff
path: root/modules/processing.py
diff options
context:
space:
mode:
authorpapuSpartan <macabeg@icloud.com>2023-04-04 02:26:44 -0500
committerpapuSpartan <macabeg@icloud.com>2023-04-04 02:26:44 -0500
commit5c8e53d5e98da0eabf384318955c57842d612c07 (patch)
tree62686f3d064381bb606624f0fc53ea97b5f4e9b4 /modules/processing.py
parentc707b7df95a61b66a05be94e805e1be9a432e294 (diff)
Allow different merge ratios to be used for each pass. Make toggle cmd flag work again. Remove ratio flag. Remove warning about controlnet being incompatible
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py44
1 files changed, 15 insertions, 29 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 55735572..670a7a28 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -501,26 +501,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
- if opts.token_merging and not opts.token_merging_hr_only:
- print("applying token merging to all passes")
- tomesd.apply_patch(
- p.sd_model,
- ratio=opts.token_merging_ratio,
- max_downsample=opts.token_merging_maximum_down_sampling,
- sx=opts.token_merging_stride_x,
- sy=opts.token_merging_stride_y,
- use_rand=opts.token_merging_random,
- merge_attn=opts.token_merging_merge_attention,
- merge_crossattn=opts.token_merging_merge_cross_attention,
- merge_mlp=opts.token_merging_merge_mlp
- )
+ if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
+ print("\nApplying token merging\n")
+ sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
res = process_images_inner(p)
finally:
# undo model optimizations made by tomesd
- if opts.token_merging:
- print('removing token merging model optimizations')
+ if opts.token_merging or cmd_opts.token_merging:
+ print('\nRemoving token merging model optimizations\n')
tomesd.remove_patch(p.sd_model)
# restore opts to original state
@@ -959,20 +949,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
devices.torch_gc()
# apply token merging optimizations from tomesd for high-res pass
- # check if hr_only so we don't redundantly apply patch
- if opts.token_merging and opts.token_merging_hr_only:
- print("applying token merging for high-res pass")
- tomesd.apply_patch(
- self.sd_model,
- ratio=opts.token_merging_ratio,
- max_downsample=opts.token_merging_maximum_down_sampling,
- sx=opts.token_merging_stride_x,
- sy=opts.token_merging_stride_y,
- use_rand=opts.token_merging_random,
- merge_attn=opts.token_merging_merge_attention,
- merge_crossattn=opts.token_merging_merge_cross_attention,
- merge_mlp=opts.token_merging_merge_mlp
- )
+ # check if hr_only so we are not redundantly patching
+ if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
+ # case where user wants to use separate merge ratios
+ if not opts.token_merging_hr_only:
+ # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
+ print('Temporarily reverting token merging optimizations in preparation for next pass')
+ tomesd.remove_patch(self.sd_model)
+
+ print("\nApplying token merging for high-res pass\n")
+ sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)