From 3262e825cc542ff634e6ba2e3a162eafdc6c1bba Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 21 Jan 2023 17:42:04 +0900 Subject: add --xformers-flash-attention option & impl --- modules/sd_hijack_optimizations.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 4fa54329..9967359b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -290,7 +290,19 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + # print('xformers_attention_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -365,7 +377,17 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v) + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)): + # print('xformers_attnblock_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + out = xformers.ops.memory_efficient_attention(q, k, v, op=op) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out -- cgit v1.2.1