From bad7cb29cecac51c5c0f39afec332b007ed73133 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:17:52 +0300 Subject: added support for hypernetworks (???) --- modules/hypernetwork.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 modules/hypernetwork.py (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py new file mode 100644 index 00000000..9ed1eed9 --- /dev/null +++ b/modules/hypernetwork.py @@ -0,0 +1,55 @@ +import glob +import os +import torch +from modules import devices + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + self.load_state_dict(state_dict, strict=True) + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, filename): + self.filename = filename + self.name = os.path.splitext(os.path.basename(filename))[0] + self.layers = {} + + state_dict = torch.load(filename, map_location='cpu') + for size, sd in state_dict.items(): + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + hn = Hypernetwork(filename) + res[hn.name] = hn + + return res + +def apply(self, x, context=None, mask=None, original=None): + + + if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: + if context.shape[1] == 77 and CrossAttention.noise_cond: + context = context + (torch.randn_like(context) * 0.1) + h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] + k = self.to_k(h_k(context)) + v = self.to_v(h_v(context)) + else: + k = self.to_k(context) + v = self.to_v(context) -- cgit v1.2.1 From 97bc0b9504572d2df80598d0b694703bcd626de6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 13:22:50 +0300 Subject: do not stop working on failed hypernetwork load --- modules/hypernetwork.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 9ed1eed9..c5cf4afa 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -1,5 +1,8 @@ import glob import os +import sys +import traceback + import torch from modules import devices @@ -36,8 +39,12 @@ def load_hypernetworks(path): res = {} for filename in glob.iglob(path + '**/*.pt', recursive=True): - hn = Hypernetwork(filename) - res[hn.name] = hn + try: + hn = Hypernetwork(filename) + res[hn.name] = hn + except Exception: + print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) return res -- cgit v1.2.1 From f7c787eb7c295c27439f4fbdf78c26b8389560be Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 16:39:51 +0300 Subject: make it possible to use hypernetworks without opt split attention --- modules/hypernetwork.py | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index c5cf4afa..c7b86682 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -4,7 +4,12 @@ import sys import traceback import torch -from modules import devices + +from ldm.util import default +from modules import devices, shared +import torch +from torch import einsum +from einops import rearrange, repeat class HypernetworkModule(torch.nn.Module): @@ -48,15 +53,36 @@ def load_hypernetworks(path): return res -def apply(self, x, context=None, mask=None, original=None): +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) - if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: - if context.shape[1] == 77 and CrossAttention.noise_cond: - context = context + (torch.randn_like(context) * 0.1) - h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] - k = self.to_k(h_k(context)) - v = self.to_v(h_v(context)) + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k = self.to_k(hypernetwork_layers[0](context)) + v = self.to_v(hypernetwork_layers[1](context)) else: k = self.to_k(context) v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) -- cgit v1.2.1 From 772db721a52da374d627b60994222051f26c27a7 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Fri, 7 Oct 2022 23:02:07 +0900 Subject: fix glob path in hypernetwork.py --- modules/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index c7b86682..7f062242 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -43,7 +43,7 @@ class Hypernetwork: def load_hypernetworks(path): res = {} - for filename in glob.iglob(path + '**/*.pt', recursive=True): + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): try: hn = Hypernetwork(filename) res[hn.name] = hn -- cgit v1.2.1 From 122d42687b97ec4df4c2a8c335d2de385cd1f1a1 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 22:37:35 -0400 Subject: Fix VRAM Issue by only loading in hypernetwork when selected in settings --- modules/hypernetwork.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 7f062242..19f1c227 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -40,18 +40,25 @@ class Hypernetwork: self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + print(f"Loading hypernetwork {filename}") + path = shared.hypernetworks.get(filename, None) + if (path is not None): try: - hn = Hypernetwork(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork(path) except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - - return res + else: + shared.loaded_hypernetwork = None def attention_CrossAttention_forward(self, x, context=None, mask=None): @@ -60,7 +67,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: -- cgit v1.2.1 From 542a3d3a4a00c1383fbdaf938ceefef87cf834bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:33:22 +0300 Subject: fix btoken hypernetworks in XY plot --- modules/hypernetwork.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 19f1c227..498bc9d8 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -49,15 +49,18 @@ def list_hypernetworks(path): def load_hypernetwork(filename): - print(f"Loading hypernetwork {filename}") path = shared.hypernetworks.get(filename, None) - if (path is not None): + if path is not None: + print(f"Loading hypernetwork {filename}") try: shared.loaded_hypernetwork = Hypernetwork(path) except Exception: print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + shared.loaded_hypernetwork = None -- cgit v1.2.1