aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.eslintrc.js7
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.yml4
-rw-r--r--CHANGELOG.md2
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py7
-rw-r--r--extensions-builtin/LDSR/sd_hijack_autoencoder.py2
-rw-r--r--extensions-builtin/LDSR/vqvae_quantize.py147
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py6
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js431
-rw-r--r--javascript/aspectRatioOverlay.js2
-rw-r--r--javascript/contextMenus.js4
-rw-r--r--javascript/generationParams.js2
-rw-r--r--javascript/hints.js81
-rw-r--r--javascript/imageMaskFix.js2
-rw-r--r--javascript/imageviewer.js2
-rw-r--r--javascript/imageviewerGamepad.js8
-rw-r--r--javascript/notification.js2
-rw-r--r--javascript/ui.js2
-rw-r--r--modules/api/api.py14
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/call_queue.py22
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/codeformer_model.py10
-rw-r--r--modules/config_states.py12
-rw-r--r--modules/errors.py24
-rw-r--r--modules/extensions.py19
-rw-r--r--modules/generation_parameters_copypaste.py18
-rw-r--r--modules/gfpgan_model.py6
-rw-r--r--modules/gitpython_hack.py42
-rw-r--r--modules/hypernetworks/hypernetwork.py14
-rw-r--r--modules/images.py49
-rw-r--r--modules/img2img.py3
-rw-r--r--modules/interrogate.py5
-rw-r--r--modules/launch_utils.py12
-rw-r--r--modules/localization.py6
-rw-r--r--modules/paths.py1
-rw-r--r--modules/processing.py11
-rw-r--r--modules/realesrgan_model.py14
-rw-r--r--modules/safe.py26
-rw-r--r--modules/script_callbacks.py29
-rw-r--r--modules/script_loading.py7
-rw-r--r--modules/scripts.py35
-rw-r--r--modules/sd_hijack.py35
-rw-r--r--modules/sd_hijack_optimizations.py6
-rw-r--r--modules/sd_models.py19
-rw-r--r--modules/sd_samplers_kdiffusion.py40
-rw-r--r--modules/sd_unet.py92
-rw-r--r--modules/shared.py10
-rw-r--r--modules/shared_items.py11
-rw-r--r--modules/textual_inversion/textual_inversion.py34
-rw-r--r--modules/ui.py14
-rw-r--r--modules/ui_common.py9
-rw-r--r--modules/ui_extensions.py15
-rw-r--r--modules/ui_tempdir.py15
-rw-r--r--modules/upscaler.py4
-rw-r--r--requirements.txt42
-rw-r--r--requirements_versions.txt39
-rw-r--r--script.js63
-rw-r--r--scripts/prompts_from_file.py6
-rw-r--r--scripts/xyz_grid.py6
-rw-r--r--style.css18
-rw-r--r--webui-user.sh1
-rw-r--r--webui.py38
-rwxr-xr-xwebui.sh9
63 files changed, 1289 insertions, 333 deletions
diff --git a/.eslintrc.js b/.eslintrc.js
index 218f5609..f33aca09 100644
--- a/.eslintrc.js
+++ b/.eslintrc.js
@@ -50,13 +50,14 @@ module.exports = {
globals: {
//script.js
gradioApp: "readonly",
+ executeCallbacks: "readonly",
+ onAfterUiUpdate: "readonly",
+ onOptionsChanged: "readonly",
onUiLoaded: "readonly",
onUiUpdate: "readonly",
- onOptionsChanged: "readonly",
uiCurrentTab: "writable",
- uiElementIsVisible: "readonly",
uiElementInSight: "readonly",
- executeCallbacks: "readonly",
+ uiElementIsVisible: "readonly",
//ui.js
opts: "writable",
all_gallery_buttons: "readonly",
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index 3a8b9953..9cc16d01 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -43,8 +43,8 @@ body:
- type: input
id: commit
attributes:
- label: Commit where the problem happens
- description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
+ label: Version or Commit where the problem happens
+ description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
validations:
required: true
- type: dropdown
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e46d707a..6c6ab5e3 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,4 +1,4 @@
-## Upcoming 1.3.0
+## 1.3.0
### Features:
* add UI to edit defaults
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index c4da79f3..95f1669d 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -1,9 +1,8 @@
import os
-import sys
-import traceback
from basicsr.utils.download_util import load_file_from_url
+from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks
@@ -51,10 +50,8 @@ class UpscalerLDSR(Upscaler):
try:
return LDSR(model, yaml)
-
except Exception:
- print("Error importing LDSR:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error importing LDSR", exc_info=True)
return None
def do_upscale(self, img, path):
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
index 81c5101b..27a86e13 100644
--- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -10,7 +10,7 @@ from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma
-from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
diff --git a/extensions-builtin/LDSR/vqvae_quantize.py b/extensions-builtin/LDSR/vqvae_quantize.py
new file mode 100644
index 00000000..dd14b8fd
--- /dev/null
+++ b/extensions-builtin/LDSR/vqvae_quantize.py
@@ -0,0 +1,147 @@
+# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
+# where the license is as follows:
+#
+# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
+# OR OTHER DEALINGS IN THE SOFTWARE./
+
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
+ sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
+ assert return_logits is False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 45d9297b..dd1b822e 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,6 +1,5 @@
import os.path
import sys
-import traceback
import PIL.Image
import numpy as np
@@ -12,6 +11,8 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net
+
+from modules.errors import print_error
from modules.shared import opts
@@ -38,8 +39,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
- print(f"Error loading ScuNET model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading ScuNET model: {file}", exc_info=True)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
new file mode 100644
index 00000000..f555960d
--- /dev/null
+++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
@@ -0,0 +1,431 @@
+// Main
+
+// Helper functions
+// Get active tab
+function getActiveTab(elements, all = false) {
+ const tabs = elements.img2imgTabs.querySelectorAll("button");
+
+ if (all) return tabs;
+
+ for (let tab of tabs) {
+ if (tab.classList.contains("selected")) {
+ return tab;
+ }
+ }
+}
+
+onUiLoaded(async() => {
+ const hotkeysConfig = {
+ resetZoom: "KeyR",
+ fitToScreen: "KeyS",
+ moveKey: "KeyF",
+ overlap: "KeyO"
+ };
+
+ let isMoving = false;
+ let mouseX, mouseY;
+
+ const elementIDs = {
+ sketch: "#img2img_sketch",
+ inpaint: "#img2maskimg",
+ inpaintSketch: "#inpaint_sketch",
+ img2imgTabs: "#mode_img2img .tab-nav"
+ };
+
+ async function getElements() {
+ const elements = await Promise.all(
+ Object.values(elementIDs).map(id => document.querySelector(id))
+ );
+ return Object.fromEntries(
+ Object.keys(elementIDs).map((key, index) => [key, elements[index]])
+ );
+ }
+
+ const elements = await getElements();
+
+ function applyZoomAndPan(targetElement, elemId) {
+ targetElement.style.transformOrigin = "0 0";
+ let [zoomLevel, panX, panY] = [1, 0, 0];
+ let fullScreenMode = false;
+
+ // In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui.
+ function fixCanvas() {
+ const activeTab = getActiveTab(elements).textContent.trim();
+
+ if (activeTab !== "img2img") {
+ const img = targetElement.querySelector(`${elemId} img`);
+
+ if (img && img.style.display !== "none") {
+ img.style.display = "none";
+ img.style.visibility = "hidden";
+ }
+ }
+ }
+
+ // Reset the zoom level and pan position of the target element to their initial values
+ function resetZoom() {
+ zoomLevel = 1;
+ panX = 0;
+ panY = 0;
+
+ fixCanvas();
+ targetElement.style.transform = `scale(${zoomLevel}) translate(${panX}px, ${panY}px)`;
+
+ const canvas = gradioApp().querySelector(
+ `${elemId} canvas[key="interface"]`
+ );
+
+ toggleOverlap("off");
+ fullScreenMode = false;
+
+ if (
+ canvas &&
+ parseFloat(canvas.style.width) > 865 &&
+ parseFloat(targetElement.style.width) > 865
+ ) {
+ fitToElement();
+ return;
+ }
+
+ targetElement.style.width = "";
+ if (canvas) {
+ targetElement.style.height = canvas.style.height;
+ }
+ }
+
+ // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
+ function toggleOverlap(forced = "") {
+ const zIndex1 = "0";
+ const zIndex2 = "998";
+
+ targetElement.style.zIndex =
+ targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;
+
+ if (forced === "off") {
+ targetElement.style.zIndex = zIndex1;
+ } else if (forced === "on") {
+ targetElement.style.zIndex = zIndex2;
+ }
+ }
+
+ // Adjust the brush size based on the deltaY value from a mouse wheel event
+ function adjustBrushSize(
+ elemId,
+ deltaY,
+ withoutValue = false,
+ percentage = 5
+ ) {
+ const input =
+ gradioApp().querySelector(
+ `${elemId} input[aria-label='Brush radius']`
+ ) ||
+ gradioApp().querySelector(
+ `${elemId} button[aria-label="Use brush"]`
+ );
+
+ if (input) {
+ input.click();
+ if (!withoutValue) {
+ const maxValue =
+ parseFloat(input.getAttribute("max")) || 100;
+ const changeAmount = maxValue * (percentage / 100);
+ const newValue =
+ parseFloat(input.value) +
+ (deltaY > 0 ? -changeAmount : changeAmount);
+ input.value = Math.min(Math.max(newValue, 0), maxValue);
+ input.dispatchEvent(new Event("change"));
+ }
+ }
+ }
+
+ // Reset zoom when uploading a new image
+ const fileInput = gradioApp().querySelector(
+ `${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`
+ );
+ fileInput.addEventListener("click", resetZoom);
+
+ // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
+ function updateZoom(newZoomLevel, mouseX, mouseY) {
+ newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
+ panX += mouseX - (mouseX * newZoomLevel) / zoomLevel;
+ panY += mouseY - (mouseY * newZoomLevel) / zoomLevel;
+
+ targetElement.style.transformOrigin = "0 0";
+ targetElement.style.transform = `translate(${panX}px, ${panY}px) scale(${newZoomLevel})`;
+
+ toggleOverlap("on");
+ return newZoomLevel;
+ }
+
+ // Change the zoom level based on user interaction
+ function changeZoomLevel(operation, e) {
+ if (e.shiftKey) {
+ e.preventDefault();
+
+ let zoomPosX, zoomPosY;
+ let delta = 0.2;
+ if (zoomLevel > 7) {
+ delta = 0.9;
+ } else if (zoomLevel > 2) {
+ delta = 0.6;
+ }
+
+ zoomPosX = e.clientX;
+ zoomPosY = e.clientY;
+
+ fullScreenMode = false;
+ zoomLevel = updateZoom(
+ zoomLevel + (operation === "+" ? delta : -delta),
+ zoomPosX - targetElement.getBoundingClientRect().left,
+ zoomPosY - targetElement.getBoundingClientRect().top
+ );
+ }
+ }
+
+ /**
+ * This function fits the target element to the screen by calculating
+ * the required scale and offsets. It also updates the global variables
+ * zoomLevel, panX, and panY to reflect the new state.
+ */
+
+ function fitToElement() {
+ //Reset Zoom
+ targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
+
+ // Get element and screen dimensions
+ const elementWidth = targetElement.offsetWidth;
+ const elementHeight = targetElement.offsetHeight;
+ const parentElement = targetElement.parentElement;
+ const screenWidth = parentElement.clientWidth;
+ const screenHeight = parentElement.clientHeight;
+
+ // Get element's coordinates relative to the parent element
+ const elementRect = targetElement.getBoundingClientRect();
+ const parentRect = parentElement.getBoundingClientRect();
+ const elementX = elementRect.x - parentRect.x;
+
+ // Calculate scale and offsets
+ const scaleX = screenWidth / elementWidth;
+ const scaleY = screenHeight / elementHeight;
+ const scale = Math.min(scaleX, scaleY);
+
+ const transformOrigin =
+ window.getComputedStyle(targetElement).transformOrigin;
+ const [originX, originY] = transformOrigin.split(" ");
+ const originXValue = parseFloat(originX);
+ const originYValue = parseFloat(originY);
+
+ const offsetX =
+ (screenWidth - elementWidth * scale) / 2 -
+ originXValue * (1 - scale);
+ const offsetY =
+ (screenHeight - elementHeight * scale) / 2.5 -
+ originYValue * (1 - scale);
+
+ // Apply scale and offsets to the element
+ targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
+
+ // Update global variables
+ zoomLevel = scale;
+ panX = offsetX;
+ panY = offsetY;
+
+ fullScreenMode = false;
+ toggleOverlap("off");
+ }
+
+ /**
+ * This function fits the target element to the screen by calculating
+ * the required scale and offsets. It also updates the global variables
+ * zoomLevel, panX, and panY to reflect the new state.
+ */
+
+ // Fullscreen mode
+ function fitToScreen() {
+ const canvas = gradioApp().querySelector(
+ `${elemId} canvas[key="interface"]`
+ );
+
+ if (!canvas) return;
+
+ if (canvas.offsetWidth > 862) {
+ targetElement.style.width = canvas.offsetWidth + "px";
+ }
+
+ if (fullScreenMode) {
+ resetZoom();
+ fullScreenMode = false;
+ return;
+ }
+
+ //Reset Zoom
+ targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
+
+ // Get scrollbar width to right-align the image
+ const scrollbarWidth = window.innerWidth - document.documentElement.clientWidth;
+
+ // Get element and screen dimensions
+ const elementWidth = targetElement.offsetWidth;
+ const elementHeight = targetElement.offsetHeight;
+ const screenWidth = window.innerWidth - scrollbarWidth;
+ const screenHeight = window.innerHeight;
+
+ // Get element's coordinates relative to the page
+ const elementRect = targetElement.getBoundingClientRect();
+ const elementY = elementRect.y;
+ const elementX = elementRect.x;
+
+ // Calculate scale and offsets
+ const scaleX = screenWidth / elementWidth;
+ const scaleY = screenHeight / elementHeight;
+ const scale = Math.min(scaleX, scaleY);
+
+ // Get the current transformOrigin
+ const computedStyle = window.getComputedStyle(targetElement);
+ const transformOrigin = computedStyle.transformOrigin;
+ const [originX, originY] = transformOrigin.split(" ");
+ const originXValue = parseFloat(originX);
+ const originYValue = parseFloat(originY);
+
+ // Calculate offsets with respect to the transformOrigin
+ const offsetX =
+ (screenWidth - elementWidth * scale) / 2 -
+ elementX -
+ originXValue * (1 - scale);
+ const offsetY =
+ (screenHeight - elementHeight * scale) / 2 -
+ elementY -
+ originYValue * (1 - scale);
+
+ // Apply scale and offsets to the element
+ targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
+
+ // Update global variables
+ zoomLevel = scale;
+ panX = offsetX;
+ panY = offsetY;
+
+ fullScreenMode = true;
+ toggleOverlap("on");
+ }
+
+ // Handle keydown events
+ function handleKeyDown(event) {
+ const hotkeyActions = {
+ [hotkeysConfig.resetZoom]: resetZoom,
+ [hotkeysConfig.overlap]: toggleOverlap,
+ [hotkeysConfig.fitToScreen]: fitToScreen
+ // [hotkeysConfig.moveKey] : moveCanvas,
+ };
+
+ const action = hotkeyActions[event.code];
+ if (action) {
+ event.preventDefault();
+ action(event);
+ }
+ }
+
+ // Get Mouse position
+ function getMousePosition(e) {
+ mouseX = e.offsetX;
+ mouseY = e.offsetY;
+ }
+
+ targetElement.addEventListener("mousemove", getMousePosition);
+
+ // Handle events only inside the targetElement
+ let isKeyDownHandlerAttached = false;
+
+ function handleMouseMove() {
+ if (!isKeyDownHandlerAttached) {
+ document.addEventListener("keydown", handleKeyDown);
+ isKeyDownHandlerAttached = true;
+ }
+ }
+
+ function handleMouseLeave() {
+ if (isKeyDownHandlerAttached) {
+ document.removeEventListener("keydown", handleKeyDown);
+ isKeyDownHandlerAttached = false;
+ }
+ }
+
+ // Add mouse event handlers
+ targetElement.addEventListener("mousemove", handleMouseMove);
+ targetElement.addEventListener("mouseleave", handleMouseLeave);
+
+ // Reset zoom when click on another tab
+ elements.img2imgTabs.addEventListener("click", resetZoom);
+ elements.img2imgTabs.addEventListener("click", () => {
+ // targetElement.style.width = "";
+ if (parseInt(targetElement.style.width) > 865) {
+ setTimeout(fitToElement, 0);
+ }
+ });
+
+ targetElement.addEventListener("wheel", e => {
+ // change zoom level
+ const operation = e.deltaY > 0 ? "-" : "+";
+ changeZoomLevel(operation, e);
+
+ // Handle brush size adjustment with ctrl key pressed
+ if (e.ctrlKey || e.metaKey) {
+ e.preventDefault();
+
+ // Increase or decrease brush size based on scroll direction
+ adjustBrushSize(elemId, e.deltaY);
+ }
+ });
+
+ /**
+ * Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
+ * @param {MouseEvent} e - The mouse event.
+ */
+ function handleMoveKeyDown(e) {
+ if (e.code === hotkeysConfig.moveKey) {
+ if (!e.ctrlKey && !e.metaKey) {
+ isMoving = true;
+ }
+ }
+ }
+
+ function handleMoveKeyUp(e) {
+ if (e.code === hotkeysConfig.moveKey) {
+ isMoving = false;
+ }
+ }
+
+ document.addEventListener("keydown", handleMoveKeyDown);
+ document.addEventListener("keyup", handleMoveKeyUp);
+
+ // Detect zoom level and update the pan speed.
+ function updatePanPosition(movementX, movementY) {
+ let panSpeed = 1.5;
+
+ if (zoomLevel > 8) {
+ panSpeed = 2.5;
+ }
+
+ panX = panX + movementX * panSpeed;
+ panY = panY + movementY * panSpeed;
+
+ targetElement.style.transform = `translate(${panX}px, ${panY}px) scale(${zoomLevel})`;
+ toggleOverlap("on");
+ }
+
+ function handleMoveByKey(e) {
+ if (isMoving) {
+ updatePanPosition(e.movementX, e.movementY);
+ targetElement.style.pointerEvents = "none";
+ } else {
+ targetElement.style.pointerEvents = "auto";
+ }
+ }
+
+ gradioApp().addEventListener("mousemove", handleMoveByKey);
+ }
+
+ applyZoomAndPan(elements.sketch, elementIDs.sketch);
+ applyZoomAndPan(elements.inpaint, elementIDs.inpaint);
+ applyZoomAndPan(elements.inpaintSketch, elementIDs.inpaintSketch);
+});
diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js
index 1c08a1a9..2cf2d571 100644
--- a/javascript/aspectRatioOverlay.js
+++ b/javascript/aspectRatioOverlay.js
@@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) {
}
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
if (arPreviewRect) {
arPreviewRect.style.display = 'none';
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js
index f14af1d4..d60a10c4 100644
--- a/javascript/contextMenus.js
+++ b/javascript/contextMenus.js
@@ -167,6 +167,4 @@ var addContextMenuEventListener = initResponse[2];
})();
//End example Context Menu Items
-onUiUpdate(function() {
- addContextMenuEventListener();
-});
+onAfterUiUpdate(addContextMenuEventListener);
diff --git a/javascript/generationParams.js b/javascript/generationParams.js
index a877f8a5..7c0fd221 100644
--- a/javascript/generationParams.js
+++ b/javascript/generationParams.js
@@ -1,7 +1,7 @@
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined;
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img");
}
diff --git a/javascript/hints.js b/javascript/hints.js
index 46f342cb..05ae5f22 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -116,17 +116,25 @@ var titles = {
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
};
-function updateTooltipForSpan(span) {
- if (span.title) return; // already has a title
+function updateTooltip(element) {
+ if (element.title) return; // already has a title
- let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
+ let text = element.textContent;
+ let tooltip = localization[titles[text]] || titles[text];
if (!tooltip) {
- tooltip = localization[titles[span.value]] || titles[span.value];
+ let value = element.value;
+ if (value) tooltip = localization[titles[value]] || titles[value];
}
if (!tooltip) {
- for (const c of span.classList) {
+ // Gradio dropdown options have `data-value`.
+ let dataValue = element.dataset.value;
+ if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue];
+ }
+
+ if (!tooltip) {
+ for (const c of element.classList) {
if (c in titles) {
tooltip = localization[titles[c]] || titles[c];
break;
@@ -135,34 +143,53 @@ function updateTooltipForSpan(span) {
}
if (tooltip) {
- span.title = tooltip;
+ element.title = tooltip;
}
}
-function updateTooltipForSelect(select) {
- if (select.onchange != null) return;
+// Nodes to check for adding tooltips.
+const tooltipCheckNodes = new Set();
+// Timer for debouncing tooltip check.
+let tooltipCheckTimer = null;
- select.onchange = function() {
- select.title = localization[titles[select.value]] || titles[select.value] || "";
- };
+function processTooltipCheckNodes() {
+ for (const node of tooltipCheckNodes) {
+ updateTooltip(node);
+ }
+ tooltipCheckNodes.clear();
}
-var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
-
-onUiUpdate(function(m) {
- m.forEach(function(record) {
- record.addedNodes.forEach(function(node) {
- if (observedTooltipElements[node.tagName]) {
- updateTooltipForSpan(node);
- }
- if (node.tagName == "SELECT") {
- updateTooltipForSelect(node);
+onUiUpdate(function(mutationRecords) {
+ for (const record of mutationRecords) {
+ if (record.type === "childList" && record.target.classList.contains("options")) {
+ // This smells like a Gradio dropdown menu having changed,
+ // so let's enqueue an update for the input element that shows the current value.
+ let wrap = record.target.parentNode;
+ let input = wrap?.querySelector("input");
+ if (input) {
+ input.title = ""; // So we'll even have a chance to update it.
+ tooltipCheckNodes.add(input);
}
-
- if (node.querySelectorAll) {
- node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
- node.querySelectorAll('select').forEach(updateTooltipForSelect);
+ }
+ for (const node of record.addedNodes) {
+ if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) {
+ if (!node.title) {
+ if (
+ node.tagName === "SPAN" ||
+ node.tagName === "BUTTON" ||
+ node.tagName === "P" ||
+ node.tagName === "INPUT" ||
+ (node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item
+ ) {
+ tooltipCheckNodes.add(node);
+ }
+ }
+ node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n));
}
- });
- });
+ }
+ }
+ if (tooltipCheckNodes.size) {
+ clearTimeout(tooltipCheckTimer);
+ tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
+ }
});
diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js
index 3c9b8a6f..900c56f3 100644
--- a/javascript/imageMaskFix.js
+++ b/javascript/imageMaskFix.js
@@ -39,5 +39,5 @@ function imageMaskResize() {
});
}
-onUiUpdate(imageMaskResize);
+onAfterUiUpdate(imageMaskResize);
window.addEventListener('resize', imageMaskResize);
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index 78e24eb9..677e95c1 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -170,7 +170,7 @@ function modalTileImageToggle(event) {
event.stopPropagation();
}
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox);
diff --git a/javascript/imageviewerGamepad.js b/javascript/imageviewerGamepad.js
index 31d226de..a22c7e6e 100644
--- a/javascript/imageviewerGamepad.js
+++ b/javascript/imageviewerGamepad.js
@@ -1,7 +1,9 @@
+let gamepads = [];
+
window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index;
let isWaiting = false;
- setInterval(async() => {
+ gamepads[index] = setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0];
@@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => {
}, 10);
});
+window.addEventListener('gamepaddisconnected', (e) => {
+ clearInterval(gamepads[e.gamepad.index]);
+});
+
/*
Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr.
diff --git a/javascript/notification.js b/javascript/notification.js
index a68a76f2..76c5715d 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -4,7 +4,7 @@ let lastHeadImg = null;
let notificationButton = null;
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
if (notificationButton == null) {
notificationButton = gradioApp().getElementById('request_notifications');
diff --git a/javascript/ui.js b/javascript/ui.js
index 800a2ae6..d70a681b 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -249,7 +249,7 @@ function confirm_clear_prompt(prompt, negative_prompt) {
var opts = {};
-onUiUpdate(function() {
+onAfterUiUpdate(function() {
if (Object.keys(opts).length != 0) return;
var json_elem = gradioApp().getElementById('settings_json');
diff --git a/modules/api/api.py b/modules/api/api.py
index eee99bbb..fbd616a3 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -16,6 +16,7 @@ from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api import models
+from modules.errors import print_error
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
@@ -23,6 +24,7 @@ from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -108,7 +110,6 @@ def api_middleware(app: FastAPI):
from rich.console import Console
console = Console()
except Exception:
- import traceback
rich_available = False
@app.middleware("http")
@@ -139,11 +140,12 @@ def api_middleware(app: FastAPI):
"errors": str(e),
}
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
- print(f"API error: {request.method}: {request.url} {err}")
+ message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
+ print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
- traceback.print_exc()
+ print_error(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
@@ -189,6 +191,7 @@ class Api:
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
+ self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
@@ -541,6 +544,9 @@ class Api:
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
+ def get_sd_vaes(self):
+ return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
+
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
@@ -700,4 +706,4 @@ class Api:
def launch(self, server_name, port):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
diff --git a/modules/api/models.py b/modules/api/models.py
index 1ff2fb33..47fdede2 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -249,6 +249,10 @@ class SDModelItem(BaseModel):
filename: str = Field(title="Filename")
config: Optional[str] = Field(title="Config file")
+class SDVaeItem(BaseModel):
+ model_name: str = Field(title="Model Name")
+ filename: str = Field(title="Filename")
+
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 447bb764..dba2a9b4 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,10 +1,9 @@
import html
-import sys
import threading
-import traceback
import time
from modules import shared, progress
+from modules.errors import print_error
queue_lock = threading.Lock()
@@ -56,16 +55,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
try:
res = list(func(*args, **kwargs))
except Exception as e:
- # When printing out our debug argument list, do not print out more than a MB of text
- max_debug_str_len = 131072 # (1024*1024)/8
-
- print("Error completing request", file=sys.stderr)
- argStr = f"Arguments: {args} {kwargs}"
- print(argStr[:max_debug_str_len], file=sys.stderr)
- if len(argStr) > max_debug_str_len:
- print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
-
- print(traceback.format_exc(), file=sys.stderr)
+ # When printing out our debug argument list,
+ # do not print out more than a 100 KB of text
+ max_debug_str_len = 131072
+ message = "Error completing request"
+ arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
+ if len(arg_str) > max_debug_str_len:
+ arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
+ print_error(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
@@ -108,4 +105,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res)
return f
-
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index 3eeb84d5..0974056d 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
-parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
+parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index ececdbae..76143e9f 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
import cv2
import torch
@@ -8,6 +6,7 @@ import torch
import modules.face_restoration
import modules.shared
from modules import shared, devices, modelloader
+from modules.errors import print_error
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
@@ -105,8 +104,8 @@ def setup_model(dirname):
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
+ except Exception:
+ print_error('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
@@ -135,7 +134,6 @@ def setup_model(dirname):
shared.face_restorers.append(codeformer)
except Exception:
- print("Error setting up CodeFormer:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
diff --git a/modules/config_states.py b/modules/config_states.py
index db65bcdb..faeaf28b 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
"""
import os
-import sys
-import traceback
import json
import time
import tqdm
@@ -14,6 +12,7 @@ from collections import OrderedDict
import git
from modules import shared, extensions
+from modules.errors import print_error
from modules.paths_internal import script_path, config_states_dir
@@ -53,8 +52,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading webui git info from {script_path}", exc_info=True)
webui_remote = None
webui_commit_hash = None
@@ -134,8 +132,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading webui git info from {script_path}", exc_info=True)
return
try:
@@ -143,8 +140,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception:
- print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error restoring webui to commit{webui_commit_hash}")
def restore_extension_config(config):
diff --git a/modules/errors.py b/modules/errors.py
index f6b80dbb..41d8dc93 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -1,7 +1,23 @@
import sys
+import textwrap
import traceback
+def print_error(
+ message: str,
+ *,
+ exc_info: bool = False,
+) -> None:
+ """
+ Print an error message to stderr, with optional traceback.
+ """
+ for line in message.splitlines():
+ print("***", line, file=sys.stderr)
+ if exc_info:
+ print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
+ print("---")
+
+
def print_error_explanation(message):
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
@@ -12,9 +28,13 @@ def print_error_explanation(message):
print('=' * max_len, file=sys.stderr)
-def display(e: Exception, task):
+def display(e: Exception, task, *, full_traceback=False):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ te = traceback.TracebackException.from_exception(e)
+ if full_traceback:
+ # include frames leading up to the try-catch block
+ te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
+ print(*te.format(), sep="", file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
diff --git a/modules/extensions.py b/modules/extensions.py
index 624832a0..92f93ad9 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,11 +1,9 @@
import os
-import sys
import threading
-import traceback
-
-import git
from modules import shared
+from modules.errors import print_error
+from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
@@ -54,10 +52,9 @@ class Extension:
repo = None
try:
if os.path.exists(os.path.join(self.path, ".git")):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
except Exception:
- print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error reading github repository info from {self.path}", exc_info=True)
if repo is None or repo.bare:
self.remote = None
@@ -72,8 +69,8 @@ class Extension:
self.commit_hash = commit.hexsha
self.version = self.commit_hash[:8]
- except Exception as ex:
- print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
+ except Exception:
+ print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None
self.have_info_from_repo = True
@@ -94,7 +91,7 @@ class Extension:
return res
def check_updates(self):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
@@ -116,7 +113,7 @@ class Extension:
self.status = "latest"
def fetch_and_reset_hard(self, commit='origin'):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index d5f0a49b..071bd9ea 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -35,7 +35,7 @@ def reset():
def quote(text):
- if ',' not in str(text) and '\n' not in str(text):
+ if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
return text
return json.dumps(text, ensure_ascii=False)
@@ -306,6 +306,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "RNG" not in res:
res["RNG"] = "GPU"
+ if "Schedule type" not in res:
+ res["Schedule type"] = "Automatic"
+
+ if "Schedule max sigma" not in res:
+ res["Schedule max sigma"] = 0
+
+ if "Schedule min sigma" not in res:
+ res["Schedule min sigma"] = 0
+
+ if "Schedule rho" not in res:
+ res["Schedule rho"] = 0
+
return res
@@ -318,6 +330,10 @@ infotext_to_setting_name_mapping = [
('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'),
+ ('Schedule type', 'k_sched_type'),
+ ('Schedule max sigma', 'sigma_max'),
+ ('Schedule min sigma', 'sigma_min'),
+ ('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index 0131dea4..d2f647fe 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,12 +1,11 @@
import os
-import sys
-import traceback
import facexlib
import gfpgan
import modules.face_restoration
from modules import paths, shared, devices, modelloader
+from modules.errors import print_error
model_dir = "GFPGAN"
user_path = None
@@ -112,5 +111,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
- print("Error setting up GFPGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py
new file mode 100644
index 00000000..e537c1df
--- /dev/null
+++ b/modules/gitpython_hack.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+import io
+import subprocess
+
+import git
+
+
+class Git(git.Git):
+ """
+ Git subclassed to never use persistent processes.
+ """
+
+ def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
+ raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
+
+ def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
+ ret = subprocess.check_output(
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
+ input=self._prepare_ref(ref),
+ cwd=self._working_dir,
+ timeout=2,
+ )
+ return self._parse_object_header(ret)
+
+ def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
+ # Not really streaming, per se; this buffers the entire object in memory.
+ # Shouldn't be a problem for our use case, since we're only using this for
+ # object headers (commit objects).
+ ret = subprocess.check_output(
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
+ input=self._prepare_ref(ref),
+ cwd=self._working_dir,
+ timeout=30,
+ )
+ bio = io.BytesIO(ret)
+ hexsha, typename, size = self._parse_object_header(bio.readline())
+ return (hexsha, typename, size, self.CatFileContentStream(size, bio))
+
+
+class Repo(git.Repo):
+ GitCommandWrapperType = Git
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 570b5603..fcc1ef20 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -2,8 +2,6 @@ import datetime
import glob
import html
import os
-import sys
-import traceback
import inspect
import modules.textual_inversion.dataset
@@ -12,6 +10,7 @@ import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
+from modules.errors import print_error
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -325,17 +324,14 @@ def load_hypernetwork(name):
if path is None:
return None
- hypernetwork = Hypernetwork()
-
try:
+ hypernetwork = Hypernetwork()
hypernetwork.load(path)
+ return hypernetwork
except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading hypernetwork {path}", exc_info=True)
return None
- return hypernetwork
-
def load_hypernetworks(names, multipliers=None):
already_loaded = {}
@@ -770,7 +766,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Exception in training hypernetwork", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/images.py b/modules/images.py
index 4e8cd993..09f728df 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,6 +1,4 @@
import datetime
-import sys
-import traceback
import pytz
import io
@@ -18,9 +16,12 @@ import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
+from modules.errors import print_error
from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
+import modules.sd_vae as sd_vae
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -336,8 +337,20 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator:
+ def get_vae_filename(self): #get the name of the VAE file.
+ if sd_vae.loaded_vae_file is None:
+ return "NoneType"
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
+ split_file_name = file_name.split('.')
+ if len(split_file_name) > 1 and split_file_name[0] == '':
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
+ else:
+ return split_file_name[0]
+
replacements = {
'seed': lambda self: self.seed if self.seed is not None else '',
+ 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
+ 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.image.width,
@@ -354,19 +367,23 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
- 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
- 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
+ 'batch_size': lambda self: self.p.batch_size,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'vae_filename': lambda self: self.get_vae_filename(),
+
}
default_time_format = '%Y%m%d%H%M%S'
- def __init__(self, p, seed, prompt, image):
+ def __init__(self, p, seed, prompt, image, zip=False):
self.p = p
self.seed = seed
self.prompt = prompt
self.image = image
+ self.zip = zip
def hasprompt(self, *args):
lower = self.prompt.lower()
@@ -446,8 +463,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args)
except Exception:
replacement = None
- print(f"Error adding [{pattern}] to filename", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error adding [{pattern}] to filename", exc_info=True)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue
@@ -493,9 +509,12 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
existing_pnginfo['parameters'] = geninfo
if extension.lower() == '.png':
- pnginfo_data = PngImagePlugin.PngInfo()
- for k, v in (existing_pnginfo or {}).items():
- pnginfo_data.add_text(k, str(v))
+ if opts.enable_pnginfo:
+ pnginfo_data = PngImagePlugin.PngInfo()
+ for k, v in (existing_pnginfo or {}).items():
+ pnginfo_data.add_text(k, str(v))
+ else:
+ pnginfo_data = None
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
@@ -665,9 +684,10 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration']:
- items.pop(field, None)
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity']:
+ items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
@@ -678,8 +698,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items
diff --git a/modules/img2img.py b/modules/img2img.py
index d704bf90..4c12c2c5 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -92,7 +92,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask).convert('L')
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 111b1322..d36e1a5a 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,6 +1,5 @@
import os
import sys
-import traceback
from collections import namedtuple
from pathlib import Path
import re
@@ -12,6 +11,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors
+from modules.errors import print_error
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -216,8 +216,7 @@ class InterrogateModels:
res += f", {match}"
except Exception:
- print("Error interrogating", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error interrogating", exc_info=True)
res += "<error>"
self.unload()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index 35a52310..0bf4cb7e 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -8,6 +8,7 @@ import json
from functools import lru_cache
from modules import cmd_args
+from modules.errors import print_error
from modules.paths_internal import script_path, extensions_dir
args, _ = cmd_args.parser.parse_known_args()
@@ -188,7 +189,7 @@ def run_extension_installer(extension_dir):
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
- print(e, file=sys.stderr)
+ print_error(str(e))
def list_extensions(settings_file):
@@ -198,8 +199,8 @@ def list_extensions(settings_file):
if os.path.isfile(settings_file):
with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
- except Exception as e:
- print(e, file=sys.stderr)
+ except Exception:
+ print_error("Could not load settings", exc_info=True)
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
@@ -223,19 +224,17 @@ def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
- taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
- taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -286,7 +285,6 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
- git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
diff --git a/modules/localization.py b/modules/localization.py
index ee9c65e7..9a1df343 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,8 +1,7 @@
import json
import os
-import sys
-import traceback
+from modules.errors import print_error
localizations = {}
@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
- print(f"Error loading localization from {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/paths.py b/modules/paths.py
index 5f6474c0..5171df4f 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
- (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
diff --git a/modules/processing.py b/modules/processing.py
index 29a3743f..f628d88b 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,4 +1,5 @@
import json
+import logging
import math
import os
import sys
@@ -13,7 +14,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -23,7 +24,6 @@ import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
-import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
@@ -321,14 +321,13 @@ class StableDiffusionProcessing:
have been used before. The second element is where the previously
computed result is stored.
"""
-
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
return cache[1]
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps)
+ cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
return cache[1]
def setup_conds(self):
@@ -674,6 +673,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
+ sd_unet.apply_unet()
+
if state.job_count == -1:
state.job_count = p.n_iter
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 99983678..c8d0c64f 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,12 +1,11 @@
import os
-import sys
-import traceback
import numpy as np
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
+from modules.errors import print_error
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
from modules import modelloader
@@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler)
except Exception:
- print("Error importing Real-ESRGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
@@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
return info
- except Exception as e:
- print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ except Exception:
+ print_error("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _):
@@ -135,5 +132,4 @@ def get_realesrgan_models(scaler):
]
return models
except Exception:
- print("Error making Real-ESRGAN models list:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error making Real-ESRGAN models list", exc_info=True)
diff --git a/modules/safe.py b/modules/safe.py
index e8f50774..b596f565 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -2,8 +2,6 @@
import pickle
import collections
-import sys
-import traceback
import torch
import numpy
@@ -11,6 +9,8 @@ import _codecs
import zipfile
import re
+from modules.errors import print_error
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
@@ -136,17 +136,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ print_error(
+ f"Error verifying pickled file from {filename}\n"
+ "-----> !!!! The file is most likely corrupted !!!! <-----\n"
+ "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
+ exc_info=True,
+ )
return None
-
except Exception:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ print_error(
+ f"Error verifying pickled file from {filename}\n"
+ f"The file may be malicious, so the program is not going to read it.\n"
+ f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
+ exc_info=True,
+ )
return None
return unsafe_torch_load(filename, *args, **kwargs)
@@ -190,4 +193,3 @@ with safe.Extra(handler):
unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None
-
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 40f388a5..6aa9c3b6 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,16 +1,15 @@
-import sys
-import traceback
-from collections import namedtuple
import inspect
+from collections import namedtuple
from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
+from modules.errors import print_error
+
def report_exception(c, job):
- print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error executing callback {job} for {c.script}", exc_info=True)
class ImageSaveParams:
@@ -111,6 +110,7 @@ callback_map = dict(
callbacks_before_ui=[],
callbacks_on_reload=[],
callbacks_list_optimizers=[],
+ callbacks_list_unets=[],
)
@@ -271,6 +271,18 @@ def list_optimizers_callback():
return res
+def list_unets_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_unets']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_unets')
+
+ return res
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -430,3 +442,10 @@ def on_list_optimizers(callback):
to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback)
+
+
+def on_list_unets(callback):
+ """register a function to be called when UI is making a list of alternative options for unet.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
+
+ add_callback(callback_map['callbacks_list_unets'], callback)
diff --git a/modules/script_loading.py b/modules/script_loading.py
index 57b15862..26efffcb 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,8 +1,8 @@
import os
-import sys
-import traceback
import importlib.util
+from modules.errors import print_error
+
def load_module(path):
module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)
@@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser)
except Exception:
- print(f"Error running preload() for {preload_script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running preload() for {preload_script}", exc_info=True)
diff --git a/modules/scripts.py b/modules/scripts.py
index c902804b..a7168fd1 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,12 +1,12 @@
import os
import re
import sys
-import traceback
from collections import namedtuple
import gradio as gr
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
+from modules.errors import print_error
AlwaysVisible = object()
@@ -264,8 +264,7 @@ def load_scripts():
register_scripts_from_module(script_module)
except Exception:
- print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading script: {scriptfile.filename}", exc_info=True)
finally:
sys.path = syspath
@@ -280,11 +279,9 @@ def load_scripts():
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
- res = func(*args, **kwargs)
- return res
+ return func(*args, **kwargs)
except Exception:
- print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error calling: {filename}/{funcname}", exc_info=True)
return default
@@ -450,8 +447,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running process: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -459,8 +455,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running before_process_batch: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -468,8 +463,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
@@ -477,8 +471,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
- print(f"Error running postprocess: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
@@ -486,8 +479,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
@@ -495,24 +487,21 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
- print(f"Error running before_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
- print(f"Error running after_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error running after_component: {script.filename}", exc_info=True)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 08d31080..487dfd60 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -43,11 +43,16 @@ def list_optimizers():
optimizers.extend(new_optimizers)
-def apply_optimizations():
+def apply_optimizations(option=None):
global current_optimizer
undo_optimizations()
+ if len(optimizers) == 0:
+ # a script can access the model very early, and optimizations would not be filled by then
+ current_optimizer = None
+ return ''
+
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
@@ -55,7 +60,7 @@ def apply_optimizations():
current_optimizer.undo()
current_optimizer = None
- selection = shared.opts.cross_attention_optimization
+ selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
@@ -67,11 +72,13 @@ def apply_optimizations():
matching_optimizer = optimizers[0]
if matching_optimizer is not None:
- print(f"Applying optimization: {matching_optimizer.name}")
+ print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
+ print("done.")
current_optimizer = matching_optimizer
return current_optimizer.name
else:
+ print("Disabling attention optimization")
return ''
@@ -149,6 +156,13 @@ class StableDiffusionModelHijack:
def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
+ def apply_optimizations(self, option=None):
+ try:
+ self.optimization_method = apply_optimizations(option)
+ except Exception as e:
+ errors.display(e, "applying cross attention optimization")
+ undo_optimizations()
+
def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
@@ -168,11 +182,7 @@ class StableDiffusionModelHijack:
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
- try:
- self.optimization_method = apply_optimizations()
- except Exception as e:
- errors.display(e, "applying cross attention optimization")
- undo_optimizations()
+ self.apply_optimizations()
self.clip = m.cond_stage_model
@@ -185,6 +195,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
+ if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
+ ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
+
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
@@ -206,6 +221,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 2ec0b049..fd186fa2 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,5 @@
from __future__ import annotations
import math
-import sys
-import traceback
import psutil
import torch
@@ -11,6 +9,7 @@ from ldm.util import default
from einops import rearrange
from modules import shared, errors, devices, sub_quadratic_attention
+from modules.errors import print_error
from modules.hypernetworks import hypernetwork
import ldm.modules.attention
@@ -140,8 +139,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops
shared.xformers_available = True
except Exception:
- print("Cannot import xformers", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Cannot import xformers", exc_info=True)
def get_available_vram():
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 91b3eb11..232eb9c4 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -164,6 +164,7 @@ def model_hash(filename):
def select_checkpoint():
+ """Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
@@ -171,14 +172,14 @@ def select_checkpoint():
return checkpoint_info
if len(checkpoints_list) == 0:
- print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
if shared.cmd_opts.ckpt is not None:
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {model_path}", file=sys.stderr)
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
+ error_message += f"\n - directory {model_path}"
if shared.cmd_opts.ckpt_dir is not None:
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
- exit(1)
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
+ raise FileNotFoundError(error_message)
checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
@@ -423,7 +424,7 @@ class SdModelData:
try:
load_model()
except Exception as e:
- errors.display(e, "loading stable diffusion model")
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
@@ -532,6 +533,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
+ sd_unet.apply_unet("None")
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index dcec9e0e..f8a0c7ba 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -44,6 +44,14 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
+k_diffusion_scheduler = {
+ 'Automatic': None,
+ 'karras': k_diffusion.sampling.get_sigmas_karras,
+ 'exponential': k_diffusion.sampling.get_sigmas_exponential,
+ 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
+}
+
class CFGDenoiser(torch.nn.Module):
"""
@@ -265,6 +273,13 @@ class KDiffusionSampler:
try:
return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
except sd_samplers_common.InterruptedException:
return self.last_latent
@@ -304,6 +319,31 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
+ elif opts.k_sched_type != "Automatic":
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
+ sigmas_kwargs = {
+ 'sigma_min': sigma_min,
+ 'sigma_max': sigma_max,
+ }
+
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
+ p.extra_generation_params["Schedule type"] = opts.k_sched_type
+
+ if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
+ sigmas_kwargs['sigma_min'] = opts.sigma_min
+ p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
+ if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
+ sigmas_kwargs['sigma_max'] = opts.sigma_max
+ p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
+
+ default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
+
+ if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
+ sigmas_kwargs['rho'] = opts.rho
+ p.extra_generation_params["Schedule rho"] = opts.rho
+
+ sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
new file mode 100644
index 00000000..6d708ad2
--- /dev/null
+++ b/modules/sd_unet.py
@@ -0,0 +1,92 @@
+import torch.nn
+import ldm.modules.diffusionmodules.openaimodel
+
+from modules import script_callbacks, shared, devices
+
+unet_options = []
+current_unet_option = None
+current_unet = None
+
+
+def list_unets():
+ new_unets = script_callbacks.list_unets_callback()
+
+ unet_options.clear()
+ unet_options.extend(new_unets)
+
+
+def get_unet_option(option=None):
+ option = option or shared.opts.sd_unet
+
+ if option == "None":
+ return None
+
+ if option == "Automatic":
+ name = shared.sd_model.sd_checkpoint_info.model_name
+
+ options = [x for x in unet_options if x.model_name == name]
+
+ option = options[0].label if options else "None"
+
+ return next(iter([x for x in unet_options if x.label == option]), None)
+
+
+def apply_unet(option=None):
+ global current_unet_option
+ global current_unet
+
+ new_option = get_unet_option(option)
+ if new_option == current_unet_option:
+ return
+
+ if current_unet is not None:
+ print(f"Dectivating unet: {current_unet.option.label}")
+ current_unet.deactivate()
+
+ current_unet_option = new_option
+ if current_unet_option is None:
+ current_unet = None
+
+ if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ shared.sd_model.model.diffusion_model.to(devices.device)
+
+ return
+
+ shared.sd_model.model.diffusion_model.to(devices.cpu)
+ devices.torch_gc()
+
+ current_unet = current_unet_option.create_unet()
+ current_unet.option = current_unet_option
+ print(f"Activating unet: {current_unet.option.label}")
+ current_unet.activate()
+
+
+class SdUnetOption:
+ model_name = None
+ """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
+
+ label = None
+ """name of the unet in UI"""
+
+ def create_unet(self):
+ """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
+ raise NotImplementedError()
+
+
+class SdUnet(torch.nn.Module):
+ def forward(self, x, timesteps, context, *args, **kwargs):
+ raise NotImplementedError()
+
+ def activate(self):
+ pass
+
+ def deactivate(self):
+ pass
+
+
+def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
+ if current_unet is not None:
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
+
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+
diff --git a/modules/shared.py b/modules/shared.py
index 0897f937..acec7f18 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -314,6 +314,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
+ "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
@@ -403,6 +404,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -414,12 +416,12 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
- "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
+ "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
}))
options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
- "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+ "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
@@ -517,6 +519,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 2a8713c8..7f306a06 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -29,3 +29,14 @@ def cross_attention_optimizations():
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
+def sd_unet_items():
+ import modules.sd_unet
+
+ return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
+
+
+def refresh_unet_list():
+ import modules.sd_unet
+
+ modules.sd_unet.list_unets()
+
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d489ed1e..b3dcb140 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
from collections import namedtuple
import torch
@@ -16,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
+from modules.errors import print_error
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
@@ -120,16 +119,29 @@ class EmbeddingDatabase:
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
-
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
+ return self.register_embedding_by_name(embedding, model, embedding.name)
+ def register_embedding_by_name(self, embedding, model, name):
+ ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
-
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
+ if name in self.word_embeddings:
+ # remove old one from the lookup list
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
+ else:
+ lookup = self.ids_lookup[first_id]
+ if embedding is not None:
+ lookup += [(ids, embedding)]
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
+ if embedding is None:
+ # unregister embedding with specified name
+ if name in self.word_embeddings:
+ del self.word_embeddings[name]
+ if len(self.ids_lookup[first_id])==0:
+ del self.ids_lookup[first_id]
+ return None
+ self.word_embeddings[name] = embedding
return embedding
def get_expected_shape(self):
@@ -207,8 +219,7 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn)
except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error loading embedding {fn}", exc_info=True)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
@@ -632,8 +643,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
- pass
+ print_error("Error training embedding", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/ui.py b/modules/ui.py
index 001b9792..fb6b2498 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -2,7 +2,6 @@ import json
import mimetypes
import os
import sys
-import traceback
from functools import reduce
import warnings
@@ -14,6 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
+from modules.errors import print_error
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
@@ -231,9 +231,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
+ if gen_info_string:
+ print_error(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
@@ -505,10 +504,10 @@ def create_ui():
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
with gr.Row():
- hr_prompt = gr.Textbox(label="Prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
+ hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80):
with gr.Row():
- hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
+ hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
elif category == "batch":
if not opts.dimensions_and_batch_together:
@@ -1753,8 +1752,7 @@ def create_ui():
try:
results = modules.extras.run_modelmerger(*args)
except Exception as e:
- print("Error loading/saving model file:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error("Error loading/saving model file", exc_info=True)
modules.sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 27ab3ebb..5a9204a4 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -50,9 +50,10 @@ def save_files(js_data, images, do_make_zip, index):
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
extension: str = shared.opts.samples_format
start_index = 0
+ only_one = False
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
+ only_one = True
images = [images[index]]
start_index = index
@@ -70,6 +71,7 @@ def save_files(js_data, images, do_make_zip, index):
is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image)
+ p.batch_index = image_index-1
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
@@ -83,7 +85,10 @@ def save_files(js_data, images, do_make_zip, index):
# Make Zip
if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
+ zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
+ namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
+ zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
+ zip_filepath = os.path.join(path, f"{zip_filename}.zip")
from zipfile import ZipFile
with ZipFile(zip_filepath, "w") as zip_file:
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 515ec262..e2ee9d72 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,10 +1,8 @@
import json
import os.path
-import sys
import threading
import time
from datetime import datetime
-import traceback
import git
@@ -14,6 +12,7 @@ import shutil
import errno
from modules import extensions, shared, paths, config_states
+from modules.errors import print_error
from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
@@ -46,8 +45,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
try:
ext.fetch_and_reset_hard()
except Exception:
- print(f"Error getting updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error getting updates for {ext.name}", exc_info=True)
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
@@ -113,8 +111,7 @@ def check_updates(id_task, disable_list):
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
- print(f"Error checking updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error checking updates for {ext.name}", exc_info=True)
shared.state.nextjob()
@@ -490,8 +487,14 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
def preload_extensions_git_metadata():
+ t0 = time.time()
for extension in extensions.extensions:
extension.read_info_from_repo()
+ print(
+ f"preload_extensions_git_metadata for "
+ f"{len(extensions.extensions)} extensions took "
+ f"{time.time() - t0:.2f}s"
+ )
def create_ui():
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index f05049e1..9fc7d764 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -3,7 +3,7 @@ import tempfile
from collections import namedtuple
from pathlib import Path
-import gradio as gr
+import gradio.components
from PIL import PngImagePlugin
@@ -31,13 +31,16 @@ def check_tmp_file(gradio, filename):
return False
-def save_pil_to_file(pil_image, dir=None):
+def save_pil_to_file(self, pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
+ filename = already_saved_as
- file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
- return file_obj
+ if not shared.opts.save_images_add_number:
+ filename += f'?{os.path.getmtime(already_saved_as)}'
+
+ return filename
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
@@ -51,11 +54,11 @@ def save_pil_to_file(pil_image, dir=None):
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
- return file_obj
+ return file_obj.name
# override save to file function so that it also writes PNG info
-gr.processing_utils.save_pil_to_file = save_pil_to_file
+gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
def on_tmpdir_changed():
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 7b1046d6..3c82861d 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -53,8 +53,8 @@ class Upscaler:
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
- dest_w = int(img.width * scale)
- dest_h = int(img.height * scale)
+ dest_w = round((img.width * scale - 4) / 8) * 8
+ dest_h = round((img.height * scale - 4) / 8) * 8
for _ in range(3):
shape = (img.width, img.height)
diff --git a/requirements.txt b/requirements.txt
index 34e4520d..3142085e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,32 +1,32 @@
-astunparse
-blendmodes
+GitPython
+Pillow
accelerate
+
basicsr
+blendmodes
+clean-fid
+einops
gfpgan
-gradio==3.31.0
+gradio==3.32.0
+inflection
+jsonmerge
+kornia
+lark
numpy
omegaconf
-opencv-contrib-python
-requests
+
piexif
-Pillow
-pytorch_lightning==1.7.7
+psutil
+pytorch_lightning
realesrgan
+requests
+resize-right
+
+safetensors
scikit-image>=0.19
-timm==0.4.12
-transformers==4.25.1
+timm
+tomesd
torch
-einops
-jsonmerge
-clean-fid
-resize-right
torchdiffeq
-kornia
-lark
-inflection
-GitPython
torchsde
-safetensors
-psutil
-rich
-tomesd
+transformers==4.25.1
diff --git a/requirements_versions.txt b/requirements_versions.txt
index de501fda..f71b9d6c 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,29 +1,30 @@
-blendmodes==2022
-transformers==4.25.1
+GitPython==3.1.30
+Pillow==9.5.0
accelerate==0.18.0
basicsr==1.4.2
+blendmodes==2022
+clean-fid==0.1.35
+einops==0.4.1
+fastapi==0.94.0
gfpgan==1.3.8
-gradio==3.31.0
+gradio==3.32.0
+httpcore<=0.15
+inflection==0.5.1
+jsonmerge==1.8.0
+kornia==0.6.7
+lark==1.1.2
numpy==1.23.5
-Pillow==9.5.0
-realesrgan==0.3.0
-torch
omegaconf==2.2.3
+piexif==1.1.3
+psutil~=5.9.5
pytorch_lightning==1.9.4
+realesrgan==0.3.0
+resize-right==0.0.2
+safetensors==0.3.1
scikit-image==0.20.0
timm==0.6.7
-piexif==1.1.3
-einops==0.4.1
-jsonmerge==1.8.0
-clean-fid==0.1.35
-resize-right==0.0.2
+tomesd==0.1.2
+torch
torchdiffeq==0.2.3
-kornia==0.6.7
-lark==1.1.2
-inflection==0.5.1
-GitPython==3.1.30
torchsde==0.2.5
-safetensors==0.3.1
-httpcore<=0.15
-fastapi==0.94.0
-tomesd==0.1.2
+transformers==4.25.1
diff --git a/script.js b/script.js
index f7612779..de9d7e22 100644
--- a/script.js
+++ b/script.js
@@ -19,35 +19,79 @@ function get_uiCurrentTabContent() {
}
var uiUpdateCallbacks = [];
+var uiAfterUpdateCallbacks = [];
var uiLoadedCallbacks = [];
var uiTabChangeCallbacks = [];
var optionsChangedCallbacks = [];
+var uiAfterUpdateTimeout = null;
var uiCurrentTab = null;
+/**
+ * Register callback to be called at each UI update.
+ * The callback receives an array of MutationRecords as an argument.
+ */
function onUiUpdate(callback) {
uiUpdateCallbacks.push(callback);
}
+
+/**
+ * Register callback to be called soon after UI updates.
+ * The callback receives no arguments.
+ *
+ * This is preferred over `onUiUpdate` if you don't need
+ * access to the MutationRecords, as your function will
+ * not be called quite as often.
+ */
+function onAfterUiUpdate(callback) {
+ uiAfterUpdateCallbacks.push(callback);
+}
+
+/**
+ * Register callback to be called when the UI is loaded.
+ * The callback receives no arguments.
+ */
function onUiLoaded(callback) {
uiLoadedCallbacks.push(callback);
}
+
+/**
+ * Register callback to be called when the UI tab is changed.
+ * The callback receives no arguments.
+ */
function onUiTabChange(callback) {
uiTabChangeCallbacks.push(callback);
}
+
+/**
+ * Register callback to be called when the options are changed.
+ * The callback receives no arguments.
+ * @param callback
+ */
function onOptionsChanged(callback) {
optionsChangedCallbacks.push(callback);
}
-function runCallback(x, m) {
- try {
- x(m);
- } catch (e) {
- (console.error || console.log).call(console, e.message, e);
+function executeCallbacks(queue, arg) {
+ for (const callback of queue) {
+ try {
+ callback(arg);
+ } catch (e) {
+ console.error("error running callback", callback, ":", e);
+ }
}
}
-function executeCallbacks(queue, m) {
- queue.forEach(function(x) {
- runCallback(x, m);
- });
+
+/**
+ * Schedule the execution of the callbacks registered with onAfterUiUpdate.
+ * The callbacks are executed after a short while, unless another call to this function
+ * is made before that time. IOW, the callbacks are executed only once, even
+ * when there are multiple mutations observed.
+ */
+function scheduleAfterUiUpdateCallbacks() {
+ clearTimeout(uiAfterUpdateTimeout);
+ uiAfterUpdateTimeout = setTimeout(function() {
+ executeCallbacks(uiAfterUpdateCallbacks);
+ }, 200);
}
var executedOnLoaded = false;
@@ -60,6 +104,7 @@ document.addEventListener("DOMContentLoaded", function() {
}
executeCallbacks(uiUpdateCallbacks, m);
+ scheduleAfterUiUpdateCallbacks();
const newTab = get_uiCurrentTab();
if (newTab && (newTab !== uiCurrentTab)) {
uiCurrentTab = newTab;
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index b918a764..4dc24615 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -1,13 +1,12 @@
import copy
import random
-import sys
-import traceback
import shlex
import modules.scripts as scripts
import gradio as gr
from modules import sd_samplers
+from modules.errors import print_error
from modules.processing import Processed, process_images
from modules.shared import state
@@ -136,8 +135,7 @@ class Script(scripts.Script):
try:
args = cmdargs(line)
except Exception:
- print(f"Error parsing line {line} as commandline:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ print_error(f"Error parsing line {line} as commandline", exc_info=True)
args = {"prompt": line}
else:
args = {"prompt": line}
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index da820b39..7821cc65 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, sd_samplers, processing, sd_models, sd_vae
+from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, state
import modules.shared as shared
@@ -220,6 +220,10 @@ axis_options = [
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
+ AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
+ AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
+ AxisOption("Schedule rho", float, apply_override("rho")),
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
diff --git a/style.css b/style.css
index ba12723a..571f4cf4 100644
--- a/style.css
+++ b/style.css
@@ -756,13 +756,22 @@ footer {
.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
display: none;
position: absolute;
- right: 0;
color: white;
+ right: 0;
+}
+.extra-network-cards .card .metadata-button {
text-shadow: 2px 2px 3px black;
padding: 0.25em;
font-size: 22pt;
width: 1.5em;
}
+.extra-network-thumbs .card .metadata-button {
+ text-shadow: 1px 1px 2px black;
+ padding: 0;
+ font-size: 16pt;
+ width: 1em;
+ top: -0.25em;
+}
.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
display: inline-block;
}
@@ -787,6 +796,13 @@ footer {
position: relative;
}
+.extra-network-thumbs .card .preview{
+ position: absolute;
+ object-fit: cover;
+ width: 100%;
+ height:100%;
+}
+
.extra-network-thumbs .card:hover .additional a {
display: inline-block;
}
diff --git a/webui-user.sh b/webui-user.sh
index 49a426ff..70306c60 100644
--- a/webui-user.sh
+++ b/webui-user.sh
@@ -36,7 +36,6 @@
# Fixed git commits
#export STABLE_DIFFUSION_COMMIT_HASH=""
-#export TAMING_TRANSFORMERS_COMMIT_HASH=""
#export CODEFORMER_COMMIT_HASH=""
#export BLIP_COMMIT_HASH=""
diff --git a/webui.py b/webui.py
index 07c70c46..3df2cd1a 100644
--- a/webui.py
+++ b/webui.py
@@ -58,6 +58,7 @@ import modules.sd_hijack
import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
+import modules.sd_unet
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
@@ -134,7 +135,7 @@ there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip())
- expected_xformers_version = "0.0.17"
+ expected_xformers_version = "0.0.20"
if shared.xformers_available:
import xformers
@@ -291,9 +292,23 @@ def initialize_rest(*, reload_script_modules=False):
modules.sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
- # load model in parallel to other startup stuff
- # (when reloading, this does nothing)
- Thread(target=lambda: shared.sd_model).start()
+ modules.sd_unet.list_unets()
+ startup_timer.record("scripts list_unets")
+
+ def load_model():
+ """
+ Accesses shared.sd_model property to load model.
+ After it's available, if it has been loaded before this access by some extension,
+ its optimization may be None because the list of optimizaers has neet been filled
+ by that time, so we apply optimization again.
+ """
+
+ shared.sd_model # noqa: B018
+
+ if modules.sd_hijack.current_optimizer is None:
+ modules.sd_hijack.apply_optimizations()
+
+ Thread(target=load_model).start()
Thread(target=devices.first_time_calculation).start()
@@ -372,17 +387,6 @@ def webui():
gradio_auth_creds = list(get_gradio_auth_creds()) or None
- # this restores the missing /docs endpoint
- if launch_api and not hasattr(FastAPI, 'original_setup'):
- # TODO: replace this with `launch(app_kwargs=...)` if https://github.com/gradio-app/gradio/pull/4282 gets merged
- def fastapi_setup(self):
- self.docs_url = "/docs"
- self.redoc_url = "/redoc"
- self.original_setup()
-
- FastAPI.original_setup = FastAPI.setup
- FastAPI.setup = fastapi_setup
-
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
@@ -395,6 +399,10 @@ def webui():
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path,
+ app_kwargs={
+ "docs_url": "/docs",
+ "redoc_url": "/redoc",
+ },
)
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])
diff --git a/webui.sh b/webui.sh
index ab52ac3b..607557b1 100755
--- a/webui.sh
+++ b/webui.sh
@@ -124,9 +124,12 @@ case "$gpu_info" in
*)
;;
esac
-if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+if ! echo "$gpu_info" | grep -q "NVIDIA";
then
- export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+ if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+ then
+ export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+ fi
fi
for preq in "${GIT}" "${python_cmd}"
@@ -190,7 +193,7 @@ fi
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
- TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
+ TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}"