aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-05-14 08:21:02 +0300
committerGitHub <noreply@github.com>2023-05-14 08:21:02 +0300
commit7f6ef764b945c9a63200007a8aed7595fbe8439d (patch)
tree86590841a27fbe411c8cd3a9c5ccca92fd15e009 /modules/sd_models.py
parent005849331e82cded96f6f3e5ff828037c672c38d (diff)
parentc2fdb44880e07f43aee2f7edc1dc36a9516501e8 (diff)
Merge pull request #9256 from papuSpartan/tomesd
Integrate optional speed and memory improvements by token merging (via dbolya/tomesd)
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 3316d021..4c9a0a1f 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -578,3 +579,25 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, hr: bool):
+ """
+ Applies speed and memory optimizations from tomesd.
+
+ Args:
+ hr (bool): True if called in the context of a high-res pass
+ """
+
+ ratio = shared.opts.token_merging_ratio
+ if hr:
+ ratio = shared.opts.token_merging_ratio_hr
+
+ tomesd.apply_patch(
+ sd_model,
+ ratio=ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )