From bad7cb29cecac51c5c0f39afec332b007ed73133 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:17:52 +0300 Subject: added support for hypernetworks (???) --- modules/hypernetwork.py | 55 ++++++++++++++++++++++++++++++++++++++ modules/sd_hijack_optimizations.py | 17 ++++++++++-- modules/shared.py | 9 ++++++- 3 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 modules/hypernetwork.py (limited to 'modules') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py new file mode 100644 index 00000000..9ed1eed9 --- /dev/null +++ b/modules/hypernetwork.py @@ -0,0 +1,55 @@ +import glob +import os +import torch +from modules import devices + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + self.load_state_dict(state_dict, strict=True) + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, filename): + self.filename = filename + self.name = os.path.splitext(os.path.basename(filename))[0] + self.layers = {} + + state_dict = torch.load(filename, map_location='cpu') + for size, sd in state_dict.items(): + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + hn = Hypernetwork(filename) + res[hn.name] = hn + + return res + +def apply(self, x, context=None, mask=None, original=None): + + + if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: + if context.shape[1] == 77 and CrossAttention.noise_cond: + context = context + (torch.randn_like(context) * 0.1) + h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] + k = self.to_k(h_k(context)) + v = self.to_v(h_v(context)) + else: + k = self.to_k(context) + v = self.to_v(context) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..d9cca485 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default from einops import rearrange +from modules import shared + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): @@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) + + k_in *= self.scale + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..879d8424 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -76,6 +76,12 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file +hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) + + +def selected_hypernetwork(): + return hypernetworks.get(opts.sd_hypernetwork, None) + class State: interrupted = False @@ -206,6 +212,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), -- cgit v1.2.1