From c0484f1b986ce7acb0e3596f6089a191279f5442 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 22:48:54 -0400 Subject: Add cross-attention optimization from InvokeAI * Add cross-attention optimization from InvokeAI (~30% speed improvement on MPS) * Add command line option for it * Make it default when CUDA is unavailable --- modules/sd_hijack_optimizations.py | 79 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) (limited to 'modules/sd_hijack_optimizations.py') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 3349b9c3..870226c5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ import math import sys import traceback +import psutil import torch from torch import einsum @@ -116,6 +117,84 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) +# -- From https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py (with hypernetworks support added) -- + +mem_total_gb = psutil.virtual_memory().total // (1 << 30) + +def einsum_op_compvis(q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + +def einsum_op_slice_0(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + +def einsum_op_slice_1(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) + return r + +def einsum_op_mps_v1(q, k, v): + if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + return einsum_op_compvis(q, k, v) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + return einsum_op_slice_1(q, k, v, slice_size) + +def einsum_op_mps_v2(q, k, v): + if mem_total_gb > 8 and q.shape[1] <= 4096: + return einsum_op_compvis(q, k, v) + else: + return einsum_op_slice_0(q, k, v, 1) + +def einsum_op_tensor_mem(q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return einsum_op_slice_0(q, k, v, q.shape[0] // div) + return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + +def einsum_op(q, k, v): + if q.device.type == 'mps': + if mem_total_gb >= 32: + return einsum_op_mps_v1(q, k, v) + return einsum_op_mps_v2(q, k, v) + + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return einsum_op_tensor_mem(q, k, v, 32) + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + hypernetwork = shared.loaded_hypernetwork + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k = self.to_k(hypernetwork_layers[0](context)) * self.scale + v = self.to_v(hypernetwork_layers[1](context)) + else: + k = self.to_k(context) * self.scale + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + +# -- End of code from https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py -- + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) -- cgit v1.2.1