aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md14
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py20
-rw-r--r--extensions-builtin/Lora/lora.py199
-rw-r--r--extensions-builtin/Lora/preload.py6
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py30
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py36
-rw-r--r--javascript/edit-attention.js110
-rw-r--r--javascript/extraNetworks.js20
-rw-r--r--javascript/hints.js5
-rw-r--r--javascript/ui.js5
-rw-r--r--modules/api/api.py13
-rw-r--r--modules/extra_networks_hypernet.py2
-rw-r--r--modules/extras.py228
-rw-r--r--modules/hypernetworks/hypernetwork.py2
-rw-r--r--modules/postprocessing.py103
-rw-r--r--modules/processing.py14
-rw-r--r--modules/script_callbacks.py15
-rw-r--r--modules/scripts.py28
-rw-r--r--modules/scripts_postprocessing.py147
-rw-r--r--modules/sd_disable_initialization.py4
-rw-r--r--modules/shared.py21
-rw-r--r--modules/ui.py315
-rw-r--r--modules/ui_common.py202
-rw-r--r--modules/ui_components.py1
-rw-r--r--modules/ui_extra_networks.py36
-rw-r--r--modules/ui_extra_networks_hypernets.py3
-rw-r--r--modules/ui_extra_networks_textual_inversion.py3
-rw-r--r--modules/ui_postprocessing.py57
-rw-r--r--scripts/postprocessing_codeformer.py36
-rw-r--r--scripts/postprocessing_gfpgan.py33
-rw-r--r--scripts/postprocessing_upscale.py106
-rw-r--r--scripts/xy_grid.py50
-rw-r--r--style.css22
-rw-r--r--test/basic_features/utils_test.py10
-rw-r--r--webui.py6
-rwxr-xr-xwebui.sh31
36 files changed, 1306 insertions, 627 deletions
diff --git a/README.md b/README.md
index 1ac794e8..9c0cd1ef 100644
--- a/README.md
+++ b/README.md
@@ -51,6 +51,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Possible to change defaults/mix/max/step values for UI elements via text config
- Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview
+ - Can use a separate neural network to produce previews with almost none VRAM or compute requirement
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
- Styles, a way to save part of prompt and easily apply them via dropdown later
- Variations, a way to generate same image but with tiny differences
@@ -75,13 +76,22 @@ A browser interface based on Gradio library for Stable Diffusion.
- hypernetworks and embeddings options
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
- Clip skip
-- Use Hypernetworks
-- Use VAEs
+- Hypernetworks
+- Loras (same as Hypernetworks but more pretty)
+- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
+- Can select to load a different VAE from settings screen
- Estimated completion time in progress bar
- API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
+- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
+- Now without any bad letters!
+- Load checkpoints in safetensors format
+- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
+- Now with a license!
+- Reorder elements in the UI from settings screen
+-
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
new file mode 100644
index 00000000..8f2e753e
--- /dev/null
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -0,0 +1,20 @@
+from modules import extra_networks
+import lora
+
+class ExtraNetworkLora(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('lora')
+
+ def activate(self, p, params_list):
+ names = []
+ multipliers = []
+ for params in params_list:
+ assert len(params.items) > 0
+
+ names.append(params.items[0])
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+ lora.load_loras(names, multipliers)
+
+ def deactivate(self, p):
+ pass
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
new file mode 100644
index 00000000..da1797dc
--- /dev/null
+++ b/extensions-builtin/Lora/lora.py
@@ -0,0 +1,199 @@
+import glob
+import os
+import re
+import torch
+
+from modules import shared, devices, sd_models
+
+re_digits = re.compile(r"\d+")
+re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
+re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
+re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
+re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
+
+
+def convert_diffusers_name_to_compvis(key):
+ def match(match_list, regex):
+ r = re.match(regex, key)
+ if not r:
+ return False
+
+ match_list.clear()
+ match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
+ return True
+
+ m = []
+
+ if match(m, re_unet_down_blocks):
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
+
+ if match(m, re_unet_mid_blocks):
+ return f"diffusion_model_middle_block_1_{m[1]}"
+
+ if match(m, re_unet_up_blocks):
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
+
+ if match(m, re_text_block):
+ return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
+
+ return key
+
+
+class LoraOnDisk:
+ def __init__(self, name, filename):
+ self.name = name
+ self.filename = filename
+
+
+class LoraModule:
+ def __init__(self, name):
+ self.name = name
+ self.multiplier = 1.0
+ self.modules = {}
+ self.mtime = None
+
+
+class LoraUpDownModule:
+ def __init__(self):
+ self.up = None
+ self.down = None
+
+
+def assign_lora_names_to_compvis_modules(sd_model):
+ lora_layer_mapping = {}
+
+ for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
+ lora_name = name.replace(".", "_")
+ lora_layer_mapping[lora_name] = module
+ module.lora_layer_name = lora_name
+
+ for name, module in shared.sd_model.model.named_modules():
+ lora_name = name.replace(".", "_")
+ lora_layer_mapping[lora_name] = module
+ module.lora_layer_name = lora_name
+
+ sd_model.lora_layer_mapping = lora_layer_mapping
+
+
+def load_lora(name, filename):
+ lora = LoraModule(name)
+ lora.mtime = os.path.getmtime(filename)
+
+ sd = sd_models.read_state_dict(filename)
+
+ keys_failed_to_match = []
+
+ for key_diffusers, weight in sd.items():
+ fullkey = convert_diffusers_name_to_compvis(key_diffusers)
+ key, lora_key = fullkey.split(".", 1)
+
+ sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
+ if sd_module is None:
+ keys_failed_to_match.append(key_diffusers)
+ continue
+
+ if type(sd_module) == torch.nn.Linear:
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif type(sd_module) == torch.nn.Conv2d:
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ else:
+ assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+
+ with torch.no_grad():
+ module.weight.copy_(weight)
+
+ module.to(device=devices.device, dtype=devices.dtype)
+
+ lora_module = lora.modules.get(key, None)
+ if lora_module is None:
+ lora_module = LoraUpDownModule()
+ lora.modules[key] = lora_module
+
+ if lora_key == "lora_up.weight":
+ lora_module.up = module
+ elif lora_key == "lora_down.weight":
+ lora_module.down = module
+ else:
+ assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight'
+
+ if len(keys_failed_to_match) > 0:
+ print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
+
+ return lora
+
+
+def load_loras(names, multipliers=None):
+ already_loaded = {}
+
+ for lora in loaded_loras:
+ if lora.name in names:
+ already_loaded[lora.name] = lora
+
+ loaded_loras.clear()
+
+ loras_on_disk = [available_loras.get(name, None) for name in names]
+ if any([x is None for x in loras_on_disk]):
+ list_available_loras()
+
+ loras_on_disk = [available_loras.get(name, None) for name in names]
+
+ for i, name in enumerate(names):
+ lora = already_loaded.get(name, None)
+
+ lora_on_disk = loras_on_disk[i]
+ if lora_on_disk is not None:
+ if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
+ lora = load_lora(name, lora_on_disk.filename)
+
+ if lora is None:
+ print(f"Couldn't find Lora with name {name}")
+ continue
+
+ lora.multiplier = multipliers[i] if multipliers else 1.0
+ loaded_loras.append(lora)
+
+
+def lora_forward(module, input, res):
+ if len(loaded_loras) == 0:
+ return res
+
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
+ for lora in loaded_loras:
+ module = lora.modules.get(lora_layer_name, None)
+ if module is not None:
+ res = res + module.up(module.down(input)) * lora.multiplier
+
+ return res
+
+
+def lora_Linear_forward(self, input):
+ return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
+
+
+def lora_Conv2d_forward(self, input):
+ return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
+
+
+def list_available_loras():
+ available_loras.clear()
+
+ os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
+
+ candidates = \
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
+ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
+
+ for filename in sorted(candidates):
+ if os.path.isdir(filename):
+ continue
+
+ name = os.path.splitext(os.path.basename(filename))[0]
+
+ available_loras[name] = LoraOnDisk(name, filename)
+
+
+available_loras = {}
+loaded_loras = []
+
+list_available_loras()
diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py
new file mode 100644
index 00000000..863dc5c0
--- /dev/null
+++ b/extensions-builtin/Lora/preload.py
@@ -0,0 +1,6 @@
+import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
new file mode 100644
index 00000000..60b9eb64
--- /dev/null
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -0,0 +1,30 @@
+import torch
+
+import lora
+import extra_networks_lora
+import ui_extra_networks_lora
+from modules import script_callbacks, ui_extra_networks, extra_networks
+
+
+def unload():
+ torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
+ torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
+
+
+def before_ui():
+ ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
+ extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
+
+
+if not hasattr(torch.nn, 'Linear_forward_before_lora'):
+ torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
+
+if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
+ torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
+
+torch.nn.Linear.forward = lora.lora_Linear_forward
+torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
+
+script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
+script_callbacks.on_script_unloaded(unload)
+script_callbacks.on_before_ui(before_ui)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
new file mode 100644
index 00000000..54a80d36
--- /dev/null
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -0,0 +1,36 @@
+import json
+import os
+import lora
+
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Lora')
+
+ def refresh(self):
+ lora.list_available_loras()
+
+ def list_items(self):
+ for name, lora_on_disk in lora.available_loras.items():
+ path, ext = os.path.splitext(lora_on_disk.filename)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ break
+
+ yield {
+ "name": name,
+ "filename": path,
+ "preview": preview,
+ "prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.lora_dir]
+
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index cec6a530..619bb1fa 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -1,74 +1,96 @@
-addEventListener('keydown', (event) => {
+function keyupEditAttention(event){
let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
if (! (event.metaKey || event.ctrlKey)) return;
-
- let plus = "ArrowUp"
- let minus = "ArrowDown"
- if (event.key != plus && event.key != minus) return;
+ let isPlus = event.key == "ArrowUp"
+ let isMinus = event.key == "ArrowDown"
+ if (!isPlus && !isMinus) return;
let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd;
- // If the user hasn't selected anything, let's select their current parenthesis block
- if (selectionStart === selectionEnd) {
+ let text = target.value;
+
+ function selectCurrentParenthesisBlock(OPEN, CLOSE){
+ if (selectionStart !== selectionEnd) return false;
+
// Find opening parenthesis around current cursor
- const before = target.value.substring(0, selectionStart);
- let beforeParen = before.lastIndexOf("(");
- if (beforeParen == -1) return;
- let beforeParenClose = before.lastIndexOf(")");
+ const before = text.substring(0, selectionStart);
+ let beforeParen = before.lastIndexOf(OPEN);
+ if (beforeParen == -1) return false;
+ let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
- beforeParen = before.lastIndexOf("(", beforeParen - 1);
- beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
+ beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
+ beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
}
// Find closing parenthesis around current cursor
- const after = target.value.substring(selectionStart);
- let afterParen = after.indexOf(")");
- if (afterParen == -1) return;
- let afterParenOpen = after.indexOf("(");
+ const after = text.substring(selectionStart);
+ let afterParen = after.indexOf(CLOSE);
+ if (afterParen == -1) return false;
+ let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
- afterParen = after.indexOf(")", afterParen + 1);
- afterParenOpen = after.indexOf("(", afterParenOpen + 1);
+ afterParen = after.indexOf(CLOSE, afterParen + 1);
+ afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
}
- if (beforeParen === -1 || afterParen === -1) return;
+ if (beforeParen === -1 || afterParen === -1) return false;
// Set the selection to the text between the parenthesis
- const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
+ const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd);
- }
+ return true;
+ }
+
+ // If the user hasn't selected anything, let's select their current parenthesis block
+ if(! selectCurrentParenthesisBlock('<', '>')){
+ selectCurrentParenthesisBlock('(', ')')
+ }
event.preventDefault();
- if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
- target.value = target.value.slice(0, selectionStart) +
- "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
- target.value.slice(selectionEnd);
+ closeCharacter = ')'
+ delta = opts.keyedit_precision_attention
+
+ if (selectionStart > 0 && text[selectionStart - 1] == '<'){
+ closeCharacter = '>'
+ delta = opts.keyedit_precision_extra
+ } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
+
+ // do not include spaces at the end
+ while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
+ selectionEnd -= 1;
+ }
+ if(selectionStart == selectionEnd){
+ return
+ }
- target.focus();
- target.selectionStart = selectionStart + 1;
- target.selectionEnd = selectionEnd + 1;
+ text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
- } else {
- end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
- weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
- if (isNaN(weight)) return;
- if (event.key == minus) weight -= 0.1;
- if (event.key == plus) weight += 0.1;
+ selectionStart += 1;
+ selectionEnd += 1;
+ }
- weight = parseFloat(weight.toPrecision(12));
+ end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
+ weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
+ if (isNaN(weight)) return;
- target.value = target.value.slice(0, selectionEnd + 1) +
- weight +
- target.value.slice(selectionEnd + 1 + end - 1);
+ weight += isPlus ? delta : -delta;
+ weight = parseFloat(weight.toPrecision(12));
+ if(String(weight).length == 1) weight += ".0"
- target.focus();
- target.selectionStart = selectionStart;
- target.selectionEnd = selectionEnd;
- }
+ text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
+
+ target.focus();
+ target.value = text;
+ target.selectionStart = selectionStart;
+ target.selectionEnd = selectionEnd;
updateInput(target)
-});
+}
+
+addEventListener('keydown', (event) => {
+ keyupEditAttention(event);
+}); \ No newline at end of file
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index 5e0d9714..c5a9adb3 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -2,8 +2,24 @@
function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
- gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh'))
- gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close'))
+ var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
+ var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
+ var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
+ var close = gradioApp().getElementById(tabname+'_extra_close')
+
+ search.classList.add('search')
+ tabs.appendChild(search)
+ tabs.appendChild(refresh)
+ tabs.appendChild(close)
+
+ search.addEventListener("input", function(evt){
+ searchTerm = search.value.toLowerCase()
+
+ gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
+ text = elem.querySelector('.name').textContent.toLowerCase()
+ elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
+ })
+ });
}
var activePromptTextarea = {};
diff --git a/javascript/hints.js b/javascript/hints.js
index ef410fba..3cf10e20 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -107,7 +107,10 @@ titles = {
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
- "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders."
+ "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
+ "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
+ "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
+ "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
}
diff --git a/javascript/ui.js b/javascript/ui.js
index 77256e15..ba72623c 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){
return res
}
-function get_extras_tab_index(){
- const [,,...args] = [...arguments]
- return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
-}
-
function get_img2img_tab_index() {
let res = args_to_array(arguments)
res.splice(-2)
diff --git a/modules/api/api.py b/modules/api/api.py
index f2e9e884..5d60fc0a 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.extras import run_extras
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
@@ -45,10 +44,8 @@ def validate_sampler_name(name):
def setUpscalers(req: dict):
reqDict = vars(req)
- reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
- reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
- reqDict.pop('upscaler_1')
- reqDict.pop('upscaler_2')
+ reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
+ reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict
def decode_base64_to_image(encoding):
@@ -244,7 +241,7 @@ class Api:
reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
- result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
+ result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
@@ -260,7 +257,7 @@ class Api:
reqDict.pop('imageList')
with self.queue_lock:
- result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
+ result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
index 6a0c4ba8..ff279a1f 100644
--- a/modules/extra_networks_hypernet.py
+++ b/modules/extra_networks_hypernet.py
@@ -17,5 +17,5 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
hypernetwork.load_hypernetworks(names, multipliers)
- def deactivate(p, self):
+ def deactivate(self, p):
pass
diff --git a/modules/extras.py b/modules/extras.py
index 1218f88f..36123aa5 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -1,230 +1,16 @@
-from __future__ import annotations
-import math
import os
-import sys
-import traceback
+import re
import shutil
-import numpy as np
-from PIL import Image
import torch
import tqdm
-from typing import Callable, List, OrderedDict, Tuple
-from functools import partial
-from dataclasses import dataclass
-
-from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
-from modules.shared import opts
-import modules.gfpgan_model
-from modules.ui import plaintext_to_html
-import modules.codeformer_model
+from modules import shared, images, sd_models, sd_vae
+from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
-class LruCache(OrderedDict):
- @dataclass(frozen=True)
- class Key:
- image_hash: int
- info_hash: int
- args_hash: int
-
- @dataclass
- class Value:
- image: Image.Image
- info: str
-
- def __init__(self, max_size: int = 5, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._max_size = max_size
-
- def get(self, key: LruCache.Key) -> LruCache.Value:
- ret = super().get(key)
- if ret is not None:
- self.move_to_end(key) # Move to end of eviction list
- return ret
-
- def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
- self[key] = value
- while len(self) > self._max_size:
- self.popitem(last=False)
-
-
-cached_images: LruCache = LruCache(max_size=5)
-
-
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
- devices.torch_gc()
-
- shared.state.begin()
- shared.state.job = 'extras'
-
- imageArr = []
- # Also keep track of original file names
- imageNameArr = []
- outputs = []
-
- if extras_mode == 1:
- #convert file to pillow image
- for img in image_folder:
- image = Image.open(img)
- imageArr.append(image)
- imageNameArr.append(os.path.splitext(img.orig_name)[0])
- elif extras_mode == 2:
- assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
-
- if input_dir == '':
- return outputs, "Please select an input directory.", ''
- image_list = shared.listfiles(input_dir)
- for img in image_list:
- try:
- image = Image.open(img)
- except Exception:
- continue
- imageArr.append(image)
- imageNameArr.append(img)
- else:
- imageArr.append(image)
- imageNameArr.append(None)
-
- if extras_mode == 2 and output_dir != '':
- outpath = output_dir
- else:
- outpath = opts.outdir_samples or opts.outdir_extras_samples
-
- # Extra operation definitions
-
- def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- shared.state.job = 'extras-gfpgan'
- restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
- res = Image.fromarray(restored_img)
-
- if gfpgan_visibility < 1.0:
- res = Image.blend(image, res, gfpgan_visibility)
-
- info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
- return (res, info)
-
- def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- shared.state.job = 'extras-codeformer'
- restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
- res = Image.fromarray(restored_img)
-
- if codeformer_visibility < 1.0:
- res = Image.blend(image, res, codeformer_visibility)
-
- info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
- return (res, info)
-
- def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
- shared.state.job = 'extras-upscale'
- upscaler = shared.sd_upscalers[scaler_index]
- res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
- if mode == 1 and crop:
- cropped = Image.new("RGB", (resize_w, resize_h))
- cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
- res = cropped
- return res
-
- def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
- nonlocal upscaling_resize
- if resize_mode == 1:
- upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
- crop_info = " (crop)" if upscaling_crop else ""
- info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
- return (image, info)
-
- @dataclass
- class UpscaleParams:
- upscaler_idx: int
- blend_alpha: float
-
- def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- blended_result: Image.Image = None
- image_hash: str = hash(np.array(image.getdata()).tobytes())
- for upscaler in params:
- upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
- upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- cache_key = LruCache.Key(image_hash=image_hash,
- info_hash=hash(info),
- args_hash=hash(upscale_args))
- cached_entry = cached_images.get(cache_key)
- if cached_entry is None:
- res = upscale(image, *upscale_args)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
- cached_images.put(cache_key, LruCache.Value(image=res, info=info))
- else:
- res, info = cached_entry.image, cached_entry.info
-
- if blended_result is None:
- blended_result = res
- else:
- blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
- return (blended_result, info)
-
- # Build a list of operations to run
- facefix_ops: List[Callable] = []
- facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
- facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
-
- upscale_ops: List[Callable] = []
- upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
-
- if upscaling_resize != 0:
- step_params: List[UpscaleParams] = []
- step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
-
- upscale_ops.append(partial(run_upscalers_blend, step_params))
-
- extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
-
- for image, image_name in zip(imageArr, imageNameArr):
- if image is None:
- return outputs, "Please select an input image.", ''
-
- shared.state.textinfo = f'Processing image {image_name}'
-
- existing_pnginfo = image.info or {}
-
- image = image.convert("RGB")
- info = ""
- # Run each operation on each image
- for op in extras_ops:
- image, info = op(image, info)
-
- if opts.use_original_name_batch and image_name is not None:
- basename = os.path.splitext(os.path.basename(image_name))[0]
- else:
- basename = ''
-
- if opts.enable_pnginfo: # append info before save
- image.info = existing_pnginfo
- image.info["extras"] = info
-
- if save_output:
- # Add upscaler name as a suffix.
- suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
- # Add second upscaler if applicable.
- if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
- suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
-
- images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
-
- if extras_mode != 2 or show_extras_results :
- outputs.append(image)
-
- devices.torch_gc()
-
- return outputs, plaintext_to_html(info), ''
-
-def clear_cache():
- cached_images.clear()
-
def run_pnginfo(image):
if image is None:
@@ -285,7 +71,7 @@ def to_half(tensor, enable):
return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -430,6 +216,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
for key in theta_0.keys():
theta_0[key] = to_half(theta_0[key], save_as_half)
+ if discard_weights:
+ regex = re.compile(discard_weights)
+ for key in list(theta_0):
+ if re.search(regex, key):
+ theta_0.pop(key, None)
+
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = filename_generator() if custom_name == '' else custom_name
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 80a47c79..503534e2 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -715,6 +715,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
do_not_save_samples=True,
)
+ p.disable_extra_networks = True
+
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
new file mode 100644
index 00000000..8514fea7
--- /dev/null
+++ b/modules/postprocessing.py
@@ -0,0 +1,103 @@
+import os
+
+from PIL import Image
+
+from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
+from modules.shared import opts
+
+
+def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
+ devices.torch_gc()
+
+ shared.state.begin()
+ shared.state.job = 'extras'
+
+ image_data = []
+ image_names = []
+ outputs = []
+
+ if extras_mode == 1:
+ for img in image_folder:
+ image = Image.open(img)
+ image_data.append(image)
+ image_names.append(os.path.splitext(img.orig_name)[0])
+ elif extras_mode == 2:
+ assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
+ assert input_dir, 'input directory not selected'
+
+ image_list = shared.listfiles(input_dir)
+ for filename in image_list:
+ try:
+ image = Image.open(filename)
+ except Exception:
+ continue
+ image_data.append(image)
+ image_names.append(filename)
+ else:
+ assert image, 'image not selected'
+
+ image_data.append(image)
+ image_names.append(None)
+
+ if extras_mode == 2 and output_dir != '':
+ outpath = output_dir
+ else:
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
+
+ infotext = ''
+
+ for image, name in zip(image_data, image_names):
+ shared.state.textinfo = name
+
+ existing_pnginfo = image.info or {}
+
+ pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
+
+ scripts.scripts_postproc.run(pp, args)
+
+ if opts.use_original_name_batch and name is not None:
+ basename = os.path.splitext(os.path.basename(name))[0]
+ else:
+ basename = ''
+
+ infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+
+ if opts.enable_pnginfo:
+ pp.image.info = existing_pnginfo
+ pp.image.info["postprocessing"] = infotext
+
+ if save_output:
+ images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
+
+ if extras_mode != 2 or show_extras_results:
+ outputs.append(pp.image)
+
+ devices.torch_gc()
+
+ return outputs, ui_common.plaintext_to_html(infotext), ''
+
+
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
+ """old handler for API"""
+
+ args = scripts.scripts_postproc.create_args_for_run({
+ "Upscale": {
+ "upscale_mode": resize_mode,
+ "upscale_by": upscaling_resize,
+ "upscale_to_width": upscaling_resize_w,
+ "upscale_to_height": upscaling_resize_h,
+ "upscale_crop": upscaling_crop,
+ "upscaler_1_name": extras_upscaler_1,
+ "upscaler_2_name": extras_upscaler_2,
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
+ },
+ "GFPGAN": {
+ "gfpgan_visibility": gfpgan_visibility,
+ },
+ "CodeFormer": {
+ "codeformer_visibility": codeformer_visibility,
+ "codeformer_weight": codeformer_weight,
+ },
+ })
+
+ return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
diff --git a/modules/processing.py b/modules/processing.py
index b5deeacf..bc541e2f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -140,6 +140,7 @@ class StableDiffusionProcessing:
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
+ self.disable_extra_networks = False
if not seed_enable_extras:
self.subseed = -1
@@ -532,6 +533,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
+
if p.scripts is not None:
p.scripts.process(p)
@@ -561,13 +564,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
cache[0] = (required_prompts, steps)
return cache[1]
- p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts)
-
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
- extra_networks.activate(p, extra_network_data)
+ if not p.disable_extra_networks:
+ extra_networks.activate(p, extra_network_data)
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
@@ -593,6 +595,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
+ prompts, _ = extra_networks.parse_prompts(prompts)
+
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
@@ -682,7 +686,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- extra_networks.deactivate(p, extra_network_data)
+ if not p.disable_extra_networks:
+ extra_networks.deactivate(p, extra_network_data)
+
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index a9e19236..4bb45ec7 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -73,6 +73,7 @@ callback_map = dict(
callbacks_image_grid=[],
callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
+ callbacks_before_ui=[],
)
@@ -189,6 +190,14 @@ def script_unloaded_callback():
report_exception(c, 'script_unloaded')
+def before_ui_callback():
+ for c in reversed(callback_map['callbacks_before_ui']):
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'before_ui')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -313,3 +322,9 @@ def on_script_unloaded(callback):
the script did should be reverted here"""
add_callback(callback_map['callbacks_script_unloaded'], callback)
+
+
+def on_before_ui(callback):
+ """register a function to be called before the UI is created."""
+
+ add_callback(callback_map['callbacks_before_ui'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index 4ffc369b..03907a63 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -7,7 +7,7 @@ from collections import namedtuple
import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks, extensions, script_loading
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object()
@@ -150,8 +150,10 @@ def basedir():
return current_basedir
-scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+
+scripts_data = []
+postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
@@ -190,23 +192,31 @@ def list_files_with_name(filename):
def load_scripts():
global current_basedir
scripts_data.clear()
+ postprocessing_scripts_data.clear()
script_callbacks.clear_callbacks()
scripts_list = list_scripts("scripts", ".py")
syspath = sys.path
+ def register_scripts_from_module(module):
+ for key, script_class in module.__dict__.items():
+ if type(script_class) != type:
+ continue
+
+ if issubclass(script_class, Script):
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+ elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
+ postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+
for scriptfile in sorted(scripts_list):
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
- module = script_loading.load_module(scriptfile.path)
-
- for key, script_class in module.__dict__.items():
- if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+ script_module = script_loading.load_module(scriptfile.path)
+ register_scripts_from_module(script_module)
except Exception:
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@@ -413,6 +423,7 @@ class ScriptRunner:
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
scripts_current: ScriptRunner = None
@@ -423,12 +434,13 @@ def reload_script_body_only():
def reload_scripts():
- global scripts_txt2img, scripts_img2img
+ global scripts_txt2img, scripts_img2img, scripts_postproc
load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def IOComponent_init(self, *args, **kwargs):
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
new file mode 100644
index 00000000..25de02d0
--- /dev/null
+++ b/modules/scripts_postprocessing.py
@@ -0,0 +1,147 @@
+import os
+import gradio as gr
+
+from modules import errors, shared
+
+
+class PostprocessedImage:
+ def __init__(self, image):
+ self.image = image
+ self.info = {}
+
+
+class ScriptPostprocessing:
+ filename = None
+ controls = None
+ args_from = None
+ args_to = None
+
+ order = 1000
+ """scripts will be ordred by this value in postprocessing UI"""
+
+ name = None
+ """this function should return the title of the script."""
+
+ group = None
+ """A gr.Group component that has all script's UI inside it"""
+
+ def ui(self):
+ """
+ This function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be a dictionary that maps parameter names to components used in processing.
+ Values of those components will be passed to process() function.
+ """
+
+ pass
+
+ def process(self, pp: PostprocessedImage, **args):
+ """
+ This function is called to postprocess the image.
+ args contains a dictionary with all values returned by components from ui()
+ """
+
+ pass
+
+ def image_changed(self):
+ pass
+
+
+def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
+ try:
+ res = func(*args, **kwargs)
+ return res
+ except Exception as e:
+ errors.display(e, f"calling {filename}/{funcname}")
+
+ return default
+
+
+class ScriptPostprocessingRunner:
+ def __init__(self):
+ self.scripts = None
+ self.ui_created = False
+
+ def initialize_scripts(self, scripts_data):
+ self.scripts = []
+
+ for script_class, path, basedir, script_module in scripts_data:
+ script: ScriptPostprocessing = script_class()
+ script.filename = path
+
+ self.scripts.append(script)
+
+ def create_script_ui(self, script, inputs):
+ script.args_from = len(inputs)
+ script.args_to = len(inputs)
+
+ script.controls = wrap_call(script.ui, script.filename, "ui")
+
+ for control in script.controls.values():
+ control.custom_script_source = os.path.basename(script.filename)
+
+ inputs += list(script.controls.values())
+ script.args_to = len(inputs)
+
+ def scripts_in_preferred_order(self):
+ if self.scripts is None:
+ import modules.scripts
+ self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
+
+ scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
+
+ def script_score(name):
+ name = name.lower()
+ for i, possible_match in enumerate(scripts_order):
+ if possible_match in name:
+ return i
+
+ return len(self.scripts)
+
+ script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
+
+ return sorted(self.scripts, key=lambda x: script_scores[x.name])
+
+ def setup_ui(self):
+ inputs = []
+
+ for script in self.scripts_in_preferred_order():
+ with gr.Box() as group:
+ self.create_script_ui(script, inputs)
+
+ script.group = group
+
+ self.ui_created = True
+ return inputs
+
+ def run(self, pp: PostprocessedImage, args):
+ for script in self.scripts_in_preferred_order():
+ shared.state.job = script.name
+
+ script_args = args[script.args_from:script.args_to]
+
+ process_args = {}
+ for (name, component), value in zip(script.controls.items(), script_args):
+ process_args[name] = value
+
+ script.process(pp, **process_args)
+
+ def create_args_for_run(self, scripts_args):
+ if not self.ui_created:
+ with gr.Blocks(analytics_enabled=False):
+ self.setup_ui()
+
+ scripts = self.scripts_in_preferred_order()
+ args = [None] * max([x.args_to for x in scripts])
+
+ for script in scripts:
+ script_args_dict = scripts_args.get(script.name, None)
+ if script_args_dict is not None:
+
+ for i, name in enumerate(script.controls):
+ args[script.args_from + i] = script_args_dict.get(name, None)
+
+ return args
+
+ def image_changed(self):
+ for script in self.scripts_in_preferred_order():
+ script.image_changed()
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index c72d8efc..e90aa9fe 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -41,7 +41,9 @@ class DisableInitialization:
return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
- return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+ res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+ res.name_or_path = pretrained_model_name_or_path
+ return res
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
diff --git a/modules/shared.py b/modules/shared.py
index 23328adf..a644c0be 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -101,6 +101,8 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
+parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
+
script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
@@ -397,7 +399,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
@@ -405,7 +407,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -442,9 +445,12 @@ options_templates.update(options_section(('ui', "User interface"), {
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
- 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
- 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
- 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
+ "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
+ "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
+ "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
options_templates.update(options_section(('ui', "Live previews"), {
@@ -469,6 +475,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
}))
+options_templates.update(options_section(('postprocessing', "Postprocessing"), {
+ 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
+ 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+}))
+
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
diff --git a/modules/ui.py b/modules/ui.py
index fbc3efa0..85ae62c7 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -5,7 +5,6 @@ import mimetypes
import os
import platform
import random
-import subprocess as sp
import sys
import tempfile
import time
@@ -20,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -41,6 +40,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img
from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
+import modules.extras
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
@@ -75,6 +75,7 @@ css_hide_progressbar = """
.wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
+.wrap.cover-bg .z-20::before { content:"" }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
.meta-text-center { display:none!important; }
@@ -85,7 +86,6 @@ css_hide_progressbar = """
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
paste_symbol = '\u2199\ufe0f' # ↙
-folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
@@ -94,78 +94,14 @@ extra_networks_symbol = '\U0001F3B4' # 🎴
def plaintext_to_html(text):
- text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
- return text
+ return ui_common.plaintext_to_html(text)
+
def send_gradio_gallery_to_image(x):
if len(x) == 0:
return None
return image_from_url_text(x[0])
-def save_files(js_data, images, do_make_zip, index):
- import csv
- filenames = []
- fullfns = []
-
- #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
- class MyObject:
- def __init__(self, d=None):
- if d is not None:
- for key, value in d.items():
- setattr(self, key, value)
-
- data = json.loads(js_data)
-
- p = MyObject(data)
- path = opts.outdir_save
- save_to_dirs = opts.use_save_to_dirs_for_ui
- extension: str = opts.samples_format
- start_index = 0
-
- if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
- images = [images[index]]
- start_index = index
-
- os.makedirs(opts.outdir_save, exist_ok=True)
-
- with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
- at_start = file.tell() == 0
- writer = csv.writer(file)
- if at_start:
- writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
-
- for image_index, filedata in enumerate(images, start_index):
- image = image_from_url_text(filedata)
-
- is_grid = image_index < p.index_of_first_image
- i = 0 if is_grid else (image_index - p.index_of_first_image)
-
- fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
-
- filename = os.path.relpath(fullfn, path)
- filenames.append(filename)
- fullfns.append(fullfn)
- if txt_fullfn:
- filenames.append(os.path.basename(txt_fullfn))
- fullfns.append(txt_fullfn)
-
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
-
- # Make Zip
- if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
-
- from zipfile import ZipFile
- with ZipFile(zip_filepath, "w") as zip_file:
- for i in range(len(fullfns)):
- with open(fullfns[i], mode="rb") as f:
- zip_file.writestr(filenames[i], f.read())
- fullfns.insert(0, zip_filepath)
-
- return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-
-
def visit(x, func, path=""):
if hasattr(x, 'children'):
for c in x.children:
@@ -443,19 +379,6 @@ def apply_setting(key, value):
opts.save(shared.config_filename)
return getattr(opts, key)
-
-def update_generation_info(generation_info, html_info, img_index):
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info, gr.update()
- return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
- except Exception:
- pass
- # if the json parse or anything else fails, just return the old html_info
- return html_info, gr.update()
-
-
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
@@ -476,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
def create_output_panel(tabname, outdir):
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- elif "microsoft-standard-WSL2" in platform.uname().release:
- sp.Popen(["wsl-open", path])
- else:
- sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
- with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
- generation_info = None
- with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
-
- if tabname != "extras":
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
-
- buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
-
- open_folder_button.click(
- fn=lambda: open_folder(opts.outdir_samples or outdir),
- inputs=[],
- outputs=[],
- )
-
- if tabname != "extras":
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
-
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- if tabname == 'txt2img' or tabname == 'img2img':
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
- generation_info_button.click(
- fn=update_generation_info,
- _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
- inputs=[generation_info, html_info, html_info],
- outputs=[html_info, html_info],
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ],
- show_progress=False,
- )
-
- save_zip.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+ return ui_common.create_output_panel(tabname, outdir)
def create_sampler_and_steps_selection(choices, tabname):
@@ -918,7 +741,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
elif category == "checkboxes":
- with FormRow(elem_id="img2img_checkboxes"):
+ with FormRow(elem_id="img2img_checkboxes", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
@@ -1105,86 +928,7 @@ def create_ui():
modules.scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface:
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='compact'):
- with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image', elem_id="extras_single_tab"):
- extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
-
- with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
-
- with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
- show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
-
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
- with gr.Tabs(elem_id="extras_resize_mode"):
- with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
- upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
- with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
- with gr.Group():
- with gr.Row():
- upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
- upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
-
- with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
-
- with gr.Group():
- extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
- extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
-
- with gr.Group():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
-
- with gr.Group():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
-
- with gr.Group():
- upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
-
- result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
-
- submit.click(
- fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
- _js="get_extras_tab_index",
- inputs=[
- dummy_component,
- dummy_component,
- extras_image,
- image_batch,
- extras_batch_input_dir,
- extras_batch_output_dir,
- show_extras_results,
- gfpgan_visibility,
- codeformer_visibility,
- codeformer_weight,
- upscaling_resize,
- upscaling_resize_w,
- upscaling_resize_h,
- upscaling_crop,
- extras_upscaler_1,
- extras_upscaler_2,
- extras_upscaler_2_visibility,
- upscale_before_face_fix,
- ],
- outputs=[
- result_images,
- html_info_x,
- html_info,
- ]
- )
- parameters_copypaste.add_paste_fields("extras", extras_image, None)
-
- extras_image.change(
- fn=modules.extras.clear_cache,
- inputs=[], outputs=[]
- )
+ ui_postprocessing.create_ui()
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False):
@@ -1205,10 +949,19 @@ def create_ui():
outputs=[html, generation_info, html2],
)
+ def update_interp_description(value):
+ interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
+ interp_descriptions = {
+ "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
+ "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
+ "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
+ }
+ return interp_descriptions[value]
+
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact'):
- gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
+ interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
@@ -1223,6 +976,7 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
@@ -1237,6 +991,9 @@ def create_ui():
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
+ with FormRow():
+ discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
+
with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
@@ -1248,7 +1005,7 @@ def create_ui():
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
- with gr.Row().style(equal_height=False):
+ with gr.Row(variant="compact").style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"):
@@ -1827,6 +1584,7 @@ def create_ui():
checkpoint_format,
config_source,
bake_in_vae,
+ discard_weights,
],
outputs=[
primary_model_name,
@@ -1897,7 +1655,7 @@ def create_ui():
if type(x) == gr.Dropdown:
def check_dropdown(val):
- if x.multiselect:
+ if getattr(x, 'multiselect', False):
return all([value in x.choices for value in val])
else:
return val in x.choices
@@ -1914,28 +1672,27 @@ def create_ui():
with open(ui_config_file, "w", encoding="utf8") as file:
json.dump(ui_settings, file, indent=4)
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
+ interp_description.value = update_interp_description(interp_method.value)
+
return demo
def reload_javascript():
- with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
- javascript = f'<script>{jsfile.read()}</script>'
-
- scripts_list = modules.scripts.list_scripts("javascript", ".js")
-
- for basedir, filename, path in scripts_list:
- with open(path, "r", encoding="utf8") as jsfile:
- javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
+ head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}"></script>\n'
+ inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
- javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
+ inline += f"set_theme('{cmd_opts.theme}');"
+
+ for script in modules.scripts.list_scripts("javascript", ".js"):
+ head += f'<script type="text/javascript" src="file={script.path}"></script>\n'
- javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
+ head += f'<script type="text/javascript">{inline}</script>\n'
def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(
- b'</head>', f'{javascript}</head>'.encode("utf8"))
+ res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8"))
res.init_headers()
return res
diff --git a/modules/ui_common.py b/modules/ui_common.py
new file mode 100644
index 00000000..9405ac1f
--- /dev/null
+++ b/modules/ui_common.py
@@ -0,0 +1,202 @@
+import json
+import html
+import os
+import platform
+import sys
+
+import gradio as gr
+import subprocess as sp
+
+from modules import call_queue, shared
+from modules.generation_parameters_copypaste import image_from_url_text
+import modules.images
+
+folder_symbol = '\U0001f4c2' # 📂
+
+
+def update_generation_info(generation_info, html_info, img_index):
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info, gr.update()
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info, gr.update()
+
+
+def plaintext_to_html(text):
+ text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
+ return text
+
+
+def save_files(js_data, images, do_make_zip, index):
+ import csv
+ filenames = []
+ fullfns = []
+
+ #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
+ class MyObject:
+ def __init__(self, d=None):
+ if d is not None:
+ for key, value in d.items():
+ setattr(self, key, value)
+
+ data = json.loads(js_data)
+
+ p = MyObject(data)
+ path = shared.opts.outdir_save
+ save_to_dirs = shared.opts.use_save_to_dirs_for_ui
+ extension: str = shared.opts.samples_format
+ start_index = 0
+
+ if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
+
+ images = [images[index]]
+ start_index = index
+
+ os.makedirs(shared.opts.outdir_save, exist_ok=True)
+
+ with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
+ at_start = file.tell() == 0
+ writer = csv.writer(file)
+ if at_start:
+ writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
+
+ for image_index, filedata in enumerate(images, start_index):
+ image = image_from_url_text(filedata)
+
+ is_grid = image_index < p.index_of_first_image
+ i = 0 if is_grid else (image_index - p.index_of_first_image)
+
+ fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
+
+ filename = os.path.relpath(fullfn, path)
+ filenames.append(filename)
+ fullfns.append(fullfn)
+ if txt_fullfn:
+ filenames.append(os.path.basename(txt_fullfn))
+ fullfns.append(txt_fullfn)
+
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
+
+ # Make Zip
+ if do_make_zip:
+ zip_filepath = os.path.join(path, "images.zip")
+
+ from zipfile import ZipFile
+ with ZipFile(zip_filepath, "w") as zip_file:
+ for i in range(len(fullfns)):
+ with open(fullfns[i], mode="rb") as f:
+ zip_file.writestr(filenames[i], f.read())
+ fullfns.insert(0, zip_filepath)
+
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
+
+
+def create_output_panel(tabname, outdir):
+ from modules import shared
+ import modules.generation_parameters_copypaste as parameters_copypaste
+
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+
+ open_folder_button.click(
+ fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
+
+ with gr.Group():
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info],
+ )
+
+ save.click(
+ fn=call_queue.wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ save_zip.click(
+ fn=call_queue.wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ else:
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_components.py b/modules/ui_components.py
index 46324425..9aec3097 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -47,3 +47,4 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
def get_block_name(self):
return "colorpicker"
+
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 253e90f7..af2b8071 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -10,7 +10,7 @@ extra_pages = []
def register_page(page):
- """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions"""
+ """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
extra_pages.append(page)
@@ -18,6 +18,7 @@ def register_page(page):
class ExtraNetworksPage:
def __init__(self, title):
self.title = title
+ self.name = title.lower()
self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False
@@ -34,7 +35,11 @@ class ExtraNetworksPage:
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
- res = "<div class='extra-network-cards'>" + items_html + "</div>"
+ res = f"""
+<div id='{tabname}_{self.name}_cards' class='extra-network-cards'>
+{items_html}
+</div>
+"""
return res
@@ -49,7 +54,7 @@ class ExtraNetworksPage:
args = {
"preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
- "prompt": json.dumps(item["prompt"]),
+ "prompt": item["prompt"],
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
@@ -74,21 +79,38 @@ class ExtraNetworksUi:
self.tabname = None
+def pages_in_preferred_order(pages):
+ tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
+
+ def tab_name_score(name):
+ name = name.lower()
+ for i, possible_match in enumerate(tab_order):
+ if possible_match in name:
+ return i
+
+ return len(pages)
+
+ tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
+
+ return sorted(pages, key=lambda x: tab_scores[x.name])
+
+
def create_ui(container, button, tabname):
ui = ExtraNetworksUi()
ui.pages = []
- ui.stored_extra_pages = extra_pages.copy()
+ ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
- button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
- button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
-
for page in ui.stored_extra_pages:
with gr.Tab(page.title):
page_elem = gr.HTML(page.create_html(ui.tabname))
ui.pages.append(page_elem)
+ filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
+ button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
+ button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
+
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 312dbaf0..65d000cf 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -1,3 +1,4 @@
+import json
import os
from modules import shared, ui_extra_networks
@@ -25,7 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
"name": name,
"filename": path,
"preview": preview,
- "prompt": f"<hypernet:{name}:1.0>",
+ "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index e4a6e3bf..dbd23d2d 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -1,3 +1,4 @@
+import json
import os
from modules import ui_extra_networks, sd_hijack
@@ -24,7 +25,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
"name": embedding.name,
"filename": embedding.filename,
"preview": preview,
- "prompt": embedding.name,
+ "prompt": json.dumps(embedding.name),
"local_preview": path + ".preview.png",
}
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
new file mode 100644
index 00000000..b418d955
--- /dev/null
+++ b/modules/ui_postprocessing.py
@@ -0,0 +1,57 @@
+import gradio as gr
+from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+import modules.generation_parameters_copypaste as parameters_copypaste
+
+
+def create_ui():
+ tab_index = gr.State(value=0)
+
+ with gr.Row().style(equal_height=False, variant='compact'):
+ with gr.Column(variant='compact'):
+ with gr.Tabs(elem_id="mode_extras"):
+ with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
+
+ with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
+
+ with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
+
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+
+ script_inputs = scripts.scripts_postproc.setup_ui()
+
+ with gr.Column():
+ result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
+
+ tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
+ tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
+ tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
+
+ submit.click(
+ fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
+ inputs=[
+ tab_index,
+ extras_image,
+ image_batch,
+ extras_batch_input_dir,
+ extras_batch_output_dir,
+ show_extras_results,
+ *script_inputs
+ ],
+ outputs=[
+ result_images,
+ html_info_x,
+ html_info,
+ ]
+ )
+
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
+
+ extras_image.change(
+ fn=scripts.scripts_postproc.image_changed,
+ inputs=[], outputs=[]
+ )
diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py
new file mode 100644
index 00000000..a7d80d40
--- /dev/null
+++ b/scripts/postprocessing_codeformer.py
@@ -0,0 +1,36 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, codeformer_model
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
+ name = "CodeFormer"
+ order = 3000
+
+ def ui(self):
+ with FormRow():
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
+
+ return {
+ "codeformer_visibility": codeformer_visibility,
+ "codeformer_weight": codeformer_weight,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
+ if codeformer_visibility == 0:
+ return
+
+ restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
+ res = Image.fromarray(restored_img)
+
+ if codeformer_visibility < 1.0:
+ res = Image.blend(pp.image, res, codeformer_visibility)
+
+ pp.image = res
+ pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3)
+ pp.info["CodeFormer weight"] = round(codeformer_weight, 3)
diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py
new file mode 100644
index 00000000..d854f3f7
--- /dev/null
+++ b/scripts/postprocessing_gfpgan.py
@@ -0,0 +1,33 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, gfpgan_model
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
+ name = "GFPGAN"
+ order = 2000
+
+ def ui(self):
+ with FormRow():
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
+
+ return {
+ "gfpgan_visibility": gfpgan_visibility,
+ }
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
+ if gfpgan_visibility == 0:
+ return
+
+ restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
+ res = Image.fromarray(restored_img)
+
+ if gfpgan_visibility < 1.0:
+ res = Image.blend(pp.image, res, gfpgan_visibility)
+
+ pp.image = res
+ pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3)
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
new file mode 100644
index 00000000..095d29b2
--- /dev/null
+++ b/scripts/postprocessing_upscale.py
@@ -0,0 +1,106 @@
+from PIL import Image
+import numpy as np
+
+from modules import scripts_postprocessing, shared
+import gradio as gr
+
+from modules.ui_components import FormRow
+
+
+upscale_cache = {}
+
+
+class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
+ name = "Upscale"
+ order = 1000
+
+ def ui(self):
+ selected_tab = gr.State(value=0)
+
+ with gr.Tabs(elem_id="extras_resize_mode"):
+ with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
+
+ with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
+ with FormRow():
+ upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
+
+ with FormRow():
+ extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+
+ with FormRow():
+ extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
+
+ tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
+ tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
+
+ return {
+ "upscale_mode": selected_tab,
+ "upscale_by": upscaling_resize,
+ "upscale_to_width": upscaling_resize_w,
+ "upscale_to_height": upscaling_resize_h,
+ "upscale_crop": upscaling_crop,
+ "upscaler_1_name": extras_upscaler_1,
+ "upscaler_2_name": extras_upscaler_2,
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
+ }
+
+ def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop):
+ if upscale_mode == 1:
+ upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height)
+ info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}"
+ else:
+ info["Postprocess upscale by"] = upscale_by
+
+ cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ cached_image = upscale_cache.pop(cache_key, None)
+
+ if cached_image is not None:
+ image = cached_image
+ else:
+ image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path)
+
+ upscale_cache[cache_key] = image
+ if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache:
+ upscale_cache.pop(next(iter(upscale_cache), None), None)
+
+ if upscale_mode == 1 and upscale_crop:
+ cropped = Image.new("RGB", (upscale_to_width, upscale_to_height))
+ cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2))
+ image = cropped
+ info["Postprocess crop to"] = f"{image.width}x{image.height}"
+
+ return image
+
+ def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
+ if upscaler_1_name == "None":
+ upscaler_1_name = None
+
+ upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None)
+ assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}'
+
+ if not upscaler1:
+ return
+
+ if upscaler_2_name == "None":
+ upscaler_2_name = None
+
+ upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None)
+ assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'
+
+ upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ pp.info[f"Postprocess upscaler"] = upscaler1.name
+
+ if upscaler2 and upscaler_2_visibility > 0:
+ second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
+ upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)
+
+ pp.info[f"Postprocess upscaler 2"] = upscaler2.name
+
+ pp.image = upscaled_image
+
+ def image_changed(self):
+ upscale_cache.clear()
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index b1badec9..1a452355 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -165,12 +165,16 @@ class AxisOption:
self.confirm = confirm
self.cost = cost
self.choices = choices
- self.is_img2img = False
class AxisOptionImg2Img(AxisOption):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ self.is_img2img = True
+
+class AxisOptionTxt2Img(AxisOption):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
self.is_img2img = False
@@ -180,10 +184,12 @@ axis_options = [
AxisOption("Var. seed", int, apply_field("subseed")),
AxisOption("Var. strength", float, apply_field("subseed_strength")),
AxisOption("Steps", int, apply_field("steps")),
+ AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
AxisOption("CFG Scale", float, apply_field("cfg_scale")),
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
- AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+ AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
+ AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
AxisOption("Sigma Churn", float, apply_field("s_churn")),
AxisOption("Sigma min", float, apply_field("s_tmin")),
@@ -192,8 +198,8 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
- AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]),
- AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
+ AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
+ AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
]
@@ -288,42 +294,41 @@ class Script(scripts.Script):
return "X/Y plot"
def ui(self, is_img2img):
- current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img]
+ self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]
with gr.Row():
with gr.Column(scale=19):
with gr.Row():
- x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
+ x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)
with gr.Row():
- y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
+ y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)
- with gr.Row(variant="compact"):
+ with gr.Row(variant="compact", elem_id="axis_options"):
draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")
def swap_axes(x_type, x_values, y_type, y_values):
- nonlocal current_axis_options
- return current_axis_options[y_type].label, y_values, current_axis_options[x_type].label, x_values
+ return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values
swap_args = [x_type, x_values, y_type, y_values]
swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
def fill(x_type):
- axis = axis_options[x_type]
+ axis = self.current_axis_options[x_type]
return ", ".join(axis.choices()) if axis.choices else gr.update()
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
def select_axis(x_type):
- return gr.Button.update(visible=axis_options[x_type].choices is not None)
+ return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
@@ -398,10 +403,10 @@ class Script(scripts.Script):
return valslist
- x_opt = axis_options[x_type]
+ x_opt = self.current_axis_options[x_type]
xs = process_axis(x_opt, x_values)
- y_opt = axis_options[y_type]
+ y_opt = self.current_axis_options[y_type]
ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list):
@@ -422,10 +427,21 @@ class Script(scripts.Script):
total_steps = p.steps * len(xs) * len(ys)
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
- total_steps *= 2
+ if x_opt.label == "Hires steps":
+ total_steps += sum(xs) * len(ys)
+ elif y_opt.label == "Hires steps":
+ total_steps += sum(ys) * len(xs)
+ elif p.hr_second_pass_steps:
+ total_steps += p.hr_second_pass_steps * len(xs) * len(ys)
+ else:
+ total_steps *= 2
+
+ total_steps *= p.n_iter
- print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
- shared.total_tqdm.updateTotal(total_steps * p.n_iter)
+ image_cell_count = p.n_iter * p.batch_size
+ cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else ""
+ print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})")
+ shared.total_tqdm.updateTotal(total_steps)
grid_infotext = [None]
diff --git a/style.css b/style.css
index 04bf2982..bf8260d7 100644
--- a/style.css
+++ b/style.css
@@ -589,7 +589,7 @@ canvas[key="mask"] {
/* Extensions */
-#tab_extensions table``{
+#tab_extensions table{
border-collapse: collapse;
}
@@ -707,12 +707,24 @@ footer {
#txt2img_checkboxes, #img2img_checkboxes{
margin-bottom: 0.5em;
+ margin-left: 0em;
}
#txt2img_checkboxes > div, #img2img_checkboxes > div{
flex: 0;
white-space: nowrap;
min-width: auto;
}
+#txt2img_hires_fix{
+ margin-left: -0.8em;
+}
+
+#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
+ margin-left: 0em;
+}
+
+#axis_options {
+ margin-left: 0em;
+}
.inactive{
opacity: 0.5;
@@ -774,6 +786,14 @@ footer {
margin: 0.3em;
}
+
+
+#txt2img_extra_networks .search, #img2img_extra_networks .search{
+ display: inline-block;
+ max-width: 16em;
+ margin: 0.3em;
+}
+
.extra-network-cards .nocards{
margin: 1.25em 0.5em 0.5em 0.5em;
}
diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py
index 94e00253..0bfc28a0 100644
--- a/test/basic_features/utils_test.py
+++ b/test/basic_features/utils_test.py
@@ -12,8 +12,6 @@ class UtilsTests(unittest.TestCase):
self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
- self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
- self.url_artists = "http://localhost:7860/sdapi/v1/artists"
self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
def test_options_get(self):
@@ -56,15 +54,9 @@ class UtilsTests(unittest.TestCase):
def test_prompt_styles(self):
self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
-
- def test_artist_categories(self):
- self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
-
- def test_artists(self):
- self.assertEqual(requests.get(self.url_artists).status_code, 200)
def test_embeddings(self):
- self.assertEqual(requests.get(self.url_artists).status_code, 200)
+ self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
if __name__ == "__main__":
unittest.main()
diff --git a/webui.py b/webui.py
index e8dd822a..7cf5885e 100644
--- a/webui.py
+++ b/webui.py
@@ -22,7 +22,6 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
import modules.codeformer_model as codeformer
-import modules.extras
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
@@ -165,8 +164,13 @@ def webui():
if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()
+ modules.script_callbacks.before_ui_callback()
+
shared.demo = modules.ui.create_ui()
+ if cmd_opts.gradio_queue:
+ shared.demo.queue(64)
+
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
diff --git a/webui.sh b/webui.sh
index 1edf921d..8cdad22d 100755
--- a/webui.sh
+++ b/webui.sh
@@ -104,6 +104,23 @@ then
fi
# Check prerequisites
+gpu_info=$(lspci 2>/dev/null | grep VGA)
+case "$gpu_info" in
+ *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
+ ;;
+ *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
+ printf "\n%s\n" "${delimiter}"
+ printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
+ printf "\n%s\n" "${delimiter}"
+ ;;
+ *)
+ ;;
+esac
+if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+then
+ export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
+fi
+
for preq in "${GIT}" "${python_cmd}"
do
if ! hash "${preq}" &>/dev/null
@@ -164,16 +181,6 @@ then
else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
- printf "\n%s\n" "${delimiter}"
- gpu_info=$(lspci 2>/dev/null | grep VGA)
- if echo "$gpu_info" | grep -q "AMD"
- then
- if [[ -z "${TORCH_COMMAND}" ]]
- then
- export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
- fi
- HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
- else
- exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
- fi
+ printf "\n%s\n" "${delimiter}"
+ exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi