diff options
58 files changed, 2346 insertions, 904 deletions
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index e9370cc0..3dafaf8d 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -41,6 +41,7 @@ jobs: --skip-prepare-environment --skip-torch-cuda-test --test-server + --do-not-download-clip --no-half --disable-opt-split-attention --use-cpu all diff --git a/.github/workflows/warns_merge_master.yml b/.github/workflows/warns_merge_master.yml new file mode 100644 index 00000000..ae2aab6b --- /dev/null +++ b/.github/workflows/warns_merge_master.yml @@ -0,0 +1,19 @@ +name: Pull requests can't target master branch + +"on": + pull_request: + types: + - opened + - synchronize + - reopened + branches: + - master + +jobs: + check: + runs-on: ubuntu-latest + steps: + - name: Warning marge into master + run: | + echo -e "::warning::This pull request directly merge into \"master\" branch, normally development happens on \"dev\" branch." + exit 1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 925403a9..63a2c7d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,68 @@ +## 1.5.0
+
+### Features:
+ * SD XL support
+ * user metadata system for custom networks
+ * extended Lora metadata editor: set activation text, default weight, view tags, training info
+ * Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension)
+ * show github stars for extenstions
+ * img2img batch mode can read extra stuff from png info
+ * img2img batch works with subdirectories
+ * hotkeys to move prompt elements: alt+left/right
+ * restyle time taken/VRAM display
+ * add textual inversion hashes to infotext
+ * optimization: cache git extension repo information
+ * move generate button next to the generated picture for mobile clients
+ * hide cards for networks of incompatible Stable Diffusion version in Lora extra networks interface
+ * skip installing packages with pip if they all are already installed - startup speedup of about 2 seconds
+
+### Minor:
+ * checkbox to check/uncheck all extensions in the Installed tab
+ * add gradio user to infotext and to filename patterns
+ * allow gif for extra network previews
+ * add options to change colors in grid
+ * use natural sort for items in extra networks
+ * Mac: use empty_cache() from torch 2 to clear VRAM
+ * added automatic support for installing the right libraries for Navi3 (AMD)
+ * add option SWIN_torch_compile to accelerate SwinIR upscale
+ * suppress printing TI embedding info at start to console by default
+ * speedup extra networks listing
+ * added `[none]` filename token.
+ * removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs)
+ * add always_discard_next_to_last_sigma option to XYZ plot
+ * automatically switch to 32-bit float VAE if the generated picture has NaNs without the need for `--no-half-vae` commandline flag.
+
+### Extensions and API:
+ * api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop
+ * allow Script to have custom metaclass
+ * add model exists status check /sdapi/v1/options
+ * rename --add-stop-route to --api-server-stop
+ * add `before_hr` script callback
+ * add callback `after_extra_networks_activate`
+ * disable rich exception output in console for API by default, use WEBUI_RICH_EXCEPTIONS env var to enable
+ * return http 404 when thumb file not found
+ * allow replacing extensions index with environment variable
+
+### Bug Fixes:
+ * fix for catch errors when retrieving extension index #11290
+ * fix very slow loading speed of .safetensors files when reading from network drives
+ * API cache cleanup
+ * fix UnicodeEncodeError when writing to file CLIP Interrogator batch mode
+ * fix warning of 'has_mps' deprecated from PyTorch
+ * fix problem with extra network saving images as previews losing generation info
+ * fix throwing exception when trying to resize image with I;16 mode
+ * fix for #11534: canvas zoom and pan extension hijacking shortcut keys
+ * fixed launch script to be runnable from any directory
+ * don't add "Seed Resize: -1x-1" to API image metadata
+ * correctly remove end parenthesis with ctrl+up/down
+ * fixing --subpath on newer gradio version
+ * fix: check fill size none zero when resize (fixes #11425)
+ * use submit and blur for quick settings textbox
+ * save img2img batch with images.save_image()
+ * prevent running preload.py for disabled extensions
+ * fix: previously, model name was added together with directory name to infotext and to [model_name] filename pattern; directory name is now not included
+
+
## 1.4.1
### Bug Fixes:
@@ -168,5 +168,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
+- LyCORIS - KohakuBlueleaf
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 66ee9c85..ba2945c6 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -1,5 +1,5 @@ from modules import extra_networks, shared
-import lora
+import networks
class ExtraNetworkLora(extra_networks.ExtraNetwork):
@@ -9,24 +9,38 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list):
additional = shared.opts.sd_lora
- if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
+ if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
- multipliers = []
+ te_multipliers = []
+ unet_multipliers = []
+ dyn_dims = []
for params in params_list:
assert params.items
- names.append(params.items[0])
- multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+ names.append(params.positional[0])
- lora.load_loras(names, multipliers)
+ te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
+ te_multiplier = float(params.named.get("te", te_multiplier))
+
+ unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
+ unet_multiplier = float(params.named.get("unet", unet_multiplier))
+
+ dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
+ dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
+
+ te_multipliers.append(te_multiplier)
+ unet_multipliers.append(unet_multiplier)
+ dyn_dims.append(dyn_dim)
+
+ networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
if shared.opts.lora_add_hashes_to_infotext:
- lora_hashes = []
- for item in lora.loaded_loras:
- shorthash = item.lora_on_disk.shorthash
+ network_hashes = []
+ for item in networks.loaded_networks:
+ shorthash = item.network_on_disk.shorthash
if not shorthash:
continue
@@ -36,10 +50,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): alias = alias.replace(":", "").replace(",", "")
- lora_hashes.append(f"{alias}: {shorthash}")
+ network_hashes.append(f"{alias}: {shorthash}")
- if lora_hashes:
- p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
+ if network_hashes:
+ p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p):
pass
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 302490fb..9365aa74 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -1,532 +1,9 @@ -import os
-import re
-import torch
-from typing import Union
+import networks
-from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
+list_available_loras = networks.list_available_networks
-metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
-
-re_digits = re.compile(r"\d+")
-re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
-re_compiled = {}
-
-suffix_conversion = {
- "attentions": {},
- "resnets": {
- "conv1": "in_layers_2",
- "conv2": "out_layers_3",
- "time_emb_proj": "emb_layers_1",
- "conv_shortcut": "skip_connection",
- }
-}
-
-
-def convert_diffusers_name_to_compvis(key, is_sd2):
- def match(match_list, regex_text):
- regex = re_compiled.get(regex_text)
- if regex is None:
- regex = re.compile(regex_text)
- re_compiled[regex_text] = regex
-
- r = re.match(regex, key)
- if not r:
- return False
-
- match_list.clear()
- match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
- return True
-
- m = []
-
- if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
- suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
- return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
-
- if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
- suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
- return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
-
- if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
- suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
- return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
-
- if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
- return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
-
- if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
- return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
-
- if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
- if is_sd2:
- if 'mlp_fc1' in m[1]:
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
- elif 'mlp_fc2' in m[1]:
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
- else:
- return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
-
- return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
-
- if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
- if 'mlp_fc1' in m[1]:
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
- elif 'mlp_fc2' in m[1]:
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
- else:
- return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
-
- return key
-
-
-class LoraOnDisk:
- def __init__(self, name, filename):
- self.name = name
- self.filename = filename
- self.metadata = {}
- self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
-
- if self.is_safetensors:
- try:
- self.metadata = sd_models.read_metadata_from_safetensors(filename)
- except Exception as e:
- errors.display(e, f"reading lora {filename}")
-
- if self.metadata:
- m = {}
- for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
- m[k] = v
-
- self.metadata = m
-
- self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
- self.alias = self.metadata.get('ss_output_name', self.name)
-
- self.hash = None
- self.shorthash = None
- self.set_hash(
- self.metadata.get('sshs_model_hash') or
- hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
- ''
- )
-
- def set_hash(self, v):
- self.hash = v
- self.shorthash = self.hash[0:12]
-
- if self.shorthash:
- available_lora_hash_lookup[self.shorthash] = self
-
- def read_hash(self):
- if not self.hash:
- self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
-
- def get_alias(self):
- if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
- return self.name
- else:
- return self.alias
-
-
-class LoraModule:
- def __init__(self, name, lora_on_disk: LoraOnDisk):
- self.name = name
- self.lora_on_disk = lora_on_disk
- self.multiplier = 1.0
- self.modules = {}
- self.mtime = None
-
- self.mentioned_name = None
- """the text that was used to add lora to prompt - can be either name or an alias"""
-
-
-class LoraUpDownModule:
- def __init__(self):
- self.up = None
- self.down = None
- self.alpha = None
-
-
-def assign_lora_names_to_compvis_modules(sd_model):
- lora_layer_mapping = {}
-
- if shared.sd_model.is_sdxl:
- for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
- if not hasattr(embedder, 'wrapped'):
- continue
-
- for name, module in embedder.wrapped.named_modules():
- lora_name = f'{i}_{name.replace(".", "_")}'
- lora_layer_mapping[lora_name] = module
- module.lora_layer_name = lora_name
- else:
- for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
- lora_name = name.replace(".", "_")
- lora_layer_mapping[lora_name] = module
- module.lora_layer_name = lora_name
-
- for name, module in shared.sd_model.model.named_modules():
- lora_name = name.replace(".", "_")
- lora_layer_mapping[lora_name] = module
- module.lora_layer_name = lora_name
-
- sd_model.lora_layer_mapping = lora_layer_mapping
-
-
-def load_lora(name, lora_on_disk):
- lora = LoraModule(name, lora_on_disk)
- lora.mtime = os.path.getmtime(lora_on_disk.filename)
-
- sd = sd_models.read_state_dict(lora_on_disk.filename)
-
- # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
- if not hasattr(shared.sd_model, 'lora_layer_mapping'):
- assign_lora_names_to_compvis_modules(shared.sd_model)
-
- keys_failed_to_match = {}
- is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
-
- for key_lora, weight in sd.items():
- key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)
-
- key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
- sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
-
- if sd_module is None:
- m = re_x_proj.match(key)
- if m:
- sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
-
- # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
- if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
- key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
- sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
- elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
- key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
- sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
-
- if sd_module is None:
- keys_failed_to_match[key_lora] = key
- continue
-
- lora_module = lora.modules.get(key, None)
- if lora_module is None:
- lora_module = LoraUpDownModule()
- lora.modules[key] = lora_module
-
- if lora_key == "alpha":
- lora_module.alpha = weight.item()
- continue
-
- if type(sd_module) == torch.nn.Linear:
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(sd_module) == torch.nn.MultiheadAttention:
- module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
- elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
- module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
- else:
- print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
- continue
- raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")
-
- with torch.no_grad():
- module.weight.copy_(weight)
-
- module.to(device=devices.cpu, dtype=devices.dtype)
-
- if lora_key == "lora_up.weight":
- lora_module.up = module
- elif lora_key == "lora_down.weight":
- lora_module.down = module
- else:
- raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")
-
- if keys_failed_to_match:
- print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
-
- return lora
-
-
-def load_loras(names, multipliers=None):
- already_loaded = {}
-
- for lora in loaded_loras:
- if lora.name in names:
- already_loaded[lora.name] = lora
-
- loaded_loras.clear()
-
- loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
- if any(x is None for x in loras_on_disk):
- list_available_loras()
-
- loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
-
- failed_to_load_loras = []
-
- for i, name in enumerate(names):
- lora = already_loaded.get(name, None)
-
- lora_on_disk = loras_on_disk[i]
-
- if lora_on_disk is not None:
- if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
- try:
- lora = load_lora(name, lora_on_disk)
- except Exception as e:
- errors.display(e, f"loading Lora {lora_on_disk.filename}")
- continue
-
- lora.mentioned_name = name
-
- lora_on_disk.read_hash()
-
- if lora is None:
- failed_to_load_loras.append(name)
- print(f"Couldn't find Lora with name {name}")
- continue
-
- lora.multiplier = multipliers[i] if multipliers else 1.0
- loaded_loras.append(lora)
-
- if failed_to_load_loras:
- sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
-
-
-def lora_calc_updown(lora, module, target):
- with torch.no_grad():
- up = module.up.weight.to(target.device, dtype=target.dtype)
- down = module.down.weight.to(target.device, dtype=target.dtype)
-
- if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
- updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
- elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
- updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
- else:
- updown = up @ down
-
- updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
-
- return updown
-
-
-def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
- weights_backup = getattr(self, "lora_weights_backup", None)
-
- if weights_backup is None:
- return
-
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
- else:
- self.weight.copy_(weights_backup)
-
-
-def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
- """
- Applies the currently selected set of Loras to the weights of torch layer self.
- If weights already have this particular set of loras applied, does nothing.
- If not, restores orginal weights from backup and alters weights according to loras.
- """
-
- lora_layer_name = getattr(self, 'lora_layer_name', None)
- if lora_layer_name is None:
- return
-
- current_names = getattr(self, "lora_current_names", ())
- wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
-
- weights_backup = getattr(self, "lora_weights_backup", None)
- if weights_backup is None:
- if isinstance(self, torch.nn.MultiheadAttention):
- weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
- else:
- weights_backup = self.weight.to(devices.cpu, copy=True)
-
- self.lora_weights_backup = weights_backup
-
- if current_names != wanted_names:
- lora_restore_weights_from_backup(self)
-
- for lora in loaded_loras:
- module = lora.modules.get(lora_layer_name, None)
- if module is not None and hasattr(self, 'weight'):
- self.weight += lora_calc_updown(lora, module, self.weight)
- continue
-
- module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
- module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
- module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
- module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
-
- if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
- updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
- updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
- updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
- updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
-
- self.in_proj_weight += updown_qkv
- self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
- continue
-
- if module is None:
- continue
-
- print(f'failed to calculate lora weights for layer {lora_layer_name}')
-
- self.lora_current_names = wanted_names
-
-
-def lora_forward(module, input, original_forward):
- """
- Old way of applying Lora by executing operations during layer's forward.
- Stacking many loras this way results in big performance degradation.
- """
-
- if len(loaded_loras) == 0:
- return original_forward(module, input)
-
- input = devices.cond_cast_unet(input)
-
- lora_restore_weights_from_backup(module)
- lora_reset_cached_weight(module)
-
- res = original_forward(module, input)
-
- lora_layer_name = getattr(module, 'lora_layer_name', None)
- for lora in loaded_loras:
- module = lora.modules.get(lora_layer_name, None)
- if module is None:
- continue
-
- module.up.to(device=devices.device)
- module.down.to(device=devices.device)
-
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
-
- return res
-
-
-def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- self.lora_current_names = ()
- self.lora_weights_backup = None
-
-
-def lora_Linear_forward(self, input):
- if shared.opts.lora_functional:
- return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
-
- lora_apply_weights(self)
-
- return torch.nn.Linear_forward_before_lora(self, input)
-
-
-def lora_Linear_load_state_dict(self, *args, **kwargs):
- lora_reset_cached_weight(self)
-
- return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
-
-
-def lora_Conv2d_forward(self, input):
- if shared.opts.lora_functional:
- return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
-
- lora_apply_weights(self)
-
- return torch.nn.Conv2d_forward_before_lora(self, input)
-
-
-def lora_Conv2d_load_state_dict(self, *args, **kwargs):
- lora_reset_cached_weight(self)
-
- return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
-
-
-def lora_MultiheadAttention_forward(self, *args, **kwargs):
- lora_apply_weights(self)
-
- return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
-
-
-def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
- lora_reset_cached_weight(self)
-
- return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
-
-
-def list_available_loras():
- available_loras.clear()
- available_lora_aliases.clear()
- forbidden_lora_aliases.clear()
- available_lora_hash_lookup.clear()
- forbidden_lora_aliases.update({"none": 1, "Addams": 1})
-
- os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
-
- candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
- for filename in candidates:
- if os.path.isdir(filename):
- continue
-
- name = os.path.splitext(os.path.basename(filename))[0]
- try:
- entry = LoraOnDisk(name, filename)
- except OSError: # should catch FileNotFoundError and PermissionError etc.
- errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
- continue
-
- available_loras[name] = entry
-
- if entry.alias in available_lora_aliases:
- forbidden_lora_aliases[entry.alias.lower()] = 1
-
- available_lora_aliases[name] = entry
- available_lora_aliases[entry.alias] = entry
-
-
-re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
-
-
-def infotext_pasted(infotext, params):
- if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
- return # if the other extension is active, it will handle those fields, no need to do anything
-
- added = []
-
- for k in params:
- if not k.startswith("AddNet Model "):
- continue
-
- num = k[13:]
-
- if params.get("AddNet Module " + num) != "LoRA":
- continue
-
- name = params.get("AddNet Model " + num)
- if name is None:
- continue
-
- m = re_lora_name.match(name)
- if m:
- name = m.group(1)
-
- multiplier = params.get("AddNet Weight A " + num, "1.0")
-
- added.append(f"<lora:{name}:{multiplier}>")
-
- if added:
- params["Prompt"] += "\n" + "".join(added)
-
-
-available_loras = {}
-available_lora_aliases = {}
-available_lora_hash_lookup = {}
-forbidden_lora_aliases = {}
-loaded_loras = []
-
-list_available_loras()
+available_loras = networks.available_networks
+available_lora_aliases = networks.available_network_aliases
+available_lora_hash_lookup = networks.available_network_hash_lookup
+forbidden_lora_aliases = networks.forbidden_network_aliases
+loaded_loras = networks.loaded_networks
diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py new file mode 100644 index 00000000..279b34bc --- /dev/null +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -0,0 +1,21 @@ +import torch
+
+
+def make_weight_cp(t, wa, wb):
+ temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
+ return torch.einsum('i j k l, i r -> r j k l', temp, wa)
+
+
+def rebuild_conventional(up, down, shape, dyn_dim=None):
+ up = up.reshape(up.size(0), -1)
+ down = down.reshape(down.size(0), -1)
+ if dyn_dim is not None:
+ up = up[:, :dyn_dim]
+ down = down[:dyn_dim, :]
+ return (up @ down).reshape(shape)
+
+
+def rebuild_cp_decomposition(up, down, mid):
+ up = up.reshape(up.size(0), -1)
+ down = down.reshape(down.size(0), -1)
+ return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py new file mode 100644 index 00000000..8ecfa29a --- /dev/null +++ b/extensions-builtin/Lora/network.py @@ -0,0 +1,154 @@ +import os
+from collections import namedtuple
+import enum
+
+from modules import sd_models, cache, errors, hashes, shared
+
+NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
+
+metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
+
+
+class SdVersion(enum.Enum):
+ Unknown = 1
+ SD1 = 2
+ SD2 = 3
+ SDXL = 4
+
+
+class NetworkOnDisk:
+ def __init__(self, name, filename):
+ self.name = name
+ self.filename = filename
+ self.metadata = {}
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+
+ def read_metadata():
+ metadata = sd_models.read_metadata_from_safetensors(filename)
+ metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
+
+ return metadata
+
+ if self.is_safetensors:
+ try:
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
+ except Exception as e:
+ errors.display(e, f"reading lora {filename}")
+
+ if self.metadata:
+ m = {}
+ for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
+ m[k] = v
+
+ self.metadata = m
+
+ self.alias = self.metadata.get('ss_output_name', self.name)
+
+ self.hash = None
+ self.shorthash = None
+ self.set_hash(
+ self.metadata.get('sshs_model_hash') or
+ hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
+ ''
+ )
+
+ self.sd_version = self.detect_version()
+
+ def detect_version(self):
+ if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
+ return SdVersion.SDXL
+ elif str(self.metadata.get('ss_v2', "")) == "True":
+ return SdVersion.SD2
+ elif len(self.metadata):
+ return SdVersion.SD1
+
+ return SdVersion.Unknown
+
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
+ if self.shorthash:
+ import networks
+ networks.available_network_hash_lookup[self.shorthash] = self
+
+ def read_hash(self):
+ if not self.hash:
+ self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
+
+ def get_alias(self):
+ import networks
+ if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
+ return self.name
+ else:
+ return self.alias
+
+
+class Network: # LoraModule
+ def __init__(self, name, network_on_disk: NetworkOnDisk):
+ self.name = name
+ self.network_on_disk = network_on_disk
+ self.te_multiplier = 1.0
+ self.unet_multiplier = 1.0
+ self.dyn_dim = None
+ self.modules = {}
+ self.mtime = None
+
+ self.mentioned_name = None
+ """the text that was used to add the network to prompt - can be either name or an alias"""
+
+
+class ModuleType:
+ def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
+ return None
+
+
+class NetworkModule:
+ def __init__(self, net: Network, weights: NetworkWeights):
+ self.network = net
+ self.network_key = weights.network_key
+ self.sd_key = weights.sd_key
+ self.sd_module = weights.sd_module
+
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+
+ self.dim = None
+ self.bias = weights.w.get("bias")
+ self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
+ self.scale = weights.w["scale"].item() if "scale" in weights.w else None
+
+ def multiplier(self):
+ if 'transformer' in self.sd_key[:20]:
+ return self.network.te_multiplier
+ else:
+ return self.network.unet_multiplier
+
+ def calc_scale(self):
+ if self.scale is not None:
+ return self.scale
+ if self.dim is not None and self.alpha is not None:
+ return self.alpha / self.dim
+
+ return 1.0
+
+ def finalize_updown(self, updown, orig_weight, output_shape):
+ if self.bias is not None:
+ updown = updown.reshape(self.bias.shape)
+ updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = updown.reshape(output_shape)
+
+ if len(output_shape) == 4:
+ updown = updown.reshape(output_shape)
+
+ if orig_weight.size().numel() == updown.size().numel():
+ updown = updown.reshape(orig_weight.shape)
+
+ return updown * self.calc_scale() * self.multiplier()
+
+ def calc_updown(self, target):
+ raise NotImplementedError()
+
+ def forward(self, x, y):
+ raise NotImplementedError()
+
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py new file mode 100644 index 00000000..109b4c2c --- /dev/null +++ b/extensions-builtin/Lora/network_full.py @@ -0,0 +1,22 @@ +import network
+
+
+class ModuleTypeFull(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["diff"]):
+ return NetworkModuleFull(net, weights)
+
+ return None
+
+
+class NetworkModuleFull(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ self.weight = weights.w.get("diff")
+
+ def calc_updown(self, orig_weight):
+ output_shape = self.weight.shape
+ updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py new file mode 100644 index 00000000..5fcb0695 --- /dev/null +++ b/extensions-builtin/Lora/network_hada.py @@ -0,0 +1,55 @@ +import lyco_helpers
+import network
+
+
+class ModuleTypeHada(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
+ return NetworkModuleHada(net, weights)
+
+ return None
+
+
+class NetworkModuleHada(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ if hasattr(self.sd_module, 'weight'):
+ self.shape = self.sd_module.weight.shape
+
+ self.w1a = weights.w["hada_w1_a"]
+ self.w1b = weights.w["hada_w1_b"]
+ self.dim = self.w1b.shape[0]
+ self.w2a = weights.w["hada_w2_a"]
+ self.w2b = weights.w["hada_w2_b"]
+
+ self.t1 = weights.w.get("hada_t1")
+ self.t2 = weights.w.get("hada_t2")
+
+ def calc_updown(self, orig_weight):
+ w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ output_shape = [w1a.size(0), w1b.size(1)]
+
+ if self.t1 is not None:
+ output_shape = [w1a.size(1), w1b.size(1)]
+ t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
+ output_shape += t1.shape[2:]
+ else:
+ if len(w1b.shape) == 4:
+ output_shape += w1b.shape[2:]
+ updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
+
+ if self.t2 is not None:
+ t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
+ else:
+ updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
+
+ updown = updown1 * updown2
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py new file mode 100644 index 00000000..7edc4249 --- /dev/null +++ b/extensions-builtin/Lora/network_ia3.py @@ -0,0 +1,30 @@ +import network
+
+
+class ModuleTypeIa3(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["weight"]):
+ return NetworkModuleIa3(net, weights)
+
+ return None
+
+
+class NetworkModuleIa3(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ self.w = weights.w["weight"]
+ self.on_input = weights.w["on_input"].item()
+
+ def calc_updown(self, orig_weight):
+ w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ output_shape = [w.size(0), orig_weight.size(1)]
+ if self.on_input:
+ output_shape.reverse()
+ else:
+ w = w.reshape(-1, 1)
+
+ updown = orig_weight * w
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py new file mode 100644 index 00000000..340acdab --- /dev/null +++ b/extensions-builtin/Lora/network_lokr.py @@ -0,0 +1,64 @@ +import torch
+
+import lyco_helpers
+import network
+
+
+class ModuleTypeLokr(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
+ has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
+ if has_1 and has_2:
+ return NetworkModuleLokr(net, weights)
+
+ return None
+
+
+def make_kron(orig_shape, w1, w2):
+ if len(w2.shape) == 4:
+ w1 = w1.unsqueeze(2).unsqueeze(2)
+ w2 = w2.contiguous()
+ return torch.kron(w1, w2).reshape(orig_shape)
+
+
+class NetworkModuleLokr(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ self.w1 = weights.w.get("lokr_w1")
+ self.w1a = weights.w.get("lokr_w1_a")
+ self.w1b = weights.w.get("lokr_w1_b")
+ self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
+ self.w2 = weights.w.get("lokr_w2")
+ self.w2a = weights.w.get("lokr_w2_a")
+ self.w2b = weights.w.get("lokr_w2_b")
+ self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
+ self.t2 = weights.w.get("lokr_t2")
+
+ def calc_updown(self, orig_weight):
+ if self.w1 is not None:
+ w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
+ else:
+ w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1 = w1a @ w1b
+
+ if self.w2 is not None:
+ w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
+ elif self.t2 is None:
+ w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2 = w2a @ w2b
+ else:
+ t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
+
+ output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
+ if len(orig_weight.shape) == 4:
+ output_shape = orig_weight.shape
+
+ updown = make_kron(output_shape, w1, w2)
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py new file mode 100644 index 00000000..26c0a72c --- /dev/null +++ b/extensions-builtin/Lora/network_lora.py @@ -0,0 +1,86 @@ +import torch
+
+import lyco_helpers
+import network
+from modules import devices
+
+
+class ModuleTypeLora(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
+ return NetworkModuleLora(net, weights)
+
+ return None
+
+
+class NetworkModuleLora(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ self.up_model = self.create_module(weights.w, "lora_up.weight")
+ self.down_model = self.create_module(weights.w, "lora_down.weight")
+ self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
+
+ self.dim = weights.w["lora_down.weight"].shape[0]
+
+ def create_module(self, weights, key, none_ok=False):
+ weight = weights.get(key)
+
+ if weight is None and none_ok:
+ return None
+
+ is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
+ is_conv = type(self.sd_module) in [torch.nn.Conv2d]
+
+ if is_linear:
+ weight = weight.reshape(weight.shape[0], -1)
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif is_conv and key == "lora_down.weight" or key == "dyn_up":
+ if len(weight.shape) == 2:
+ weight = weight.reshape(weight.shape[0], -1, 1, 1)
+
+ if weight.shape[2] != 1 or weight.shape[3] != 1:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
+ else:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ elif is_conv and key == "lora_mid.weight":
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
+ elif is_conv and key == "lora_up.weight" or key == "dyn_down":
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ else:
+ raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
+
+ with torch.no_grad():
+ if weight.shape != module.weight.shape:
+ weight = weight.reshape(module.weight.shape)
+ module.weight.copy_(weight)
+
+ module.to(device=devices.cpu, dtype=devices.dtype)
+ module.weight.requires_grad_(False)
+
+ return module
+
+ def calc_updown(self, orig_weight):
+ up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ output_shape = [up.size(0), down.size(1)]
+ if self.mid_model is not None:
+ # cp-decomposition
+ mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
+ output_shape += mid.shape[2:]
+ else:
+ if len(down.shape) == 4:
+ output_shape += down.shape[2:]
+ updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
+
+ return self.finalize_updown(updown, orig_weight, output_shape)
+
+ def forward(self, x, y):
+ self.up_model.to(device=devices.device)
+ self.down_model.to(device=devices.device)
+
+ return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
+
+
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py new file mode 100644 index 00000000..af8188e3 --- /dev/null +++ b/extensions-builtin/Lora/networks.py @@ -0,0 +1,463 @@ +import os
+import re
+
+import network
+import network_lora
+import network_hada
+import network_ia3
+import network_lokr
+import network_full
+
+import torch
+from typing import Union
+
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+
+module_types = [
+ network_lora.ModuleTypeLora(),
+ network_hada.ModuleTypeHada(),
+ network_ia3.ModuleTypeIa3(),
+ network_lokr.ModuleTypeLokr(),
+ network_full.ModuleTypeFull(),
+]
+
+
+re_digits = re.compile(r"\d+")
+re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
+re_compiled = {}
+
+suffix_conversion = {
+ "attentions": {},
+ "resnets": {
+ "conv1": "in_layers_2",
+ "conv2": "out_layers_3",
+ "time_emb_proj": "emb_layers_1",
+ "conv_shortcut": "skip_connection",
+ }
+}
+
+
+def convert_diffusers_name_to_compvis(key, is_sd2):
+ def match(match_list, regex_text):
+ regex = re_compiled.get(regex_text)
+ if regex is None:
+ regex = re.compile(regex_text)
+ re_compiled[regex_text] = regex
+
+ r = re.match(regex, key)
+ if not r:
+ return False
+
+ match_list.clear()
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
+ return True
+
+ m = []
+
+ if match(m, r"lora_unet_conv_in(.*)"):
+ return f'diffusion_model_input_blocks_0_0{m[0]}'
+
+ if match(m, r"lora_unet_conv_out(.*)"):
+ return f'diffusion_model_out_2{m[0]}'
+
+ if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
+ return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
+
+ if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
+
+ if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
+ return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
+
+ if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
+
+ if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
+ return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
+
+ if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
+ return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
+
+ if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
+ if is_sd2:
+ if 'mlp_fc1' in m[1]:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
+ elif 'mlp_fc2' in m[1]:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
+ else:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
+
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
+
+ if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
+ if 'mlp_fc1' in m[1]:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
+ elif 'mlp_fc2' in m[1]:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
+ else:
+ return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
+
+ return key
+
+
+def assign_network_names_to_compvis_modules(sd_model):
+ network_layer_mapping = {}
+
+ if shared.sd_model.is_sdxl:
+ for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
+ if not hasattr(embedder, 'wrapped'):
+ continue
+
+ for name, module in embedder.wrapped.named_modules():
+ network_name = f'{i}_{name.replace(".", "_")}'
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+ else:
+ for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
+ network_name = name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+
+ for name, module in shared.sd_model.model.named_modules():
+ network_name = name.replace(".", "_")
+ network_layer_mapping[network_name] = module
+ module.network_layer_name = network_name
+
+ sd_model.network_layer_mapping = network_layer_mapping
+
+
+def load_network(name, network_on_disk):
+ net = network.Network(name, network_on_disk)
+ net.mtime = os.path.getmtime(network_on_disk.filename)
+
+ sd = sd_models.read_state_dict(network_on_disk.filename)
+
+ # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
+ if not hasattr(shared.sd_model, 'network_layer_mapping'):
+ assign_network_names_to_compvis_modules(shared.sd_model)
+
+ keys_failed_to_match = {}
+ is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
+
+ matched_networks = {}
+
+ for key_network, weight in sd.items():
+ key_network_without_network_parts, network_part = key_network.split(".", 1)
+
+ key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ if sd_module is None:
+ m = re_x_proj.match(key)
+ if m:
+ sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
+
+ # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
+ if sd_module is None and "lora_unet" in key_network_without_network_parts:
+ key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ if sd_module is None:
+ keys_failed_to_match[key_network] = key
+ continue
+
+ if key not in matched_networks:
+ matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
+
+ matched_networks[key].w[network_part] = weight
+
+ for key, weights in matched_networks.items():
+ net_module = None
+ for nettype in module_types:
+ net_module = nettype.create_module(net, weights)
+ if net_module is not None:
+ break
+
+ if net_module is None:
+ raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
+
+ net.modules[key] = net_module
+
+ if keys_failed_to_match:
+ print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
+
+ return net
+
+
+def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ already_loaded = {}
+
+ for net in loaded_networks:
+ if net.name in names:
+ already_loaded[net.name] = net
+
+ loaded_networks.clear()
+
+ networks_on_disk = [available_network_aliases.get(name, None) for name in names]
+ if any(x is None for x in networks_on_disk):
+ list_available_networks()
+
+ networks_on_disk = [available_network_aliases.get(name, None) for name in names]
+
+ failed_to_load_networks = []
+
+ for i, name in enumerate(names):
+ net = already_loaded.get(name, None)
+
+ network_on_disk = networks_on_disk[i]
+
+ if network_on_disk is not None:
+ if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
+ try:
+ net = load_network(name, network_on_disk)
+ except Exception as e:
+ errors.display(e, f"loading network {network_on_disk.filename}")
+ continue
+
+ net.mentioned_name = name
+
+ network_on_disk.read_hash()
+
+ if net is None:
+ failed_to_load_networks.append(name)
+ print(f"Couldn't find network with name {name}")
+ continue
+
+ net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
+ net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
+ net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
+ loaded_networks.append(net)
+
+ if failed_to_load_networks:
+ sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
+
+
+def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ weights_backup = getattr(self, "network_weights_backup", None)
+
+ if weights_backup is None:
+ return
+
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+
+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ """
+ Applies the currently selected set of networks to the weights of torch layer self.
+ If weights already have this particular set of networks applied, does nothing.
+ If not, restores orginal weights from backup and alters weights according to networks.
+ """
+
+ network_layer_name = getattr(self, 'network_layer_name', None)
+ if network_layer_name is None:
+ return
+
+ current_names = getattr(self, "network_current_names", ())
+ wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
+
+ weights_backup = getattr(self, "network_weights_backup", None)
+ if weights_backup is None:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
+ else:
+ weights_backup = self.weight.to(devices.cpu, copy=True)
+
+ self.network_weights_backup = weights_backup
+
+ if current_names != wanted_names:
+ network_restore_weights_from_backup(self)
+
+ for net in loaded_networks:
+ module = net.modules.get(network_layer_name, None)
+ if module is not None and hasattr(self, 'weight'):
+ with torch.no_grad():
+ updown = module.calc_updown(self.weight)
+
+ if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ # inpainting model. zero pad updown to make channel[1] 4 to 9
+ updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
+
+ self.weight += updown
+ continue
+
+ module_q = net.modules.get(network_layer_name + "_q_proj", None)
+ module_k = net.modules.get(network_layer_name + "_k_proj", None)
+ module_v = net.modules.get(network_layer_name + "_v_proj", None)
+ module_out = net.modules.get(network_layer_name + "_out_proj", None)
+
+ if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
+ with torch.no_grad():
+ updown_q = module_q.calc_updown(self.in_proj_weight)
+ updown_k = module_k.calc_updown(self.in_proj_weight)
+ updown_v = module_v.calc_updown(self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
+ updown_out = module_out.calc_updown(self.out_proj.weight)
+
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += updown_out
+ continue
+
+ if module is None:
+ continue
+
+ print(f'failed to calculate network weights for layer {network_layer_name}')
+
+ self.network_current_names = wanted_names
+
+
+def network_forward(module, input, original_forward):
+ """
+ Old way of applying Lora by executing operations during layer's forward.
+ Stacking many loras this way results in big performance degradation.
+ """
+
+ if len(loaded_networks) == 0:
+ return original_forward(module, input)
+
+ input = devices.cond_cast_unet(input)
+
+ network_restore_weights_from_backup(module)
+ network_reset_cached_weight(module)
+
+ y = original_forward(module, input)
+
+ network_layer_name = getattr(module, 'network_layer_name', None)
+ for lora in loaded_networks:
+ module = lora.modules.get(network_layer_name, None)
+ if module is None:
+ continue
+
+ y = module.forward(y, input)
+
+ return y
+
+
+def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
+ self.network_current_names = ()
+ self.network_weights_backup = None
+
+
+def network_Linear_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.Linear_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.Linear_forward_before_network(self, input)
+
+
+def network_Linear_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def network_Conv2d_forward(self, input):
+ if shared.opts.lora_functional:
+ return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
+
+ network_apply_weights(self)
+
+ return torch.nn.Conv2d_forward_before_network(self, input)
+
+
+def network_Conv2d_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def network_MultiheadAttention_forward(self, *args, **kwargs):
+ network_apply_weights(self)
+
+ return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
+
+
+def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
+ network_reset_cached_weight(self)
+
+ return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
+
+
+def list_available_networks():
+ available_networks.clear()
+ available_network_aliases.clear()
+ forbidden_network_aliases.clear()
+ available_network_hash_lookup.clear()
+ forbidden_network_aliases.update({"none": 1, "Addams": 1})
+
+ os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
+
+ candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
+ candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
+ for filename in candidates:
+ if os.path.isdir(filename):
+ continue
+
+ name = os.path.splitext(os.path.basename(filename))[0]
+ try:
+ entry = network.NetworkOnDisk(name, filename)
+ except OSError: # should catch FileNotFoundError and PermissionError etc.
+ errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
+ continue
+
+ available_networks[name] = entry
+
+ if entry.alias in available_network_aliases:
+ forbidden_network_aliases[entry.alias.lower()] = 1
+
+ available_network_aliases[name] = entry
+ available_network_aliases[entry.alias] = entry
+
+
+re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
+
+
+def infotext_pasted(infotext, params):
+ if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
+ return # if the other extension is active, it will handle those fields, no need to do anything
+
+ added = []
+
+ for k in params:
+ if not k.startswith("AddNet Model "):
+ continue
+
+ num = k[13:]
+
+ if params.get("AddNet Module " + num) != "LoRA":
+ continue
+
+ name = params.get("AddNet Model " + num)
+ if name is None:
+ continue
+
+ m = re_network_name.match(name)
+ if m:
+ name = m.group(1)
+
+ multiplier = params.get("AddNet Weight A " + num, "1.0")
+
+ added.append(f"<lora:{name}:{multiplier}>")
+
+ if added:
+ params["Prompt"] += "\n" + "".join(added)
+
+
+available_networks = {}
+available_network_aliases = {}
+loaded_networks = []
+available_network_hash_lookup = {}
+forbidden_network_aliases = {}
+
+list_available_networks()
diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py index 863dc5c0..50961be3 100644 --- a/extensions-builtin/Lora/preload.py +++ b/extensions-builtin/Lora/preload.py @@ -4,3 +4,4 @@ from modules import paths def preload(parser):
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
+ parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index e650f469..cd28afc9 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -4,69 +4,76 @@ import torch import gradio as gr
from fastapi import FastAPI
-import lora
+import network
+import networks
+import lora # noqa:F401
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
- torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
- torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
- torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
- torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
- torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
- torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
+ torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
+ torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
+ torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
+ torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
+ torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
+ torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
- extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
+ extra_network = extra_networks_lora.ExtraNetworkLora()
+ extra_networks.register_extra_network(extra_network)
+ extra_networks.register_extra_network_alias(extra_network, "lyco")
-if not hasattr(torch.nn, 'Linear_forward_before_lora'):
- torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
-if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
- torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
+if not hasattr(torch.nn, 'Linear_forward_before_network'):
+ torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
-if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
- torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
+if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
+ torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
-if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
- torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
+if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
+ torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
-if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
- torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
+if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
+ torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
-if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
- torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
+if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
+ torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
-torch.nn.Linear.forward = lora.lora_Linear_forward
-torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
-torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
-torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
-torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
-torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
+if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
+ torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
-script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
+torch.nn.Linear.forward = networks.network_Linear_forward
+torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
+torch.nn.Conv2d.forward = networks.network_Conv2d_forward
+torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
+torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
+torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
+
+script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
-script_callbacks.on_infotext_pasted(lora.infotext_pasted)
+script_callbacks.on_infotext_pasted(networks.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
+ "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
+ "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
+ "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
}))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
- "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
+ "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
}))
-def create_lora_json(obj: lora.LoraOnDisk):
+def create_lora_json(obj: network.NetworkOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
@@ -75,17 +82,17 @@ def create_lora_json(obj: lora.LoraOnDisk): }
-def api_loras(_: gr.Blocks, app: FastAPI):
+def api_networks(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
- return [create_lora_json(obj) for obj in lora.available_loras.values()]
+ return [create_lora_json(obj) for obj in networks.available_networks.values()]
@app.post("/sdapi/v1/refresh-loras")
async def refresh_loras():
- return lora.list_available_loras()
+ return networks.list_available_networks()
-script_callbacks.on_app_started(api_loras)
+script_callbacks.on_app_started(api_networks)
re_lora = re.compile("<lora:([^:]+):")
@@ -98,19 +105,19 @@ def infotext_pasted(infotext, d): hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
- def lora_replacement(m):
+ def network_replacement(m):
alias = m.group(1)
shorthash = hashes.get(alias)
if shorthash is None:
return m.group(0)
- lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
- if lora_on_disk is None:
+ network_on_disk = networks.available_network_hash_lookup.get(shorthash)
+ if network_on_disk is None:
return m.group(0)
- return f'<lora:{lora_on_disk.get_alias()}:'
+ return f'<lora:{network_on_disk.get_alias()}:'
- d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
+ d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
script_callbacks.on_infotext_pasted(infotext_pasted)
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py new file mode 100644 index 00000000..2ca997f7 --- /dev/null +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -0,0 +1,216 @@ +import datetime
+import html
+import random
+
+import gradio as gr
+import re
+
+from modules import ui_extra_networks_user_metadata
+
+
+def is_non_comma_tagset(tags):
+ average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)
+
+ return average_tag_length >= 16
+
+
+re_word = re.compile(r"[-_\w']+")
+re_comma = re.compile(r" *, *")
+
+
+def build_tags(metadata):
+ tags = {}
+
+ for _, tags_dict in metadata.get("ss_tag_frequency", {}).items():
+ for tag, tag_count in tags_dict.items():
+ tag = tag.strip()
+ tags[tag] = tags.get(tag, 0) + int(tag_count)
+
+ if tags and is_non_comma_tagset(tags):
+ new_tags = {}
+
+ for text, text_count in tags.items():
+ for word in re.findall(re_word, text):
+ if len(word) < 3:
+ continue
+
+ new_tags[word] = new_tags.get(word, 0) + text_count
+
+ tags = new_tags
+
+ ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True)
+
+ return [(tag, tags[tag]) for tag in ordered_tags]
+
+
+class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
+ def __init__(self, ui, tabname, page):
+ super().__init__(ui, tabname, page)
+
+ self.select_sd_version = None
+
+ self.taginfo = None
+ self.edit_activation_text = None
+ self.slider_preferred_weight = None
+ self.edit_notes = None
+
+ def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
+ user_metadata = self.get_user_metadata(name)
+ user_metadata["description"] = desc
+ user_metadata["sd version"] = sd_version
+ user_metadata["activation text"] = activation_text
+ user_metadata["preferred weight"] = preferred_weight
+ user_metadata["notes"] = notes
+
+ self.write_user_metadata(name, user_metadata)
+
+ def get_metadata_table(self, name):
+ table = super().get_metadata_table(name)
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+
+ keys = {
+ 'ss_sd_model_name': "Model:",
+ 'ss_clip_skip': "Clip skip:",
+ 'ss_network_module': "Kohya module:",
+ }
+
+ for key, label in keys.items():
+ value = metadata.get(key, None)
+ if value is not None and str(value) != "None":
+ table.append((label, html.escape(value)))
+
+ ss_training_started_at = metadata.get('ss_training_started_at')
+ if ss_training_started_at:
+ table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))
+
+ ss_bucket_info = metadata.get("ss_bucket_info")
+ if ss_bucket_info and "buckets" in ss_bucket_info:
+ resolutions = {}
+ for _, bucket in ss_bucket_info["buckets"].items():
+ resolution = bucket["resolution"]
+ resolution = f'{resolution[1]}x{resolution[0]}'
+
+ resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket["count"])
+
+ resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True)
+ resolutions_text = html.escape(", ".join(resolutions_list[0:4]))
+ if len(resolutions) > 4:
+ resolutions_text += ", ..."
+ resolutions_text = f"<span title='{html.escape(', '.join(resolutions_list))}'>{resolutions_text}</span>"
+
+ table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text))
+
+ image_count = 0
+ for _, params in metadata.get("ss_dataset_dirs", {}).items():
+ image_count += int(params.get("img_count", 0))
+
+ if image_count:
+ table.append(("Dataset size:", image_count))
+
+ return table
+
+ def put_values_into_components(self, name):
+ user_metadata = self.get_user_metadata(name)
+ values = super().put_values_into_components(name)
+
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+
+ tags = build_tags(metadata)
+ gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
+
+ return [
+ *values[0:5],
+ item.get("sd_version", "Unknown"),
+ gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
+ user_metadata.get('activation text', ''),
+ float(user_metadata.get('preferred weight', 0.0)),
+ gr.update(visible=True if tags else False),
+ gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
+ ]
+
+ def generate_random_prompt(self, name):
+ item = self.page.items.get(name, {})
+ metadata = item.get("metadata") or {}
+ tags = build_tags(metadata)
+
+ return self.generate_random_prompt_from_tags(tags)
+
+ def generate_random_prompt_from_tags(self, tags):
+ max_count = None
+ res = []
+ for tag, count in tags:
+ if not max_count:
+ max_count = count
+
+ v = random.random() * max_count
+ if count > v:
+ res.append(tag)
+
+ return ", ".join(sorted(res))
+
+ def create_extra_default_items_in_left_column(self):
+
+ # this would be a lot better as gr.Radio but I can't make it work
+ self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
+
+ def create_editor(self):
+ self.create_default_editor_elems()
+
+ self.taginfo = gr.HighlightedText(label="Training dataset tags")
+ self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
+ self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
+
+ with gr.Row() as row_random_prompt:
+ with gr.Column(scale=8):
+ random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
+
+ with gr.Column(scale=1, min_width=120):
+ generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
+
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+ generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False)
+
+ def select_tag(activation_text, evt: gr.SelectData):
+ tag = evt.value[0]
+
+ words = re.split(re_comma, activation_text)
+ if tag in words:
+ words = [x for x in words if x != tag and x.strip()]
+ return ", ".join(words)
+
+ return activation_text + ", " + tag if activation_text else tag
+
+ self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False)
+
+ self.create_default_buttons()
+
+ viewed_components = [
+ self.edit_name,
+ self.edit_description,
+ self.html_filedata,
+ self.html_preview,
+ self.edit_notes,
+ self.select_sd_version,
+ self.taginfo,
+ self.edit_activation_text,
+ self.slider_preferred_weight,
+ row_random_prompt,
+ random_prompt,
+ ]
+
+ self.button_edit\
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+ edited_components = [
+ self.edit_description,
+ self.select_sd_version,
+ self.edit_activation_text,
+ self.slider_preferred_weight,
+ self.edit_notes,
+ ]
+
+ self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index da49790b..3629e5c0 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,8 +1,11 @@ -import json
import os
-import lora
+
+import network
+import networks
from modules import shared, ui_extra_networks
+from modules.ui_extra_networks import quote_js
+from ui_edit_user_metadata import LoraUserMetadataEditor
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
@@ -10,27 +13,66 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): super().__init__('Lora')
def refresh(self):
- lora.list_available_loras()
+ networks.list_available_networks()
- def list_items(self):
- for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()):
- path, ext = os.path.splitext(lora_on_disk.filename)
+ def create_item(self, name, index=None, enable_filter=True):
+ lora_on_disk = networks.available_networks.get(name)
+
+ path, ext = os.path.splitext(lora_on_disk.filename)
+
+ alias = lora_on_disk.get_alias()
+
+ item = {
+ "name": name,
+ "filename": lora_on_disk.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(lora_on_disk.filename),
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "metadata": lora_on_disk.metadata,
+ "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
+ "sd_version": lora_on_disk.sd_version.name,
+ }
- alias = lora_on_disk.get_alias()
+ self.read_user_metadata(item)
+ activation_text = item["user_metadata"].get("activation text")
+ preferred_weight = item["user_metadata"].get("preferred weight", 0.0)
+ item["prompt"] = quote_js(f"<lora:{alias}:") + " + " + (str(preferred_weight) if preferred_weight else "opts.extra_networks_default_multiplier") + " + " + quote_js(">")
- yield {
- "name": name,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(lora_on_disk.filename),
- "prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": f"{path}.{shared.opts.samples_format}",
- "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
- "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
+ if activation_text:
+ item["prompt"] += " + " + quote_js(" " + activation_text)
+
+ sd_version = item["user_metadata"].get("sd version")
+ if sd_version in network.SdVersion.__members__:
+ item["sd_version"] = sd_version
+ sd_version = network.SdVersion[sd_version]
+ else:
+ sd_version = lora_on_disk.sd_version
+
+ if shared.opts.lora_show_all or not enable_filter:
+ pass
+ elif sd_version == network.SdVersion.Unknown:
+ model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
+ if model_version.name in shared.opts.lora_hide_unknown_for_versions:
+ return None
+ elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
+ return None
+ elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
+ return None
+ elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
+ return None
+
+ return item
+
+ def list_items(self):
+ for index, name in enumerate(networks.available_networks):
+ item = self.create_item(name, index)
- }
+ if item is not None:
+ yield item
def allowed_directories_for_previews(self):
- return [shared.cmd_opts.lora_dir]
+ return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat]
+ def create_user_metadata_editor(self, ui, tabname):
+ return LoraUserMetadataEditor(ui, tabname, self)
diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js new file mode 100644 index 00000000..12cae4b7 --- /dev/null +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -0,0 +1,26 @@ +var isSetupForMobile = false; + +function isMobile() { + for (var tab of ["txt2img", "img2img"]) { + var imageTab = gradioApp().getElementById(tab + '_results'); + if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) { + return true; + } + } + + return false; +} + +function reportWindowSize() { + var currentlyMobile = isMobile(); + if (currentlyMobile == isSetupForMobile) return; + isSetupForMobile = currentlyMobile; + + for (var tab of ["txt2img", "img2img"]) { + var button = gradioApp().getElementById(tab + '_generate_box'); + var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column'); + target.insertBefore(button, target.firstElementChild); + } +} + +window.addEventListener("resize", reportWindowSize); diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 68a84c3a..39674666 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,11 +1,11 @@ -<div class='card' style={style} onclick={card_clicked} {sort_keys}> +<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}> {background_image} - {metadata_button} + <div class="button-row"> + {metadata_button} + {edit_button} + </div> <div class='actions'> <div class='additional'> - <ul> - <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a> - </ul> <span style="display:none" class='search_term{search_only}'>{search_term}</span> </div> <span class='name'>{name}</span> diff --git a/html/image-update.svg b/html/image-update.svg deleted file mode 100644 index 3abf12df..00000000 --- a/html/image-update.svg +++ /dev/null @@ -1,7 +0,0 @@ -<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"> - <filter id='shadow' color-interpolation-filters="sRGB"> - <feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/> - <feDropShadow flood-color="black" dx="0" dy="0" flood-opacity="0.9" stdDeviation="0.5"/> - </filter> - <path style="filter:url(#shadow);" fill="#FFFFFF" d="M13.18 19C13.35 19.72 13.64 20.39 14.03 21H5C3.9 21 3 20.11 3 19V5C3 3.9 3.9 3 5 3H19C20.11 3 21 3.9 21 5V11.18C20.5 11.07 20 11 19.5 11C19.33 11 19.17 11 19 11.03V5H5V19H13.18M11.21 15.83L9.25 13.47L6.5 17H13.03C13.14 15.54 13.73 14.22 14.64 13.19L13.96 12.29L11.21 15.83M19 13.5V12L16.75 14.25L19 16.5V15C20.38 15 21.5 16.12 21.5 17.5C21.5 17.9 21.41 18.28 21.24 18.62L22.33 19.71C22.75 19.08 23 18.32 23 17.5C23 15.29 21.21 13.5 19 13.5M19 20C17.62 20 16.5 18.88 16.5 17.5C16.5 17.1 16.59 16.72 16.76 16.38L15.67 15.29C15.25 15.92 15 16.68 15 17.5C15 19.71 16.79 21.5 19 21.5V23L21.25 20.75L19 18.5V20Z" /> -</svg> diff --git a/javascript/badScaleChecker.js b/javascript/badScaleChecker.js new file mode 100644 index 00000000..625ad309 --- /dev/null +++ b/javascript/badScaleChecker.js @@ -0,0 +1,108 @@ +(function() { + var ignore = localStorage.getItem("bad-scale-ignore-it") == "ignore-it"; + + function getScale() { + var ratio = 0, + screen = window.screen, + ua = navigator.userAgent.toLowerCase(); + + if (window.devicePixelRatio !== undefined) { + ratio = window.devicePixelRatio; + } else if (~ua.indexOf('msie')) { + if (screen.deviceXDPI && screen.logicalXDPI) { + ratio = screen.deviceXDPI / screen.logicalXDPI; + } + } else if (window.outerWidth !== undefined && window.innerWidth !== undefined) { + ratio = window.outerWidth / window.innerWidth; + } + + return ratio == 0 ? 0 : Math.round(ratio * 100); + } + + var showing = false; + + var div = document.createElement("div"); + div.style.position = "fixed"; + div.style.top = "0px"; + div.style.left = "0px"; + div.style.width = "100vw"; + div.style.backgroundColor = "firebrick"; + div.style.textAlign = "center"; + div.style.zIndex = 99; + + var b = document.createElement("b"); + b.innerHTML = 'Bad Scale: ??% '; + + div.appendChild(b); + + var note1 = document.createElement("p"); + note1.innerHTML = "Change your browser or your computer settings!"; + note1.title = 'Just make sure "computer-scale" * "browser-scale" = 100% ,\n' + + "you can keep your computer-scale and only change this page's scale,\n" + + "for example: your computer-scale is 125%, just use [\"CTRL\"+\"-\"] to make your browser-scale of this page to 80%."; + div.appendChild(note1); + + var note2 = document.createElement("p"); + note2.innerHTML = " Otherwise, it will cause this page to not function properly!"; + note2.title = "When you click \"Copy image to: [inpaint sketch]\" in some img2img's tab,\n" + + "if scale<100% the canvas will be invisible,\n" + + "else if scale>100% this page will take large amount of memory and CPU performance."; + div.appendChild(note2); + + var btn = document.createElement("button"); + btn.innerHTML = "Click here to ignore"; + + div.appendChild(btn); + + function tryShowTopBar(scale) { + if (showing) return; + + b.innerHTML = 'Bad Scale: ' + scale + '% '; + + var updateScaleTimer = setInterval(function() { + var newScale = getScale(); + b.innerHTML = 'Bad Scale: ' + newScale + '% '; + if (newScale == 100) { + var p = div.parentNode; + if (p != null) p.removeChild(div); + showing = false; + clearInterval(updateScaleTimer); + check(); + } + }, 999); + + btn.onclick = function() { + clearInterval(updateScaleTimer); + var p = div.parentNode; + if (p != null) p.removeChild(div); + ignore = true; + showing = false; + localStorage.setItem("bad-scale-ignore-it", "ignore-it"); + }; + + document.body.appendChild(div); + } + + function check() { + if (!ignore) { + var timer = setInterval(function() { + var scale = getScale(); + if (scale != 100 && !ignore) { + tryShowTopBar(scale); + clearInterval(timer); + } + if (ignore) { + clearInterval(timer); + } + }, 999); + } + } + + if (document.readyState != "complete") { + document.onreadystatechange = function() { + if (document.readyState != "complete") check(); + }; + } else { + check(); + } +})(); diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index b87bca3e..5582a6e5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -113,7 +113,7 @@ function setupExtraNetworks() { onUiLoaded(setupExtraNetworks); -var re_extranet = /<([^:]+:[^:]+):[\d.]+>/; +var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g; function tryToRemoveExtraNetworkFromPrompt(textarea, text) { @@ -121,15 +121,22 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var replaced = false; var newTextareaText; if (m) { + var extraTextAfterNet = m[2]; var partToSearch = m[1]; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found) { + var foundAtPosition = -1; + newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { m = found.match(re_extranet); if (m[1] == partToSearch) { replaced = true; + foundAtPosition = pos; return ""; } return found; }); + + if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); + } } else { newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { if (found == text) { @@ -182,19 +189,20 @@ function extraNetworksSearchButton(tabs_id, event) { var globalPopup = null; var globalPopupInner = null; +function closePopup() { + if (!globalPopup) return; + + globalPopup.style.display = "none"; +} function popup(contents) { if (!globalPopup) { globalPopup = document.createElement('div'); - globalPopup.onclick = function() { - globalPopup.style.display = "none"; - }; + globalPopup.onclick = closePopup; globalPopup.classList.add('global-popup'); var close = document.createElement('div'); close.classList.add('global-popup-close'); - close.onclick = function() { - globalPopup.style.display = "none"; - }; + close.onclick = closePopup; close.title = "Close"; globalPopup.appendChild(close); @@ -205,7 +213,7 @@ function popup(contents) { globalPopupInner.classList.add('global-popup-inner'); globalPopup.appendChild(globalPopupInner); - gradioApp().appendChild(globalPopup); + gradioApp().querySelector('.main').appendChild(globalPopup); } globalPopupInner.innerHTML = ''; @@ -263,3 +271,43 @@ function extraNetworksRequestMetadata(event, extraPage, cardName) { event.stopPropagation(); } + +var extraPageUserMetadataEditors = {}; + +function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) { + var id = tabname + '_' + extraPage + '_edit_user_metadata'; + + var editor = extraPageUserMetadataEditors[id]; + if (!editor) { + editor = {}; + editor.page = gradioApp().getElementById(id); + editor.nameTextarea = gradioApp().querySelector("#" + id + "_name" + ' textarea'); + editor.button = gradioApp().querySelector("#" + id + "_button"); + extraPageUserMetadataEditors[id] = editor; + } + + editor.nameTextarea.value = cardName; + updateInput(editor.nameTextarea); + + editor.button.click(); + + popup(editor.page); + + event.stopPropagation(); +} + +function extraNetworksRefreshSingleCard(page, tabname, name) { + requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) { + if (data && data.html) { + var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function + + var newDiv = document.createElement('DIV'); + newDiv.innerHTML = data.html; + var newCard = newDiv.firstElementChild; + + newCard.style = ''; + card.parentElement.insertBefore(newCard, card); + card.parentElement.removeChild(card); + } + }); +} diff --git a/javascript/hints.js b/javascript/hints.js index dc75ce31..4167cb28 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -84,8 +84,6 @@ var titles = { "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.", "Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.", - "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", - "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", @@ -110,7 +108,6 @@ var titles = { "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", - "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.", "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." @@ -18,6 +18,7 @@ run_pip = launch_utils.run_pip check_run_python = launch_utils.check_run_python
git_clone = launch_utils.git_clone
git_pull_recursive = launch_utils.git_pull_recursive
+list_extensions = launch_utils.list_extensions
run_extension_installer = launch_utils.run_extension_installer
prepare_environment = launch_utils.prepare_environment
configure_for_tests = launch_utils.configure_for_tests
diff --git a/modules/api/api.py b/modules/api/api.py index 11045292..2a4cd8a2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,6 @@ import base64 import io +import os import time import datetime import uvicorn @@ -98,14 +99,16 @@ def encode_pil_to_base64(image): def api_middleware(app: FastAPI): - rich_available = True + rich_available = False try: - import anyio # importing just so it can be placed on silent list - import starlette # importing just so it can be placed on silent list - from rich.console import Console - console = Console() + if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + console = Console() + rich_available = True except Exception: - rich_available = False + pass @app.middleware("http") async def log_and_time(req: Request, call_next): @@ -116,14 +119,14 @@ def api_middleware(app: FastAPI): endpoint = req.scope.get('path', 'err') if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( - t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), - code = res.status_code, - ver = req.scope.get('http_version', '0.0'), - cli = req.scope.get('client', ('0:0.0.0', 0))[0], - prot = req.scope.get('scheme', 'err'), - method = req.scope.get('method', 'err'), - endpoint = endpoint, - duration = duration, + t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code=res.status_code, + ver=req.scope.get('http_version', '0.0'), + cli=req.scope.get('client', ('0:0.0.0', 0))[0], + prot=req.scope.get('scheme', 'err'), + method=req.scope.get('method', 'err'), + endpoint=endpoint, + duration=duration, )) return res @@ -134,7 +137,7 @@ def api_middleware(app: FastAPI): "body": vars(e).get('body', ''), "errors": str(e), } - if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions message = f"API error: {request.method}: {request.url} {err}" if rich_available: print(message) diff --git a/modules/api/models.py b/modules/api/models.py index b5683071..bf97b1a3 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,4 +1,5 @@ import inspect + from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal @@ -207,11 +208,12 @@ class PreprocessResponse(BaseModel): fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) - optType = opts.typemap.get(type(metadata.default), type(value)) + optType = opts.typemap.get(type(metadata.default), type(metadata.default)) - if (metadata is not None): - fields.update({key: (Optional[optType], Field( - default=metadata.default ,description=metadata.label))}) + if metadata.default is None: + pass + elif metadata is not None: + fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))}) else: fields.update({key: (Optional[optType], Field())}) diff --git a/modules/cache.py b/modules/cache.py new file mode 100644 index 00000000..71fe6302 --- /dev/null +++ b/modules/cache.py @@ -0,0 +1,120 @@ +import json
+import os.path
+import threading
+import time
+
+from modules.paths import data_path, script_path
+
+cache_filename = os.path.join(data_path, "cache.json")
+cache_data = None
+cache_lock = threading.Lock()
+
+dump_cache_after = None
+dump_cache_thread = None
+
+
+def dump_cache():
+ """
+ Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written.
+ """
+
+ global dump_cache_after
+ global dump_cache_thread
+
+ def thread_func():
+ global dump_cache_after
+ global dump_cache_thread
+
+ while dump_cache_after is not None and time.time() < dump_cache_after:
+ time.sleep(1)
+
+ with cache_lock:
+ with open(cache_filename, "w", encoding="utf8") as file:
+ json.dump(cache_data, file, indent=4)
+
+ dump_cache_after = None
+ dump_cache_thread = None
+
+ with cache_lock:
+ dump_cache_after = time.time() + 5
+ if dump_cache_thread is None:
+ dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func)
+ dump_cache_thread.start()
+
+
+def cache(subsection):
+ """
+ Retrieves or initializes a cache for a specific subsection.
+
+ Parameters:
+ subsection (str): The subsection identifier for the cache.
+
+ Returns:
+ dict: The cache data for the specified subsection.
+ """
+
+ global cache_data
+
+ if cache_data is None:
+ with cache_lock:
+ if cache_data is None:
+ if not os.path.isfile(cache_filename):
+ cache_data = {}
+ else:
+ try:
+ with open(cache_filename, "r", encoding="utf8") as file:
+ cache_data = json.load(file)
+ except Exception:
+ os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
+ print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
+ cache_data = {}
+
+ s = cache_data.get(subsection, {})
+ cache_data[subsection] = s
+
+ return s
+
+
+def cached_data_for_file(subsection, title, filename, func):
+ """
+ Retrieves or generates data for a specific file, using a caching mechanism.
+
+ Parameters:
+ subsection (str): The subsection of the cache to use.
+ title (str): The title of the data entry in the subsection of the cache.
+ filename (str): The path to the file to be checked for modifications.
+ func (callable): A function that generates the data if it is not available in the cache.
+
+ Returns:
+ dict or None: The cached or generated data, or None if data generation fails.
+
+ The `cached_data_for_file` function implements a caching mechanism for data stored in files.
+ It checks if the data associated with the given `title` is present in the cache and compares the
+ modification time of the file with the cached modification time. If the file has been modified,
+ the cache is considered invalid and the data is regenerated using the provided `func`.
+ Otherwise, the cached data is returned.
+
+ If the data generation fails, None is returned to indicate the failure. Otherwise, the generated
+ or cached data is returned as a dictionary.
+ """
+
+ existing_cache = cache(subsection)
+ ondisk_mtime = os.path.getmtime(filename)
+
+ entry = existing_cache.get(title)
+ if entry:
+ cached_mtime = entry.get("mtime", 0)
+ if ondisk_mtime > cached_mtime:
+ entry = None
+
+ if not entry or 'value' not in entry:
+ value = func()
+ if value is None:
+ return None
+
+ entry = {'mtime': ondisk_mtime, 'value': value}
+ existing_cache[title] = entry
+
+ dump_cache()
+
+ return entry['value']
diff --git a/modules/call_queue.py b/modules/call_queue.py index 3b94f8a4..61aa240f 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -85,9 +85,9 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
- elapsed_text = f"{elapsed_s:.2f}s"
+ elapsed_text = f"{elapsed_s:.1f} sec."
if elapsed_m > 0:
- elapsed_text = f"{elapsed_m}m "+elapsed_text
+ elapsed_text = f"{elapsed_m} min. "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
@@ -95,14 +95,22 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
- sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
+ sys_pct = sys_peak/max(sys_total, 1) * 100
- vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
+ toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
+ toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
+ toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
+
+ text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
+ text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
+ text_sys = f"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)"
+
+ vram_html = f"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>"
else:
vram_html = ''
# last item is always HTML
- res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
+ res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
return tuple(res)
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index ae78f469..e401f641 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -15,6 +15,7 @@ parser.add_argument("--update-check", action='store_true', help="launch.py argum parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
+parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
diff --git a/modules/extensions.py b/modules/extensions.py index abc6e2b1..c561159a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,7 @@ import os
import threading
-from modules import shared, errors
+from modules import shared, errors, cache
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
@@ -21,6 +21,7 @@ def active(): class Extension:
lock = threading.Lock()
+ cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
@@ -36,15 +37,29 @@ class Extension: self.remote = None
self.have_info_from_repo = False
+ def to_dict(self):
+ return {x: getattr(self, x) for x in self.cached_fields}
+
+ def from_dict(self, d):
+ for field in self.cached_fields:
+ setattr(self, field, d[field])
+
def read_info_from_repo(self):
if self.is_builtin or self.have_info_from_repo:
return
- with self.lock:
- if self.have_info_from_repo:
- return
+ def read_from_repo():
+ with self.lock:
+ if self.have_info_from_repo:
+ return
+
+ self.do_read_info_from_repo()
+
+ return self.to_dict()
- self.do_read_info_from_repo()
+ d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
+ self.from_dict(d)
+ self.status = 'unknown'
def do_read_info_from_repo(self):
repo = None
@@ -58,7 +73,6 @@ class Extension: self.remote = None
else:
try:
- self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
commit = repo.head.commit
self.commit_date = commit.committed_date
diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 41799b0a..6ae07e91 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -4,16 +4,22 @@ from collections import defaultdict from modules import errors
extra_network_registry = {}
+extra_network_aliases = {}
def initialize():
extra_network_registry.clear()
+ extra_network_aliases.clear()
def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
+def register_extra_network_alias(extra_network, alias):
+ extra_network_aliases[alias] = extra_network
+
+
def register_default_extra_networks():
from modules.extra_networks_hypernet import ExtraNetworkHypernet
register_extra_network(ExtraNetworkHypernet())
@@ -82,20 +88,26 @@ def activate(p, extra_network_data): """call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list"""
+ activated = []
+
for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None)
+
+ if extra_network is None:
+ extra_network = extra_network_aliases.get(extra_network_name, None)
+
if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}")
continue
try:
extra_network.activate(p, extra_network_args)
+ activated.append(extra_network)
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items():
- args = extra_network_data.get(extra_network_name, None)
- if args is not None:
+ if extra_network in activated:
continue
try:
diff --git a/modules/hashes.py b/modules/hashes.py index ec1187fe..b7a33b42 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -1,43 +1,11 @@ import hashlib
-import json
import os.path
-import filelock
-
from modules import shared
-from modules.paths import data_path, script_path
-
-
-cache_filename = os.path.join(data_path, "cache.json")
-cache_data = None
-
-
-def dump_cache():
- with filelock.FileLock(f"{cache_filename}.lock"):
- with open(cache_filename, "w", encoding="utf8") as file:
- json.dump(cache_data, file, indent=4)
-
-
-def cache(subsection):
- global cache_data
-
- if cache_data is None:
- with filelock.FileLock(f"{cache_filename}.lock"):
- if not os.path.isfile(cache_filename):
- cache_data = {}
- else:
- try:
- with open(cache_filename, "r", encoding="utf8") as file:
- cache_data = json.load(file)
- except Exception:
- os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
- print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
- cache_data = {}
-
- s = cache_data.get(subsection, {})
- cache_data[subsection] = s
+import modules.cache
- return s
+dump_cache = modules.cache.dump_cache
+cache = modules.cache.cache
def calculate_sha256(filename):
diff --git a/modules/images.py b/modules/images.py index 4bdedb7f..38aa933d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -363,7 +363,7 @@ class FilenameGenerator: 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
- 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
+ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
@@ -380,6 +380,7 @@ class FilenameGenerator: 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(),
+ 'none': lambda self: '', # Overrides the default so you can get just the sequence number
}
default_time_format = '%Y%m%d%H%M%S'
@@ -601,13 +602,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else:
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
+ file_decoration = namegen.apply(file_decoration) + suffix
+
add_number = opts.save_images_add_number or file_decoration == ''
if file_decoration != "" and add_number:
file_decoration = f"-{file_decoration}"
- file_decoration = namegen.apply(file_decoration) + suffix
-
if add_number:
basecount = get_next_sequence_number(path, basename)
fullfn = None
diff --git a/modules/img2img.py b/modules/img2img.py index 664e2688..a811e7a4 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -240,4 +240,4 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 01ea7c91..03552bc2 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -1,4 +1,5 @@ # this scripts installs necessary requirements and launches main program in webui.py
+import re
import subprocess
import os
import sys
@@ -9,6 +10,9 @@ from functools import lru_cache from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir
+from modules import timer
+
+timer.startup_timer.record("start")
args, _ = cmd_args.parser.parse_known_args()
@@ -69,10 +73,12 @@ def git_tag(): return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
except Exception:
try:
- from pathlib import Path
- changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
- with changelog_md.open(encoding="utf-8") as file:
- return next((line.strip() for line in file if line.strip()), "<none>")
+
+ changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md")
+ with open(changelog_md, "r", encoding="utf-8") as file:
+ line = next((line.strip() for line in file if line.strip()), "<none>")
+ line = line.replace("## ", "")
+ return line
except Exception:
return "<none>"
@@ -224,6 +230,44 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension))
+re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
+
+
+def requrements_met(requirements_file):
+ """
+ Does a simple parse of a requirements.txt file to determine if all rerqirements in it
+ are already installed. Returns True if so, False if not installed or parsing fails.
+ """
+
+ import importlib.metadata
+ import packaging.version
+
+ with open(requirements_file, "r", encoding="utf8") as file:
+ for line in file:
+ if line.strip() == "":
+ continue
+
+ m = re.match(re_requirement, line)
+ if m is None:
+ return False
+
+ package = m.group(1).strip()
+ version_required = (m.group(2) or "").strip()
+
+ if version_required == "":
+ continue
+
+ try:
+ version_installed = importlib.metadata.version(package)
+ except Exception:
+ return False
+
+ if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
+ return False
+
+ return True
+
+
def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
@@ -309,7 +353,9 @@ def prepare_environment(): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
- run_pip(f"install -r \"{requirements_file}\"", "requirements")
+
+ if not requrements_met(requirements_file):
+ run_pip(f"install -r \"{requirements_file}\"", "requirements")
run_extensions_installers(settings_file=args.ui_settings_file)
diff --git a/modules/processing.py b/modules/processing.py index eb4a60eb..a74a5302 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -14,7 +14,7 @@ from skimage import exposure from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -538,6 +538,40 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x
+def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
+ samples = []
+
+ for i in range(batch.shape[0]):
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
+
+ if check_for_nans:
+ try:
+ devices.test_for_nans(sample, "vae")
+ except devices.NansException as e:
+ if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
+ raise e
+
+ errors.print_error_explanation(
+ "A tensor with all NaNs was produced in VAE.\n"
+ "Web UI will now convert VAE into 32-bit float and retry.\n"
+ "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
+ "To always start with 32-bit VAE, use --no-half-vae commandline flag."
+ )
+
+ devices.dtype_vae = torch.float32
+ model.first_stage_model.to(devices.dtype_vae)
+ batch = batch.to(devices.dtype_vae)
+
+ sample = decode_first_stage(model, batch[i:i + 1])[0]
+
+ if target_device is not None:
+ sample = sample.to(target_device)
+
+ samples.append(sample)
+
+ return samples
+
+
def decode_first_stage(model, x):
x = model.decode_first_stage(x.to(devices.dtype_vae))
@@ -587,7 +621,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
- "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
+ "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -747,9 +781,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds()
- if len(model_hijack.comments) > 0:
- for comment in model_hijack.comments:
- comments[comment] = 1
+ for comment in model_hijack.comments:
+ comments[comment] = 1
+
+ p.extra_generation_params.update(model_hijack.extra_generation_params)
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
@@ -757,10 +792,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
- for x in x_samples_ddim:
- devices.test_for_nans(x, "vae")
-
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -1028,7 +1060,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): image = sd_samplers.sample_to_image(image, index, approximation=0)
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
if latent_scale_mode is not None:
for i in range(samples.shape[0]):
@@ -1302,7 +1334,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(shared.device)
+ image = image.to(shared.device, dtype=devices.dtype_vae)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
diff --git a/modules/script_loading.py b/modules/script_loading.py index 306a1f35..0d55f193 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -12,11 +12,12 @@ def load_module(path): return module
-def preload_extensions(extensions_dir, parser):
+def preload_extensions(extensions_dir, parser, extension_list=None):
if not os.path.isdir(extensions_dir):
return
- for dirname in sorted(os.listdir(extensions_dir)):
+ extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)
+ for dirname in sorted(extensions):
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
if not os.path.isfile(preload_script):
continue
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2b274c18..f5615967 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -159,7 +159,6 @@ def undo_weighted_forward(sd_model): class StableDiffusionModelHijack:
fixes = None
- comments = []
layers = None
circular_enabled = False
clip = None
@@ -168,6 +167,9 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
def __init__(self):
+ self.extra_generation_params = {}
+ self.comments = []
+
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
@@ -274,6 +276,7 @@ class StableDiffusionModelHijack: def clear_comments(self):
self.comments = []
+ self.extra_generation_params = {}
def get_prompt_lengths(self, text):
if self.clip is None:
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index b3771909..5443e609 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -234,9 +234,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.process_tokens(tokens, multipliers)
zs.append(z)
- if len(used_embeddings) > 0:
- embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
- self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+ if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:
+ hashes = []
+ for name, embedding in used_embeddings.items():
+ shorthash = embedding.shorthash
+ if not shorthash:
+ continue
+
+ name = name.replace(":", "").replace(",", "")
+ hashes.append(f"{name}: {shorthash}")
+
+ if hashes:
+ self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
if getattr(self.wrapped, 'return_pooled', False):
return torch.hstack(zs), zs[0].pooled
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index ca1daf45..2101f1a0 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): if isinstance(cond, dict):
for y in cond.keys():
- cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+ if isinstance(cond[y], list):
+ cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
+ else:
+ cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
@@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
+
+CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
+CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
diff --git a/modules/sd_models.py b/modules/sd_models.py index 729f03d7..fb31a793 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -290,6 +290,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
model.is_sdxl = hasattr(model, 'conditioner')
+ model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
+ model.is_sd1 = not model.is_sdxl and not model.is_sd2
+
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
@@ -323,7 +326,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer timer.record("apply half()")
- devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -491,7 +494,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = None
try:
- with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
sd_model = instantiate_from_config(sd_config.model)
except Exception:
pass
diff --git a/modules/shared.py b/modules/shared.py index b28933a0..aa72c9c8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,6 +11,7 @@ import gradio as gr import torch
import tqdm
+import launch
import modules.interrogate
import modules.memmon
import modules.styles
@@ -26,7 +27,7 @@ demo = None parser = cmd_args.parser
-script_loading.preload_extensions(extensions_dir, parser)
+script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
@@ -426,6 +427,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
+ "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
}))
@@ -473,12 +475,15 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
- "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
- "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
+ "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
+ "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
+ "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
+ "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
}))
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cbe975b7..6166c76f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -49,6 +49,8 @@ class Embedding: self.sd_checkpoint_name = None
self.optimizer_state_dict = None
self.filename = None
+ self.hash = None
+ self.shorthash = None
def save(self, filename):
embedding_data = {
@@ -82,6 +84,10 @@ class Embedding: self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
class DirWithTextualInversionEmbeddings:
def __init__(self, path):
@@ -199,6 +205,7 @@ class EmbeddingDatabase: embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
embedding.filename = path
+ embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
@@ -249,7 +256,7 @@ class EmbeddingDatabase: self.word_embeddings.update(sorted_word_embeddings)
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
- if self.previously_displayed_embeddings != displayed_embeddings:
+ if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if self.skipped_embeddings:
diff --git a/modules/txt2img.py b/modules/txt2img.py index d0be2e73..29d94e8c 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -70,4 +70,4 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
diff --git a/modules/ui.py b/modules/ui.py index 39d226ad..07ecee7b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,8 +83,7 @@ detect_image_size_symbol = '\U0001F4D0' # 📐 up_down_symbol = '\u2195\ufe0f' # ↕️
-def plaintext_to_html(text):
- return ui_common.plaintext_to_html(text)
+plaintext_to_html = ui_common.plaintext_to_html
def send_gradio_gallery_to_image(x):
diff --git a/modules/ui_common.py b/modules/ui_common.py index 57c2d0ad..11eb2a4b 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -29,9 +29,10 @@ def update_generation_info(generation_info, html_info, img_index): return html_info, gr.update()
-def plaintext_to_html(text):
- text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
- return text
+def plaintext_to_html(text, classname=None):
+ content = "<br>\n".join(html.escape(x) for x in text.split('\n'))
+
+ return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
def save_files(js_data, images, do_make_zip, index):
@@ -157,7 +158,7 @@ Requested path was: {f} with gr.Group():
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
if tabname == 'txt2img' or tabname == 'img2img':
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index dff522ef..f3e4fba7 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,5 +1,5 @@ import json
-import os.path
+import os
import threading
import time
from datetime import datetime
@@ -513,14 +513,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" def preload_extensions_git_metadata():
- t0 = time.time()
for extension in extensions.extensions:
extension.read_info_from_repo()
- print(
- f"preload_extensions_git_metadata for "
- f"{len(extensions.extensions)} extensions took "
- f"{time.time() - t0:.2f}s"
- )
def create_ui():
@@ -570,7 +564,8 @@ def create_ui(): with gr.TabItem("Available", id="available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
- available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
+ extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
+ available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 693cafb6..49612298 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,14 +2,16 @@ import os.path import urllib.parse
from pathlib import Path
-from modules import shared
+from modules import shared, ui_extra_networks_user_metadata, errors
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr
import json
import html
+from fastapi.exceptions import HTTPException
from modules.generation_parameters_copypaste import image_from_url_text
+from modules.ui_components import ToolButton
extra_pages = []
allowed_dirs = set()
@@ -26,6 +28,9 @@ def register_page(page): def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
+ if not os.path.isfile(filename):
+ raise HTTPException(status_code=404, detail="File not found")
+
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
@@ -48,25 +53,71 @@ def get_metadata(page: str = "", item: str = ""): if metadata is None:
return JSONResponse({})
- return JSONResponse({"metadata": metadata})
+ return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
+
+
+def get_single_card(page: str = "", tabname: str = "", name: str = ""):
+ from starlette.responses import JSONResponse
+
+ page = next(iter([x for x in extra_pages if x.name == page]), None)
+
+ try:
+ item = page.create_item(name, enable_filter=False)
+ page.items[name] = item
+ except Exception as e:
+ errors.display(e, "creating item for extra network")
+ item = page.items.get(name)
+
+ page.read_user_metadata(item)
+ item_html = page.create_html_for_item(item, tabname)
+
+ return JSONResponse({"html": item_html})
def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
+ app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
+
+
+def quote_js(s):
+ s = s.replace('\\', '\\\\')
+ s = s.replace('"', '\\"')
+ return f'"{s}"'
class ExtraNetworksPage:
def __init__(self, title):
self.title = title
self.name = title.lower()
+ self.id_page = self.name.replace(" ", "_")
self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False
self.metadata = {}
+ self.items = {}
def refresh(self):
pass
+ def read_user_metadata(self, item):
+ filename = item.get("filename", None)
+ basename, ext = os.path.splitext(filename)
+ metadata_filename = basename + '.json'
+
+ metadata = {}
+ try:
+ if os.path.isfile(metadata_filename):
+ with open(metadata_filename, "r", encoding="utf8") as file:
+ metadata = json.load(file)
+ except Exception as e:
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+ desc = metadata.get("description", None)
+ if desc is not None:
+ item["description"] = desc
+
+ item["user_metadata"] = metadata
+
def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
mtime = os.path.getmtime(filename)
@@ -83,7 +134,6 @@ class ExtraNetworksPage: return ""
def create_html(self, tabname):
- view = shared.opts.extra_networks_default_view
items_html = ''
self.metadata = {}
@@ -119,11 +169,15 @@ class ExtraNetworksPage: </button>
""" for subdir in subdirs])
- for item in self.list_items():
+ self.items = {x["name"]: x for x in self.list_items()}
+ for item in self.items.values():
metadata = item.get("metadata")
if metadata:
self.metadata[item["name"]] = metadata
+ if "user_metadata" not in item:
+ self.read_user_metadata(item)
+
items_html += self.create_html_for_item(item, tabname)
if items_html == '':
@@ -133,16 +187,19 @@ class ExtraNetworksPage: self_name_id = self.name.replace(" ", "_")
res = f"""
-<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
+<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
{subdirs_html}
</div>
-<div id='{tabname}_{self_name_id}_cards' class='extra-network-{view}'>
+<div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
{items_html}
</div>
"""
return res
+ def create_item(self, name, index=None):
+ raise NotImplementedError()
+
def list_items(self):
raise NotImplementedError()
@@ -158,7 +215,7 @@ class ExtraNetworksPage: onclick = item.get("onclick", None)
if onclick is None:
- onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+ onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
@@ -166,7 +223,9 @@ class ExtraNetworksPage: metadata_button = ""
metadata = item.get("metadata")
if metadata:
- metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
+ metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
+
+ edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
local_path = ""
filename = item.get("filename", "")
@@ -190,16 +249,17 @@ class ExtraNetworksPage: args = {
"background_image": background_image,
- "style": f"'display: none; {height}{width}'",
+ "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
"prompt": item.get("prompt", None),
- "tabname": json.dumps(tabname),
- "local_preview": json.dumps(item["local_preview"]),
+ "tabname": quote_js(tabname),
+ "local_preview": quote_js(item["local_preview"]),
"name": item["name"],
- "description": (item.get("description") or ""),
+ "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
"card_clicked": onclick,
- "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
+ "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
+ "edit_button": edit_button,
"search_only": " search_only" if search_only else "",
"sort_keys": sort_keys,
}
@@ -247,6 +307,9 @@ class ExtraNetworksPage: pass
return None
+ def create_user_metadata_editor(self, ui, tabname):
+ return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)
+
def initialize():
extra_pages.clear()
@@ -297,23 +360,26 @@ def create_ui(container, button, tabname): ui = ExtraNetworksUi()
ui.pages = []
ui.pages_contents = []
+ ui.user_metadata_editors = []
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
for page in ui.stored_extra_pages:
- page_id = page.title.lower().replace(" ", "_")
-
- with gr.Tab(page.title, id=page_id):
- elem_id = f"{tabname}_{page_id}_cards_html"
+ with gr.Tab(page.title, id=page.id_page):
+ elem_id = f"{tabname}_{page.id_page}_cards_html"
page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
- page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
+ page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
+
+ editor = page.create_user_metadata_editor(ui, tabname)
+ editor.create_ui()
+ ui.user_metadata_editors.append(editor)
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
- gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
+ ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
@@ -363,6 +429,8 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery):
def save_preview(index, images, filename):
+ # this function is here for backwards compatibility and likely will be removed soon
+
if len(images) == 0:
print("There is no image in gallery to save as a preview.")
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
@@ -394,3 +462,7 @@ def setup_ui(ui, gallery): outputs=[*ui.pages]
)
+ for editor in ui.user_metadata_editors:
+ editor.setup_ui(gallery)
+
+
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 8b9ab71b..76780cfd 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -1,8 +1,8 @@ import html
-import json
import os
from modules import shared, ui_extra_networks, sd_models
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
@@ -12,21 +12,23 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def refresh(self):
shared.refresh_checkpoints()
+ def create_item(self, name, index=None):
+ checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
+ path, ext = os.path.splitext(checkpoint.filename)
+ return {
+ "name": checkpoint.name_for_extra,
+ "filename": checkpoint.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
+ "local_preview": f"{path}.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
+ }
+
def list_items(self):
- checkpoint: sd_models.CheckpointInfo
- for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()):
- path, ext = os.path.splitext(checkpoint.filename)
- yield {
- "name": checkpoint.name_for_extra,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
- "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
- "local_preview": f"{path}.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
-
- }
+ for index, name in enumerate(sd_models.checkpoints_list):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 7c19b532..e53ccb42 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -1,7 +1,7 @@ -import json
import os
from modules import shared, ui_extra_networks
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
@@ -11,21 +11,24 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def refresh(self):
shared.reload_hypernetworks()
+ def create_item(self, name, index=None):
+ full_path = shared.hypernetworks[name]
+ path, ext = os.path.splitext(full_path)
+
+ return {
+ "name": name,
+ "filename": full_path,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(path),
+ "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
+ }
+
def list_items(self):
- for index, (name, path) in enumerate(shared.hypernetworks.items()):
- path, ext = os.path.splitext(path)
-
- yield {
- "name": name,
- "filename": path,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(path),
- "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
- "local_preview": f"{path}.preview.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
-
- }
+ for index, name in enumerate(shared.hypernetworks):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return [shared.cmd_opts.hypernetwork_dir]
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 58a61c55..d1794e50 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,7 +1,7 @@ -import json
import os
from modules import ui_extra_networks, sd_hijack, shared
+from modules.ui_extra_networks import quote_js
class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
@@ -12,20 +12,24 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def refresh(self):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+ def create_item(self, name, index=None):
+ embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
+
+ path, ext = os.path.splitext(embedding.filename)
+ return {
+ "name": name,
+ "filename": embedding.filename,
+ "preview": self.find_preview(path),
+ "description": self.find_description(path),
+ "search_term": self.search_terms_from_path(embedding.filename),
+ "prompt": quote_js(embedding.name),
+ "local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
+ }
+
def list_items(self):
- for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()):
- path, ext = os.path.splitext(embedding.filename)
- yield {
- "name": embedding.name,
- "filename": embedding.filename,
- "preview": self.find_preview(path),
- "description": self.find_description(path),
- "search_term": self.search_terms_from_path(embedding.filename),
- "prompt": json.dumps(embedding.name),
- "local_preview": f"{path}.preview.{shared.opts.samples_format}",
- "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
-
- }
+ for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
+ yield self.create_item(name, index)
def allowed_directories_for_previews(self):
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py new file mode 100644 index 00000000..63d4b503 --- /dev/null +++ b/modules/ui_extra_networks_user_metadata.py @@ -0,0 +1,195 @@ +import datetime
+import html
+import json
+import os.path
+
+import gradio as gr
+
+from modules import generation_parameters_copypaste, images, sysinfo, errors
+
+
+class UserMetadataEditor:
+
+ def __init__(self, ui, tabname, page):
+ self.ui = ui
+ self.tabname = tabname
+ self.page = page
+ self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata"
+
+ self.box = None
+
+ self.edit_name_input = None
+ self.button_edit = None
+
+ self.edit_name = None
+ self.edit_description = None
+ self.edit_notes = None
+ self.html_filedata = None
+ self.html_preview = None
+ self.html_status = None
+
+ self.button_cancel = None
+ self.button_replace_preview = None
+ self.button_save = None
+
+ def get_user_metadata(self, name):
+ item = self.page.items.get(name, {})
+
+ user_metadata = item.get('user_metadata', None)
+ if user_metadata is None:
+ user_metadata = {}
+ item['user_metadata'] = user_metadata
+
+ return user_metadata
+
+ def create_extra_default_items_in_left_column(self):
+ pass
+
+ def create_default_editor_elems(self):
+ with gr.Row():
+ with gr.Column(scale=2):
+ self.edit_name = gr.HTML(elem_classes="extra-network-name")
+ self.edit_description = gr.Textbox(label="Description", lines=4)
+ self.html_filedata = gr.HTML()
+
+ self.create_extra_default_items_in_left_column()
+
+ with gr.Column(scale=1, min_width=0):
+ self.html_preview = gr.HTML()
+
+ def create_default_buttons(self):
+
+ with gr.Row(elem_classes="edit-user-metadata-buttons"):
+ self.button_cancel = gr.Button('Cancel')
+ self.button_replace_preview = gr.Button('Replace preview', variant='primary')
+ self.button_save = gr.Button('Save', variant='primary')
+
+ self.html_status = gr.HTML(elem_classes="edit-user-metadata-status")
+
+ self.button_cancel.click(fn=None, _js="closePopup")
+
+ def get_card_html(self, name):
+ item = self.page.items.get(name, {})
+
+ preview_url = item.get("preview", None)
+
+ if not preview_url:
+ filename, _ = os.path.splitext(item["filename"])
+ preview_url = self.page.find_preview(filename)
+ item["preview"] = preview_url
+
+ if preview_url:
+ preview = f'''
+ <div class='card standalone-card-preview'>
+ <img src="{html.escape(preview_url)}" class="preview">
+ </div>
+ '''
+ else:
+ preview = "<div class='card standalone-card-preview'></div>"
+
+ return preview
+
+ def get_metadata_table(self, name):
+ item = self.page.items.get(name, {})
+ try:
+ filename = item["filename"]
+
+ stats = os.stat(filename)
+ params = [
+ ('File size: ', sysinfo.pretty_bytes(stats.st_size)),
+ ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
+ ]
+
+ return params
+ except Exception as e:
+ errors.display(e, f"reading info for {name}")
+ return []
+
+ def put_values_into_components(self, name):
+ user_metadata = self.get_user_metadata(name)
+
+ try:
+ params = self.get_metadata_table(name)
+ except Exception as e:
+ errors.display(e, f"reading metadata info for {name}")
+ params = []
+
+ table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
+
+ return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
+
+ def write_user_metadata(self, name, metadata):
+ item = self.page.items.get(name, {})
+ filename = item.get("filename", None)
+ basename, ext = os.path.splitext(filename)
+
+ with open(basename + '.json', "w", encoding="utf8") as file:
+ json.dump(metadata, file)
+
+ def save_user_metadata(self, name, desc, notes):
+ user_metadata = self.get_user_metadata(name)
+ user_metadata["description"] = desc
+ user_metadata["notes"] = notes
+
+ self.write_user_metadata(name, user_metadata)
+
+ def setup_save_handler(self, button, func, components):
+ button\
+ .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\
+ .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[])
+
+ def create_editor(self):
+ self.create_default_editor_elems()
+
+ self.edit_notes = gr.TextArea(label='Notes', lines=4)
+
+ self.create_default_buttons()
+
+ self.button_edit\
+ .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\
+ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
+
+ self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])
+
+ def create_ui(self):
+ with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box:
+ self.box = box
+
+ self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name")
+ self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button")
+
+ self.create_editor()
+
+ def save_preview(self, index, gallery, name):
+ if len(gallery) == 0:
+ return self.get_card_html(name), "There is no image in gallery to save as a preview."
+
+ item = self.page.items.get(name, {})
+
+ index = int(index)
+ index = 0 if index < 0 else index
+ index = len(gallery) - 1 if index >= len(gallery) else index
+
+ img_info = gallery[index if index >= 0 else 0]
+ image = generation_parameters_copypaste.image_from_url_text(img_info)
+ geninfo, items = images.read_info_from_image(image)
+
+ images.save_image_with_geninfo(image, geninfo, item["local_preview"])
+
+ return self.get_card_html(name), ''
+
+ def setup_ui(self, gallery):
+ self.button_replace_preview.click(
+ fn=self.save_preview,
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
+ inputs=[self.edit_name_input, gallery, self.edit_name_input],
+ outputs=[self.html_preview, self.html_status]
+ ).then(
+ fn=None,
+ _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}",
+ inputs=[self.edit_name_input],
+ outputs=[]
+ )
+
+
+
diff --git a/requirements_versions.txt b/requirements_versions.txt index b826bf43..d07ab456 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -8,7 +8,7 @@ einops==0.4.1 fastapi==0.94.0
gfpgan==1.3.8
gradio==3.32.0
-httpcore<=0.15
+httpcore==0.15
inflection==0.5.1
jsonmerge==1.8.0
kornia==0.6.7
@@ -17,7 +17,7 @@ numpy==1.23.5 omegaconf==2.2.3
open-clip-torch==2.20.0
piexif==1.1.3
-psutil~=5.9.5
+psutil==5.9.5
pytorch_lightning==1.9.4
realesrgan==0.3.0
resize-right==0.0.2
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 7821cc65..1010845e 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -144,11 +144,20 @@ def apply_face_restore(p, opt, x): p.restore_faces = is_active
-def apply_override(field):
+def apply_override(field, boolean: bool = False):
def fun(p, x, xs):
+ if boolean:
+ x = True if x.lower() == "true" else False
p.override_settings[field] = x
return fun
+
+def boolean_choice(reverse: bool = False):
+ def choice():
+ return ["False", "True"] if reverse else ["True", "False"]
+ return choice
+
+
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@@ -235,6 +244,7 @@ axis_options = [ AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
+ AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),
]
@@ -227,20 +227,39 @@ button.custom-button{ align-self: end;
}
-.performance {
+.html-log .comments{
+ padding-top: 0.5em;
+}
+
+.html-log .comments:empty{
+ padding-top: 0;
+}
+
+.html-log .performance {
font-size: 0.85em;
color: #444;
+ display: flex;
}
-.performance p{
+.html-log .performance p{
display: inline-block;
}
-.performance .time {
- margin-right: 0;
+.html-log .performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr {
+ margin-bottom: 0;
+ color: var(--block-title-text-color);
+}
+
+.html-log .performance p.time {
+}
+
+.html-log .performance p.vram {
+ margin-left: auto;
}
-.performance .vram {
+.html-log .performance .measurement{
+ color: var(--body-text-color);
+ font-weight: bold;
}
#txt2img_generate, #img2img_generate {
@@ -531,6 +550,9 @@ table.popup-table .link{ background-color: rgba(20, 20, 20, 0.95);
}
+.global-popup *{
+ box-sizing: border-box;
+}
.global-popup-close:before {
content: "×";
@@ -761,8 +783,7 @@ footer { margin: 0 0.15em;
}
.extra-networks .tab-nav .search,
-.extra-networks .tab-nav .sort,
-.extra-networks .tab-nav .sortorder{
+.extra-networks .tab-nav .sort{
display: inline-block;
margin: 0.3em;
align-self: center;
@@ -782,117 +803,67 @@ footer { width: auto;
}
-.extra-network-cards .nocards, .extra-network-thumbs .nocards{
+.extra-network-cards .nocards{
margin: 1.25em 0.5em 0.5em 0.5em;
}
-.extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{
+.extra-network-cards .nocards h1{
font-size: 1.5em;
margin-bottom: 1em;
}
-.extra-network-cards .nocards li, .extra-network-thumbs .nocards li{
+.extra-network-cards .nocards li{
margin-left: 0.5em;
}
-.extra-network-cards .card .metadata-button:before, .extra-network-thumbs .card .metadata-button:before{
- content: "🛈";
-}
-.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
+.extra-network-cards .card .button-row{
display: none;
position: absolute;
color: white;
right: 0;
}
-.extra-network-cards .card .metadata-button {
- text-shadow: 2px 2px 3px black;
- padding: 0.25em;
- font-size: 22pt;
- width: 1.5em;
-}
-.extra-network-thumbs .card .metadata-button {
- text-shadow: 1px 1px 2px black;
- padding: 0;
- font-size: 16pt;
- width: 1em;
- top: -0.25em;
-}
-.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
- display: inline-block;
-}
-.extra-network-cards .card .metadata-button:hover, .extra-network-thumbs .card .metadata-button:hover{
- color: red;
-}
-
-
-.extra-network-thumbs {
+.extra-network-cards .card:hover .button-row{
display: flex;
- flex-flow: row wrap;
- gap: 10px;
}
-.extra-network-thumbs .card {
- height: 6em;
- width: 6em;
- cursor: pointer;
- background-image: url('./file=html/card-no-preview.png');
- background-size: cover;
- background-position: center center;
- position: relative;
+.extra-network-cards .card .card-button{
+ color: white;
}
-.extra-network-thumbs .card .preview{
- position: absolute;
- object-fit: cover;
- width: 100%;
- height:100%;
+.extra-network-cards .card .metadata-button:before{
+ content: "🛈";
}
-.extra-network-thumbs .card:hover .additional a {
- display: inline-block;
+.extra-network-cards .card .edit-button:before{
+ content: "🛠";
}
-.extra-network-thumbs .actions .additional a {
- background-image: url('./file=html/image-update.svg');
- background-repeat: no-repeat;
- background-size: cover;
- background-position: center center;
- position: absolute;
- top: 0;
- left: 0;
- width: 24px;
- height: 24px;
- display: none;
- font-size: 0;
- text-align: -9999;
+.extra-network-cards .card .card-button {
+ text-shadow: 2px 2px 3px black;
+ padding: 0.25em 0.1em;
+ font-size: 200%;
+ width: 1.5em;
}
+.extra-network-cards .card .card-button:hover{
+ color: red;
+}
+
-.extra-network-thumbs .actions .name {
+.standalone-card-preview.card .preview{
position: absolute;
- bottom: 0;
- font-size: 10px;
- padding: 3px;
+ object-fit: cover;
width: 100%;
- overflow: hidden;
- white-space: nowrap;
- text-overflow: ellipsis;
- background: rgba(0,0,0,.5);
- color: white;
-}
-
-.extra-network-thumbs .card:hover .actions .name {
- white-space: normal;
- word-break: break-all;
+ height:100%;
}
-.extra-network-cards .card{
+.extra-network-cards .card, .standalone-card-preview.card{
display: inline-block;
- margin: 0.5em;
- width: 16em;
- height: 24em;
+ margin: 0.5rem;
+ width: 16rem;
+ height: 24rem;
box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);
- border-radius: 0.2em;
+ border-radius: 0.2rem;
position: relative;
background-size: auto 100%;
@@ -926,10 +897,6 @@ footer { color: white;
}
-.extra-network-cards .card .actions:hover{
- box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important;
-}
-
.extra-network-cards .card .actions .name{
font-size: 1.7em;
font-weight: bold;
@@ -970,3 +937,37 @@ footer { width: 100%;
height:100%;
}
+
+div.block.gradio-box.edit-user-metadata {
+ width: 56em;
+ background: var(--body-background-fill);
+ padding: 2em !important;
+}
+
+.edit-user-metadata .extra-network-name{
+ font-size: 18pt;
+ color: var(--body-text-color);
+}
+
+.edit-user-metadata .file-metadata{
+ color: var(--body-text-color);
+}
+
+.edit-user-metadata .file-metadata th{
+ text-align: left;
+}
+
+.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{
+ padding: 0.3em 1em;
+}
+
+.edit-user-metadata .wrap.translucent{
+ background: var(--body-background-fill);
+}
+.edit-user-metadata .gradio-highlightedtext span{
+ word-break: break-word;
+}
+
+.edit-user-metadata-buttons{
+ margin-top: 1.5em;
+}
@@ -31,21 +31,22 @@ if log_level: logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors, devices # noqa: F401
-
+from modules import timer
startup_timer = timer.startup_timer
+startup_timer.record("launcher")
import torch
import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
-
-
startup_timer.record("import torch")
import gradio # noqa: F401
startup_timer.record("import gradio")
+from modules import paths, timer, import_hook, errors, devices # noqa: F401
+startup_timer.record("setup paths")
+
import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
@@ -4,8 +4,15 @@ # change the variables in webui-user.sh instead # ################################################# + +use_venv=1 +if [[ $venv_dir == "-" ]]; then + use_venv=0 +fi + SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + # If run from macOS, load defaults from webui-macos-env.sh if [[ "$OSTYPE" == "darwin"* ]]; then if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]] @@ -47,7 +54,7 @@ then fi # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) -if [[ -z "${venv_dir}" ]] +if [[ -z "${venv_dir}" ]] && [[ $use_venv -eq 1 ]] then venv_dir="venv" fi @@ -164,7 +171,7 @@ do fi done -if ! "${python_cmd}" -c "import venv" &>/dev/null +if [[ $use_venv -eq 1 ]] && ! "${python_cmd}" -c "import venv" &>/dev/null then printf "\n%s\n" "${delimiter}" printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m" @@ -184,7 +191,7 @@ else cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } fi -if [[ -z "${VIRTUAL_ENV}" ]]; +if [[ $use_venv -eq 1 ]] && [[ -z "${VIRTUAL_ENV}" ]]; then printf "\n%s\n" "${delimiter}" printf "Create and activate python venv" @@ -207,7 +214,7 @@ then fi else printf "\n%s\n" "${delimiter}" - printf "python venv already activate: ${VIRTUAL_ENV}" + printf "python venv already activate or run without venv: ${VIRTUAL_ENV}" printf "\n%s\n" "${delimiter}" fi |