aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2022-12-27 08:50:55 -0500
committerbrkirch <brkirch@users.noreply.github.com>2023-01-06 00:14:13 -0500
commitd782a95967c9eea753df3333cd1954b6ec73eba0 (patch)
tree00e368f428916688518c3171bca8c01b92f4e549 /modules/sd_hijack.py
parent4af3ca5393151d61363c30eef4965e694eeac15e (diff)
Add Birch-san's sub-quadratic attention implementation
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py15
1 files changed, 6 insertions, 9 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 690a9ec2..019a6f3f 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
-from modules.sd_hijack_optimizations import invokeAI_mps_available
-
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
@@ -40,17 +38,16 @@ def apply_optimizations():
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ elif cmd_opts.opt_sub_quad_attention:
+ print("Applying sub-quadratic cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
- if not invokeAI_mps_available and shared.device.type == 'mps':
- print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- else:
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward