From 948533950c9db5069a874d925fadd50bac00fdb5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 11:09:51 +0300 Subject: replace duplicate code with a function --- modules/hypernetwork.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 498bc9d8..7bbc443e 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -64,21 +64,26 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None +def apply_hypernetwork(hypernetwork, context): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + def attention_CrossAttention_forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k = self.to_k(hypernetwork_layers[0](context)) - v = self.to_v(hypernetwork_layers[1](context)) - else: - k = self.to_k(context) - v = self.to_v(context) + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) -- cgit v1.2.1