From 1c6efdbba774d603c592debaccd6f5ad827bd1b2 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:16:01 -0700 Subject: inference working but SLOW --- extensions-builtin/Lora/networks.py | 42 ++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index bd1f1b75..e5e73450 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -169,6 +169,10 @@ def load_network(name, network_on_disk): else: emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict + + #if key_network_without_network_parts == "oft_unet": + # print(key_network_without_network_parts) + # pass key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -185,15 +189,39 @@ def load_network(name, network_on_disk): elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) - elif sd_module is None and "oft_unet" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + # UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] + # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] + # TODO: Change matchedm odules based on whether all linear, conv, etc + + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + #key_no_suffix = key.rsplit("_to_", 1)[0] + ## Match all modules of class CrossAttention + #replace_module_list = [] + #for module_type in UNET_TARGET_REPLACE_MODULE_ATTN_ONLY: + # replace_module_list += [module for k, module in shared.sd_model.network_layer_mapping.items() if module_type in module.__class__.__name__] + + #matched_module = replace_module_list.get(key_no_suffix, None) + #if key.endswith('to_q'): + # sd_module = matched_module.to_q or None + #if key.endswith('to_k'): + # sd_module = matched_module.to_k or None + #if key.endswith('to_v'): + # sd_module = matched_module.to_v or None + #if key.endswith('to_out_0'): + # sd_module = matched_module.to_out[0] or None + #if key.endswith('to_out_1'): + # sd_module = matched_module.to_out[1] or None + + if sd_module is None: keys_failed_to_match[key_network] = key continue @@ -214,6 +242,14 @@ def load_network(name, network_on_disk): raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module + + # replaces forward method of original Linear + # applied_to_count = 0 + #for key, created_module in net.modules.items(): + # if isinstance(created_module, network_oft.NetworkModuleOFT): + # net_module.apply_to() + #applied_to_count += 1 + # print(f'Applied OFT modules: {applied_to_count}') embeddings = {} for emb_name, data in bundle_embeddings.items(): -- cgit v1.2.1