aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorpapuSpartan <macabeg@icloud.com>2023-04-04 02:26:44 -0500
committerpapuSpartan <macabeg@icloud.com>2023-04-04 02:26:44 -0500
commit5c8e53d5e98da0eabf384318955c57842d612c07 (patch)
tree62686f3d064381bb606624f0fc53ea97b5f4e9b4 /modules/sd_models.py
parentc707b7df95a61b66a05be94e805e1be9a432e294 (diff)
Allow different merge ratios to be used for each pass. Make toggle cmd flag work again. Remove ratio flag. Remove warning about controlnet being incompatible
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py29
1 files changed, 28 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 87c49b83..696a2333 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -16,6 +16,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -545,4 +546,30 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
- return sd_model \ No newline at end of file
+ return sd_model
+
+
+def apply_token_merging(sd_model, hr: bool):
+ """
+ Applies speed and memory optimizations from tomesd.
+
+ Args:
+ hr (bool): True if called in the context of a high-res pass
+ """
+
+ ratio = shared.opts.token_merging_ratio
+ if hr:
+ ratio = shared.opts.token_merging_ratio_hr
+ print("effective hr pass merge ratio is "+str(ratio))
+
+ tomesd.apply_patch(
+ sd_model,
+ ratio=ratio,
+ max_downsample=shared.opts.token_merging_maximum_down_sampling,
+ sx=shared.opts.token_merging_stride_x,
+ sy=shared.opts.token_merging_stride_y,
+ use_rand=shared.opts.token_merging_random,
+ merge_attn=shared.opts.token_merging_merge_attention,
+ merge_crossattn=shared.opts.token_merging_merge_cross_attention,
+ merge_mlp=shared.opts.token_merging_merge_mlp
+ )