diff options
author | missionfloyd <missionfloyd@users.noreply.github.com> | 2023-05-25 18:53:33 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-25 18:53:33 -0600 |
commit | 6645f23c4c715b1bc704c88a499b2f4224d7f1e6 (patch) | |
tree | 6aeb51e366254fe8993856a3db341690bb39dca5 /extensions-builtin | |
parent | 43bdaa2f0eda79c685792b06a2bd84c65806a48f (diff) | |
parent | a6e653be26cc05f4438145fa0082816e9fbbf5fc (diff) |
Merge branch 'dev' into reorder-hotkeys
Diffstat (limited to 'extensions-builtin')
-rw-r--r-- | extensions-builtin/LDSR/ldsr_model_arch.py | 13 | ||||
-rw-r--r-- | extensions-builtin/LDSR/scripts/ldsr_model.py | 7 | ||||
-rw-r--r-- | extensions-builtin/LDSR/sd_hijack_autoencoder.py | 28 | ||||
-rw-r--r-- | extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 66 | ||||
-rw-r--r-- | extensions-builtin/Lora/extra_networks_lora.py | 18 | ||||
-rw-r--r-- | extensions-builtin/Lora/lora.py | 80 | ||||
-rw-r--r-- | extensions-builtin/Lora/scripts/lora_script.py | 38 | ||||
-rw-r--r-- | extensions-builtin/Lora/ui_extra_networks_lora.py | 5 | ||||
-rw-r--r-- | extensions-builtin/ScuNET/scripts/scunet_model.py | 19 | ||||
-rw-r--r-- | extensions-builtin/ScuNET/scunet_model_arch.py | 11 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/scripts/swinir_model.py | 9 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/swinir_model_arch.py | 6 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/swinir_model_arch_v2.py | 58 | ||||
-rw-r--r-- | extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js | 52 |
14 files changed, 256 insertions, 154 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index bc11cc6e..7f450086 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -88,7 +88,7 @@ class LDSR: x_t = None logs = None - for n in range(n_runs): + for _ in range(n_runs): if custom_shape is not None: x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) @@ -110,7 +110,6 @@ class LDSR: diffusion_steps = int(steps) eta = 1.0 - down_sample_method = 'Lanczos' gc.collect() if torch.cuda.is_available: @@ -131,11 +130,11 @@ class LDSR: im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) else: print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) - + logs = self.run(model["model"], im_padded, diffusion_steps, eta) sample = logs["sample"] @@ -158,7 +157,7 @@ class LDSR: def get_cond(selected_path): - example = dict() + example = {} up_f = 4 c = selected_path.convert('RGB') c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) @@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s @torch.no_grad() def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): - log = dict() + log = {} z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, return_first_stage_outputs=True, @@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) log["sample_noquant"] = x_sample_noquant log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) - except: + except Exception: pass log["sample"] = x_sample diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index da19cff1..c4da79f3 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks -import sd_hijack_autoencoder, sd_hijack_ddpm_v1 +import sd_hijack_autoencoder # noqa: F401 +import sd_hijack_ddpm_v1 # noqa: F401 class UpscalerLDSR(Upscaler): @@ -44,9 +45,9 @@ class UpscalerLDSR(Upscaler): if local_safetensors_path is not None and os.path.exists(local_safetensors_path): model = local_safetensors_path else: - model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True) + model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True) - yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True) + yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True) try: return LDSR(model, yaml) diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 8e03c7f8..81c5101b 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -1,16 +1,21 @@ # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder - +import numpy as np import torch import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager + +from torch.optim.lr_scheduler import LambdaLR + +from ldm.modules.ema import LitEma from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.util import instantiate_from_config import ldm.models.autoencoder +from packaging import version class VQModel(pl.LightningModule): def __init__(self, @@ -19,7 +24,7 @@ class VQModel(pl.LightningModule): n_embed, embed_dim, ckpt_path=None, - ignore_keys=[], + ignore_keys=None, image_key="image", colorize_nlabels=None, monitor=None, @@ -57,7 +62,7 @@ class VQModel(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or []) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor @@ -76,11 +81,11 @@ class VQModel(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list()): + def init_from_ckpt(self, path, ignore_keys=None): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -165,7 +170,7 @@ class VQModel(pl.LightningModule): def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + self._validation_step(batch, batch_idx, suffix="_ema") return log_dict def _validation_step(self, batch, batch_idx, suffix=""): @@ -232,7 +237,7 @@ class VQModel(pl.LightningModule): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: @@ -249,7 +254,8 @@ class VQModel(pl.LightningModule): if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) log["reconstructions_ema"] = xrec_ema return log @@ -264,7 +270,7 @@ class VQModel(pl.LightningModule): class VQModelInterface(VQModel): def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) + super().__init__(*args, embed_dim=embed_dim, **kwargs) self.embed_dim = embed_dim def encode(self, x): @@ -282,5 +288,5 @@ class VQModelInterface(VQModel): dec = self.decoder(quant) return dec -setattr(ldm.models.autoencoder, "VQModel", VQModel) -setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) +ldm.models.autoencoder.VQModel = VQModel +ldm.models.autoencoder.VQModelInterface = VQModelInterface diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 5c0488e5..631a08ef 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=None, load_only_unet=False, monitor="val/loss", use_ema=True, @@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule): if monitor is not None: self.monitor = monitor if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) @@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=None, only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + super().__init__(*args, conditioning_key=conditioning_key, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1): self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None + self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: @@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1): z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): + if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] @@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1): if hasattr(self, "split_input_params"): assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids + assert not return_ids ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) @@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] @@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, @@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1): use_ddim = ddim_steps is not None - log = dict() + log = {} z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1): if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: @@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. @@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) + logs = super().log_images(*args, batch=batch, N=N, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] @@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): logs['bbox_image'] = cond_img return logs -setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) -setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) -setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) -setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) +ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1 +ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1 +ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1 +ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1 diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index ccb249ac..b5fea4d2 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -23,5 +23,23 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): lora.load_loras(names, multipliers)
+ if shared.opts.lora_add_hashes_to_infotext:
+ lora_hashes = []
+ for item in lora.loaded_loras:
+ shorthash = item.lora_on_disk.shorthash
+ if not shorthash:
+ continue
+
+ alias = item.mentioned_name
+ if not alias:
+ continue
+
+ alias = alias.replace(":", "").replace(",", "")
+
+ lora_hashes.append(f"{alias}: {shorthash}")
+
+ if lora_hashes:
+ p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
+
def deactivate(self, p):
pass
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index b5d0c98f..eec14712 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -1,10 +1,9 @@ -import glob
import os
import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors, scripts
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -77,9 +76,9 @@ class LoraOnDisk: self.name = name
self.filename = filename
self.metadata = {}
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
- _, ext = os.path.splitext(filename)
- if ext.lower() == ".safetensors":
+ if self.is_safetensors:
try:
self.metadata = sd_models.read_metadata_from_safetensors(filename)
except Exception as e:
@@ -95,14 +94,43 @@ class LoraOnDisk: self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
+ self.hash = None
+ self.shorthash = None
+ self.set_hash(
+ self.metadata.get('sshs_model_hash') or
+ hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
+ ''
+ )
+
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
+ if self.shorthash:
+ available_lora_hash_lookup[self.shorthash] = self
+
+ def read_hash(self):
+ if not self.hash:
+ self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
+
+ def get_alias(self):
+ if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
+ return self.name
+ else:
+ return self.alias
+
class LoraModule:
- def __init__(self, name):
+ def __init__(self, name, lora_on_disk: LoraOnDisk):
self.name = name
+ self.lora_on_disk = lora_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
+ self.mentioned_name = None
+ """the text that was used to add lora to prompt - can be either name or an alias"""
+
class LoraUpDownModule:
def __init__(self):
@@ -127,11 +155,11 @@ def assign_lora_names_to_compvis_modules(sd_model): sd_model.lora_layer_mapping = lora_layer_mapping
-def load_lora(name, filename):
- lora = LoraModule(name)
- lora.mtime = os.path.getmtime(filename)
+def load_lora(name, lora_on_disk):
+ lora = LoraModule(name, lora_on_disk)
+ lora.mtime = os.path.getmtime(lora_on_disk.filename)
- sd = sd_models.read_state_dict(filename)
+ sd = sd_models.read_state_dict(lora_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
@@ -177,7 +205,7 @@ def load_lora(name, filename): else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+ raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -189,10 +217,10 @@ def load_lora(name, filename): elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+ raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
- print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
+ print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
return lora
@@ -207,30 +235,41 @@ def load_loras(names, multipliers=None): loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
- if any([x is None for x in loras_on_disk]):
+ if any(x is None for x in loras_on_disk):
list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
+ failed_to_load_loras = []
+
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i]
+
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
try:
- lora = load_lora(name, lora_on_disk.filename)
+ lora = load_lora(name, lora_on_disk)
except Exception as e:
errors.display(e, f"loading Lora {lora_on_disk.filename}")
continue
+ lora.mentioned_name = name
+
+ lora_on_disk.read_hash()
+
if lora is None:
+ failed_to_load_loras.append(name)
print(f"Couldn't find Lora with name {name}")
continue
lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora)
+ if len(failed_to_load_loras) > 0:
+ sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
+
def lora_calc_updown(lora, module, target):
with torch.no_grad():
@@ -314,7 +353,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu print(f'failed to calculate lora weights for layer {lora_layer_name}')
- setattr(self, "lora_current_names", wanted_names)
+ self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward):
@@ -348,8 +387,8 @@ def lora_forward(module, input, original_forward): def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+ self.lora_current_names = ()
+ self.lora_weights_backup = None
def lora_Linear_forward(self, input):
@@ -398,7 +437,8 @@ def list_available_loras(): available_loras.clear()
available_lora_aliases.clear()
forbidden_lora_aliases.clear()
- forbidden_lora_aliases.update({"none": 1})
+ available_lora_hash_lookup.clear()
+ forbidden_lora_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
@@ -428,7 +468,7 @@ def infotext_pasted(infotext, params): added = []
- for k, v in params.items():
+ for k in params:
if not k.startswith("AddNet Model "):
continue
@@ -452,8 +492,10 @@ def infotext_pasted(infotext, params): if added:
params["Prompt"] += "\n" + "".join(added)
+
available_loras = {}
available_lora_aliases = {}
+available_lora_hash_lookup = {}
forbidden_lora_aliases = {}
loaded_loras = []
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 060bda05..e650f469 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,3 +1,5 @@ +import re
+
import torch
import gradio as gr
from fastapi import FastAPI
@@ -53,8 +55,9 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
- "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
+ "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+ "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
}))
@@ -77,6 +80,37 @@ def api_loras(_: gr.Blocks, app: FastAPI): async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
+ @app.post("/sdapi/v1/refresh-loras")
+ async def refresh_loras():
+ return lora.list_available_loras()
+
script_callbacks.on_app_started(api_loras)
+re_lora = re.compile("<lora:([^:]+):")
+
+
+def infotext_pasted(infotext, d):
+ hashes = d.get("Lora hashes")
+ if not hashes:
+ return
+
+ hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
+ hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
+
+ def lora_replacement(m):
+ alias = m.group(1)
+ shorthash = hashes.get(alias)
+ if shorthash is None:
+ return m.group(0)
+
+ lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
+ if lora_on_disk is None:
+ return m.group(0)
+
+ return f'<lora:{lora_on_disk.get_alias()}:'
+
+ d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
+
+
+script_callbacks.on_infotext_pasted(infotext_pasted)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 2050e3fa..259e99ac 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -16,10 +16,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename)
- if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
- alias = name
- else:
- alias = lora_on_disk.alias
+ alias = lora_on_disk.get_alias()
yield {
"name": name,
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index c7fd5739..45d9297b 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -10,10 +10,9 @@ from tqdm import tqdm from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import devices, modelloader +from modules import devices, modelloader, script_callbacks from scunet_model_arch import SCUNet as net from modules.shared import opts -from modules import images class UpscalerScuNET(modules.upscaler.Upscaler): @@ -122,8 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def load_model(self, path: str): device = devices.get_device_for('scunet') if "http" in path: - filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, - progress=True) + filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True) else: filename = path if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: @@ -133,8 +131,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler): model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model.load_state_dict(torch.load(filename), strict=True) model.eval() - for k, v in model.named_parameters(): + for _, v in model.named_parameters(): v.requires_grad = False model = model.to(device) return model + + +def on_ui_settings(): + import gradio as gr + from modules import shared + + shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) + shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py index 43ca8d36..b51a8806 100644 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ b/extensions-builtin/ScuNET/scunet_model_arch.py @@ -61,7 +61,9 @@ class WMSA(nn.Module): Returns: output: tensor shape [b h w c] """ - if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + if self.type != 'W': + x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) h_windows = x.size(1) w_windows = x.size(2) @@ -85,8 +87,9 @@ class WMSA(nn.Module): output = self.linear(output) output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), - dims=(1, 2)) + if self.type != 'W': + output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) + return output def relative_embedding(self): @@ -262,4 +265,4 @@ class SCUNet(nn.Module): nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0)
\ No newline at end of file + nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index e8783bca..1c7bf325 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,4 +1,3 @@ -import contextlib import os import numpy as np @@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import cmd_opts, opts, state +from modules.shared import opts, state from swinir_model_arch import SwinIR as net from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -45,14 +44,14 @@ class UpscalerSwinIR(Upscaler): img = upscale(img, model) try: torch.cuda.empty_cache() - except: + except Exception: pass return img def load_model(self, path, scale=4): if "http" in path: dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") - filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) + filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True) else: filename = path if filename is None or not os.path.exists(filename): @@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale): for w_idx in w_idx_list: if state.interrupted or state.skipped: break - + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index 863f42db..93b93274 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -644,7 +644,7 @@ class SwinIR(nn.Module): """ def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, @@ -805,7 +805,7 @@ class SwinIR(nn.Module): def forward(self, x): H, W = x.shape[2:] x = self.check_image_size(x) - + self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range @@ -844,7 +844,7 @@ class SwinIR(nn.Module): H, W = self.patches_resolution flops += H * W * 3 * self.embed_dim * 9 flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): + for layer in self.layers: flops += layer.flops() flops += H * W * 3 * self.embed_dim * self.embed_dim flops += self.upsample.flops() diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 0e28ae6e..dad22cca 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -74,7 +74,7 @@ class WindowAttention(nn.Module): """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
+ pretrained_window_size=(0, 0)):
super().__init__()
self.dim = dim
@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = None
self.register_buffer("attn_mask", attn_mask)
-
+
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
- return attn_mask
+ return attn_mask
def forward(self, x, x_size):
H, W = x_size
@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module): attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
+
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
@@ -369,7 +369,7 @@ class PatchMerging(nn.Module): H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
- return flops
+ return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
@@ -447,7 +447,7 @@ class BasicLayer(nn.Module): nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
-
+
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module): flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
- return flops
+ return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
@@ -531,7 +531,7 @@ class RSTB(nn.Module): num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
@@ -622,7 +622,7 @@ class Upsample(nn.Sequential): else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
-
+
class Upsample_hf(nn.Sequential):
"""Upsample module.
@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential): m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
+ super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential): H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
-
-
+
+
class Swin2SR(nn.Module):
r""" Swin2SR
@@ -698,8 +698,8 @@ class Swin2SR(nn.Module): """
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True,
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
@@ -764,7 +764,7 @@ class Swin2SR(nn.Module): num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -776,7 +776,7 @@ class Swin2SR(nn.Module): )
self.layers.append(layer)
-
+
if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers):
@@ -787,7 +787,7 @@ class Swin2SR(nn.Module): num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -799,7 +799,7 @@ class Swin2SR(nn.Module): )
self.layers_hf.append(layer)
-
+
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
@@ -829,10 +829,10 @@ class Swin2SR(nn.Module): self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
+ nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
@@ -846,7 +846,7 @@ class Swin2SR(nn.Module): nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
@@ -905,7 +905,7 @@ class Swin2SR(nn.Module): x = self.patch_unembed(x, x_size)
return x
-
+
def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
@@ -919,7 +919,7 @@ class Swin2SR(nn.Module): x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
- return x
+ return x
def forward(self, x):
H, W = x.shape[2:]
@@ -951,7 +951,7 @@ class Swin2SR(nn.Module): x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before))
-
+
x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf)
@@ -977,15 +977,15 @@ class Swin2SR(nn.Module): x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
-
+
x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
+
elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
+
else:
return x[:, :, :H*self.upscale, :W*self.upscale]
@@ -994,7 +994,7 @@ class Swin2SR(nn.Module): H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
@@ -1014,4 +1014,4 @@ if __name__ == '__main__': x = torch.randn((1, 3, height, width))
x = model(x)
- print(x.shape)
\ No newline at end of file + print(x.shape)
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index 5c7a836a..114cf94c 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -4,39 +4,39 @@ // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. function checkBrackets(textArea, counterElt) { - var counts = {}; - (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => { - counts[bracket] = (counts[bracket] || 0) + 1; - }); - var errors = []; + var counts = {}; + (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => { + counts[bracket] = (counts[bracket] || 0) + 1; + }); + var errors = []; - function checkPair(open, close, kind) { - if (counts[open] !== counts[close]) { - errors.push( - `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.` - ); + function checkPair(open, close, kind) { + if (counts[open] !== counts[close]) { + errors.push( + `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.` + ); + } } - } - checkPair('(', ')', 'round brackets'); - checkPair('[', ']', 'square brackets'); - checkPair('{', '}', 'curly brackets'); - counterElt.title = errors.join('\n'); - counterElt.classList.toggle('error', errors.length !== 0); + checkPair('(', ')', 'round brackets'); + checkPair('[', ']', 'square brackets'); + checkPair('{', '}', 'curly brackets'); + counterElt.title = errors.join('\n'); + counterElt.classList.toggle('error', errors.length !== 0); } function setupBracketChecking(id_prompt, id_counter) { - var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); - var counter = gradioApp().getElementById(id_counter) + var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); + var counter = gradioApp().getElementById(id_counter); - if (textarea && counter) { - textarea.addEventListener("input", () => checkBrackets(textarea, counter)); - } + if (textarea && counter) { + textarea.addEventListener("input", () => checkBrackets(textarea, counter)); + } } -onUiLoaded(function () { - setupBracketChecking('txt2img_prompt', 'txt2img_token_counter'); - setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); - setupBracketChecking('img2img_prompt', 'img2img_token_counter'); - setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter'); +onUiLoaded(function() { + setupBracketChecking('txt2img_prompt', 'txt2img_token_counter'); + setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); + setupBracketChecking('img2img_prompt', 'img2img_token_counter'); + setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter'); }); |