diff options
author | Vespinian <vespinian@proton.me> | 2023-03-11 12:33:35 -0500 |
---|---|---|
committer | Vespinian <vespinian@proton.me> | 2023-03-11 12:33:35 -0500 |
commit | 46f9fe3cd6b7be712d85cdfc2cdff7ac3fe878b5 (patch) | |
tree | a134b22d856ab534c497b04c0d7ec87935cfcce7 /modules/sd_hijack.py | |
parent | 2174f58daee1e077eec1125e196d34cc93dbaf23 (diff) | |
parent | 94ffa9fc5386e51f20692ab46906135e8de33110 (diff) |
Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diffusion-webui
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 79476783..f4bb0266 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,11 +37,23 @@ def apply_optimizations(): optimization_method = None
+ can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
+
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
+ elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
+ print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
+ optimization_method = 'sdp-no-mem'
+ elif cmd_opts.opt_sdp_attention and can_use_sdp:
+ print("Applying scaled dot product cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
+ optimization_method = 'sdp'
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|