aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 59adc7cc..abe57294 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -18,6 +18,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -467,6 +468,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
+
except Exception as e:
pass
@@ -580,3 +582,28 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, hr: bool):
+ """
+ Applies speed and memory optimizations from tomesd.
+
+ Args:
+ hr (bool): True if called in the context of a high-res pass
+ """
+
+ ratio = shared.opts.token_merging_ratio
+ if hr:
+ ratio = shared.opts.token_merging_ratio_hr
+
+ tomesd.apply_patch(
+ sd_model,
+ ratio=ratio,
+ max_downsample=shared.opts.token_merging_maximum_down_sampling,
+ sx=shared.opts.token_merging_stride_x,
+ sy=shared.opts.token_merging_stride_y,
+ use_rand=shared.opts.token_merging_random,
+ merge_attn=shared.opts.token_merging_merge_attention,
+ merge_crossattn=shared.opts.token_merging_merge_cross_attention,
+ merge_mlp=shared.opts.token_merging_merge_mlp
+ )