diff options
43 files changed, 1165 insertions, 210 deletions
diff --git a/.eslintrc.js b/.eslintrc.js index 944cc869..f33aca09 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -50,13 +50,14 @@ module.exports = { globals: { //script.js gradioApp: "readonly", + executeCallbacks: "readonly", + onAfterUiUpdate: "readonly", + onOptionsChanged: "readonly", onUiLoaded: "readonly", onUiUpdate: "readonly", - onOptionsChanged: "readonly", uiCurrentTab: "writable", - uiElementIsVisible: "readonly", uiElementInSight: "readonly", - executeCallbacks: "readonly", + uiElementIsVisible: "readonly", //ui.js opts: "writable", all_gallery_buttons: "readonly", @@ -84,5 +85,7 @@ module.exports = { // imageviewer.js modalPrevImage: "readonly", modalNextImage: "readonly", + // token-counters.js + setupTokenCounters: "readonly", } }; diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 81c5101b..27a86e13 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -10,7 +10,7 @@ from contextlib import contextmanager from torch.optim.lr_scheduler import LambdaLR from ldm.modules.ema import LitEma -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from vqvae_quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.util import instantiate_from_config diff --git a/extensions-builtin/LDSR/vqvae_quantize.py b/extensions-builtin/LDSR/vqvae_quantize.py new file mode 100644 index 00000000..dd14b8fd --- /dev/null +++ b/extensions-builtin/LDSR/vqvae_quantize.py @@ -0,0 +1,147 @@ +# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py, +# where the license is as follows: +# +# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +# OR OTHER DEALINGS IN THE SOFTWARE./ + +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits is False, "Only for interface compatible with Gumbel" + assert return_logits is False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js new file mode 100644 index 00000000..f555960d --- /dev/null +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -0,0 +1,431 @@ +// Main + +// Helper functions +// Get active tab +function getActiveTab(elements, all = false) { + const tabs = elements.img2imgTabs.querySelectorAll("button"); + + if (all) return tabs; + + for (let tab of tabs) { + if (tab.classList.contains("selected")) { + return tab; + } + } +} + +onUiLoaded(async() => { + const hotkeysConfig = { + resetZoom: "KeyR", + fitToScreen: "KeyS", + moveKey: "KeyF", + overlap: "KeyO" + }; + + let isMoving = false; + let mouseX, mouseY; + + const elementIDs = { + sketch: "#img2img_sketch", + inpaint: "#img2maskimg", + inpaintSketch: "#inpaint_sketch", + img2imgTabs: "#mode_img2img .tab-nav" + }; + + async function getElements() { + const elements = await Promise.all( + Object.values(elementIDs).map(id => document.querySelector(id)) + ); + return Object.fromEntries( + Object.keys(elementIDs).map((key, index) => [key, elements[index]]) + ); + } + + const elements = await getElements(); + + function applyZoomAndPan(targetElement, elemId) { + targetElement.style.transformOrigin = "0 0"; + let [zoomLevel, panX, panY] = [1, 0, 0]; + let fullScreenMode = false; + + // In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui. + function fixCanvas() { + const activeTab = getActiveTab(elements).textContent.trim(); + + if (activeTab !== "img2img") { + const img = targetElement.querySelector(`${elemId} img`); + + if (img && img.style.display !== "none") { + img.style.display = "none"; + img.style.visibility = "hidden"; + } + } + } + + // Reset the zoom level and pan position of the target element to their initial values + function resetZoom() { + zoomLevel = 1; + panX = 0; + panY = 0; + + fixCanvas(); + targetElement.style.transform = `scale(${zoomLevel}) translate(${panX}px, ${panY}px)`; + + const canvas = gradioApp().querySelector( + `${elemId} canvas[key="interface"]` + ); + + toggleOverlap("off"); + fullScreenMode = false; + + if ( + canvas && + parseFloat(canvas.style.width) > 865 && + parseFloat(targetElement.style.width) > 865 + ) { + fitToElement(); + return; + } + + targetElement.style.width = ""; + if (canvas) { + targetElement.style.height = canvas.style.height; + } + } + + // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements + function toggleOverlap(forced = "") { + const zIndex1 = "0"; + const zIndex2 = "998"; + + targetElement.style.zIndex = + targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1; + + if (forced === "off") { + targetElement.style.zIndex = zIndex1; + } else if (forced === "on") { + targetElement.style.zIndex = zIndex2; + } + } + + // Adjust the brush size based on the deltaY value from a mouse wheel event + function adjustBrushSize( + elemId, + deltaY, + withoutValue = false, + percentage = 5 + ) { + const input = + gradioApp().querySelector( + `${elemId} input[aria-label='Brush radius']` + ) || + gradioApp().querySelector( + `${elemId} button[aria-label="Use brush"]` + ); + + if (input) { + input.click(); + if (!withoutValue) { + const maxValue = + parseFloat(input.getAttribute("max")) || 100; + const changeAmount = maxValue * (percentage / 100); + const newValue = + parseFloat(input.value) + + (deltaY > 0 ? -changeAmount : changeAmount); + input.value = Math.min(Math.max(newValue, 0), maxValue); + input.dispatchEvent(new Event("change")); + } + } + } + + // Reset zoom when uploading a new image + const fileInput = gradioApp().querySelector( + `${elemId} input[type="file"][accept="image/*"].svelte-116rqfv` + ); + fileInput.addEventListener("click", resetZoom); + + // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables + function updateZoom(newZoomLevel, mouseX, mouseY) { + newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15)); + panX += mouseX - (mouseX * newZoomLevel) / zoomLevel; + panY += mouseY - (mouseY * newZoomLevel) / zoomLevel; + + targetElement.style.transformOrigin = "0 0"; + targetElement.style.transform = `translate(${panX}px, ${panY}px) scale(${newZoomLevel})`; + + toggleOverlap("on"); + return newZoomLevel; + } + + // Change the zoom level based on user interaction + function changeZoomLevel(operation, e) { + if (e.shiftKey) { + e.preventDefault(); + + let zoomPosX, zoomPosY; + let delta = 0.2; + if (zoomLevel > 7) { + delta = 0.9; + } else if (zoomLevel > 2) { + delta = 0.6; + } + + zoomPosX = e.clientX; + zoomPosY = e.clientY; + + fullScreenMode = false; + zoomLevel = updateZoom( + zoomLevel + (operation === "+" ? delta : -delta), + zoomPosX - targetElement.getBoundingClientRect().left, + zoomPosY - targetElement.getBoundingClientRect().top + ); + } + } + + /** + * This function fits the target element to the screen by calculating + * the required scale and offsets. It also updates the global variables + * zoomLevel, panX, and panY to reflect the new state. + */ + + function fitToElement() { + //Reset Zoom + targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + + // Get element and screen dimensions + const elementWidth = targetElement.offsetWidth; + const elementHeight = targetElement.offsetHeight; + const parentElement = targetElement.parentElement; + const screenWidth = parentElement.clientWidth; + const screenHeight = parentElement.clientHeight; + + // Get element's coordinates relative to the parent element + const elementRect = targetElement.getBoundingClientRect(); + const parentRect = parentElement.getBoundingClientRect(); + const elementX = elementRect.x - parentRect.x; + + // Calculate scale and offsets + const scaleX = screenWidth / elementWidth; + const scaleY = screenHeight / elementHeight; + const scale = Math.min(scaleX, scaleY); + + const transformOrigin = + window.getComputedStyle(targetElement).transformOrigin; + const [originX, originY] = transformOrigin.split(" "); + const originXValue = parseFloat(originX); + const originYValue = parseFloat(originY); + + const offsetX = + (screenWidth - elementWidth * scale) / 2 - + originXValue * (1 - scale); + const offsetY = + (screenHeight - elementHeight * scale) / 2.5 - + originYValue * (1 - scale); + + // Apply scale and offsets to the element + targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`; + + // Update global variables + zoomLevel = scale; + panX = offsetX; + panY = offsetY; + + fullScreenMode = false; + toggleOverlap("off"); + } + + /** + * This function fits the target element to the screen by calculating + * the required scale and offsets. It also updates the global variables + * zoomLevel, panX, and panY to reflect the new state. + */ + + // Fullscreen mode + function fitToScreen() { + const canvas = gradioApp().querySelector( + `${elemId} canvas[key="interface"]` + ); + + if (!canvas) return; + + if (canvas.offsetWidth > 862) { + targetElement.style.width = canvas.offsetWidth + "px"; + } + + if (fullScreenMode) { + resetZoom(); + fullScreenMode = false; + return; + } + + //Reset Zoom + targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + + // Get scrollbar width to right-align the image + const scrollbarWidth = window.innerWidth - document.documentElement.clientWidth; + + // Get element and screen dimensions + const elementWidth = targetElement.offsetWidth; + const elementHeight = targetElement.offsetHeight; + const screenWidth = window.innerWidth - scrollbarWidth; + const screenHeight = window.innerHeight; + + // Get element's coordinates relative to the page + const elementRect = targetElement.getBoundingClientRect(); + const elementY = elementRect.y; + const elementX = elementRect.x; + + // Calculate scale and offsets + const scaleX = screenWidth / elementWidth; + const scaleY = screenHeight / elementHeight; + const scale = Math.min(scaleX, scaleY); + + // Get the current transformOrigin + const computedStyle = window.getComputedStyle(targetElement); + const transformOrigin = computedStyle.transformOrigin; + const [originX, originY] = transformOrigin.split(" "); + const originXValue = parseFloat(originX); + const originYValue = parseFloat(originY); + + // Calculate offsets with respect to the transformOrigin + const offsetX = + (screenWidth - elementWidth * scale) / 2 - + elementX - + originXValue * (1 - scale); + const offsetY = + (screenHeight - elementHeight * scale) / 2 - + elementY - + originYValue * (1 - scale); + + // Apply scale and offsets to the element + targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`; + + // Update global variables + zoomLevel = scale; + panX = offsetX; + panY = offsetY; + + fullScreenMode = true; + toggleOverlap("on"); + } + + // Handle keydown events + function handleKeyDown(event) { + const hotkeyActions = { + [hotkeysConfig.resetZoom]: resetZoom, + [hotkeysConfig.overlap]: toggleOverlap, + [hotkeysConfig.fitToScreen]: fitToScreen + // [hotkeysConfig.moveKey] : moveCanvas, + }; + + const action = hotkeyActions[event.code]; + if (action) { + event.preventDefault(); + action(event); + } + } + + // Get Mouse position + function getMousePosition(e) { + mouseX = e.offsetX; + mouseY = e.offsetY; + } + + targetElement.addEventListener("mousemove", getMousePosition); + + // Handle events only inside the targetElement + let isKeyDownHandlerAttached = false; + + function handleMouseMove() { + if (!isKeyDownHandlerAttached) { + document.addEventListener("keydown", handleKeyDown); + isKeyDownHandlerAttached = true; + } + } + + function handleMouseLeave() { + if (isKeyDownHandlerAttached) { + document.removeEventListener("keydown", handleKeyDown); + isKeyDownHandlerAttached = false; + } + } + + // Add mouse event handlers + targetElement.addEventListener("mousemove", handleMouseMove); + targetElement.addEventListener("mouseleave", handleMouseLeave); + + // Reset zoom when click on another tab + elements.img2imgTabs.addEventListener("click", resetZoom); + elements.img2imgTabs.addEventListener("click", () => { + // targetElement.style.width = ""; + if (parseInt(targetElement.style.width) > 865) { + setTimeout(fitToElement, 0); + } + }); + + targetElement.addEventListener("wheel", e => { + // change zoom level + const operation = e.deltaY > 0 ? "-" : "+"; + changeZoomLevel(operation, e); + + // Handle brush size adjustment with ctrl key pressed + if (e.ctrlKey || e.metaKey) { + e.preventDefault(); + + // Increase or decrease brush size based on scroll direction + adjustBrushSize(elemId, e.deltaY); + } + }); + + /** + * Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element. + * @param {MouseEvent} e - The mouse event. + */ + function handleMoveKeyDown(e) { + if (e.code === hotkeysConfig.moveKey) { + if (!e.ctrlKey && !e.metaKey) { + isMoving = true; + } + } + } + + function handleMoveKeyUp(e) { + if (e.code === hotkeysConfig.moveKey) { + isMoving = false; + } + } + + document.addEventListener("keydown", handleMoveKeyDown); + document.addEventListener("keyup", handleMoveKeyUp); + + // Detect zoom level and update the pan speed. + function updatePanPosition(movementX, movementY) { + let panSpeed = 1.5; + + if (zoomLevel > 8) { + panSpeed = 2.5; + } + + panX = panX + movementX * panSpeed; + panY = panY + movementY * panSpeed; + + targetElement.style.transform = `translate(${panX}px, ${panY}px) scale(${zoomLevel})`; + toggleOverlap("on"); + } + + function handleMoveByKey(e) { + if (isMoving) { + updatePanPosition(e.movementX, e.movementY); + targetElement.style.pointerEvents = "none"; + } else { + targetElement.style.pointerEvents = "auto"; + } + } + + gradioApp().addEventListener("mousemove", handleMoveByKey); + } + + applyZoomAndPan(elements.sketch, elementIDs.sketch); + applyZoomAndPan(elements.inpaint, elementIDs.inpaint); + applyZoomAndPan(elements.inpaintSketch, elementIDs.inpaintSketch); +}); diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index 1c08a1a9..2cf2d571 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) { } -onUiUpdate(function() { +onAfterUiUpdate(function() { var arPreviewRect = gradioApp().querySelector('#imageARPreview'); if (arPreviewRect) { arPreviewRect.style.display = 'none'; diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index f14af1d4..d60a10c4 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -167,6 +167,4 @@ var addContextMenuEventListener = initResponse[2]; })(); //End example Context Menu Items -onUiUpdate(function() { - addContextMenuEventListener(); -}); +onAfterUiUpdate(addContextMenuEventListener); diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 77a24a07..5803daea 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) { } } +function eventHasFiles(e) { + if (!e.dataTransfer || !e.dataTransfer.files) return false; + if (e.dataTransfer.files.length > 0) return true; + if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true; + + return false; +} + +function dragDropTargetIsPrompt(target) { + if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true; + if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true; + return false; +} + window.document.addEventListener('dragover', e => { const target = e.composedPath()[0]; - const imgWrap = target.closest('[data-testid="image"]'); - if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { - return; - } + if (!eventHasFiles(e)) return; + + var targetImage = target.closest('[data-testid="image"]'); + if (!dragDropTargetIsPrompt(target) && !targetImage) return; + e.stopPropagation(); e.preventDefault(); e.dataTransfer.dropEffect = 'copy'; @@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => { window.document.addEventListener('drop', e => { const target = e.composedPath()[0]; - if (target.placeholder.indexOf("Prompt") == -1) { - return; + if (!eventHasFiles(e)) return; + + if (dragDropTargetIsPrompt(target)) { + e.stopPropagation(); + e.preventDefault(); + + let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; + + const imgParent = gradioApp().getElementById(prompt_target); + const files = e.dataTransfer.files; + const fileInput = imgParent.querySelector('input[type="file"]'); + if (fileInput) { + fileInput.files = files; + fileInput.dispatchEvent(new Event('change')); + } } - const imgWrap = target.closest('[data-testid="image"]'); - if (!imgWrap) { + + var targetImage = target.closest('[data-testid="image"]'); + if (targetImage) { + e.stopPropagation(); + e.preventDefault(); + const files = e.dataTransfer.files; + dropReplaceImage(targetImage, files); return; } - e.stopPropagation(); - e.preventDefault(); - const files = e.dataTransfer.files; - dropReplaceImage(imgWrap, files); }); window.addEventListener('paste', e => { diff --git a/javascript/generationParams.js b/javascript/generationParams.js index a877f8a5..7c0fd221 100644 --- a/javascript/generationParams.js +++ b/javascript/generationParams.js @@ -1,7 +1,7 @@ // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes let txt2img_gallery, img2img_gallery, modal = undefined; -onUiUpdate(function() { +onAfterUiUpdate(function() { if (!txt2img_gallery) { txt2img_gallery = attachGalleryListeners("txt2img"); } diff --git a/javascript/hints.js b/javascript/hints.js index 46f342cb..05ae5f22 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -116,17 +116,25 @@ var titles = { "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." }; -function updateTooltipForSpan(span) { - if (span.title) return; // already has a title +function updateTooltip(element) { + if (element.title) return; // already has a title - let tooltip = localization[titles[span.textContent]] || titles[span.textContent]; + let text = element.textContent; + let tooltip = localization[titles[text]] || titles[text]; if (!tooltip) { - tooltip = localization[titles[span.value]] || titles[span.value]; + let value = element.value; + if (value) tooltip = localization[titles[value]] || titles[value]; } if (!tooltip) { - for (const c of span.classList) { + // Gradio dropdown options have `data-value`. + let dataValue = element.dataset.value; + if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue]; + } + + if (!tooltip) { + for (const c of element.classList) { if (c in titles) { tooltip = localization[titles[c]] || titles[c]; break; @@ -135,34 +143,53 @@ function updateTooltipForSpan(span) { } if (tooltip) { - span.title = tooltip; + element.title = tooltip; } } -function updateTooltipForSelect(select) { - if (select.onchange != null) return; +// Nodes to check for adding tooltips. +const tooltipCheckNodes = new Set(); +// Timer for debouncing tooltip check. +let tooltipCheckTimer = null; - select.onchange = function() { - select.title = localization[titles[select.value]] || titles[select.value] || ""; - }; +function processTooltipCheckNodes() { + for (const node of tooltipCheckNodes) { + updateTooltip(node); + } + tooltipCheckNodes.clear(); } -var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1}; - -onUiUpdate(function(m) { - m.forEach(function(record) { - record.addedNodes.forEach(function(node) { - if (observedTooltipElements[node.tagName]) { - updateTooltipForSpan(node); - } - if (node.tagName == "SELECT") { - updateTooltipForSelect(node); +onUiUpdate(function(mutationRecords) { + for (const record of mutationRecords) { + if (record.type === "childList" && record.target.classList.contains("options")) { + // This smells like a Gradio dropdown menu having changed, + // so let's enqueue an update for the input element that shows the current value. + let wrap = record.target.parentNode; + let input = wrap?.querySelector("input"); + if (input) { + input.title = ""; // So we'll even have a chance to update it. + tooltipCheckNodes.add(input); } - - if (node.querySelectorAll) { - node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan); - node.querySelectorAll('select').forEach(updateTooltipForSelect); + } + for (const node of record.addedNodes) { + if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) { + if (!node.title) { + if ( + node.tagName === "SPAN" || + node.tagName === "BUTTON" || + node.tagName === "P" || + node.tagName === "INPUT" || + (node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item + ) { + tooltipCheckNodes.add(node); + } + } + node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n)); } - }); - }); + } + } + if (tooltipCheckNodes.size) { + clearTimeout(tooltipCheckTimer); + tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000); + } }); diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js index 3c9b8a6f..900c56f3 100644 --- a/javascript/imageMaskFix.js +++ b/javascript/imageMaskFix.js @@ -39,5 +39,5 @@ function imageMaskResize() { }); } -onUiUpdate(imageMaskResize); +onAfterUiUpdate(imageMaskResize); window.addEventListener('resize', imageMaskResize); diff --git a/javascript/imageParams.js b/javascript/imageParams.js deleted file mode 100644 index 057e2d39..00000000 --- a/javascript/imageParams.js +++ /dev/null @@ -1,18 +0,0 @@ -window.onload = (function() { - window.addEventListener('drop', e => { - const target = e.composedPath()[0]; - if (target.placeholder.indexOf("Prompt") == -1) return; - - let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; - - e.stopPropagation(); - e.preventDefault(); - const imgParent = gradioApp().getElementById(prompt_target); - const files = e.dataTransfer.files; - const fileInput = imgParent.querySelector('input[type="file"]'); - if (fileInput) { - fileInput.files = files; - fileInput.dispatchEvent(new Event('change')); - } - }); -}); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 78e24eb9..677e95c1 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -170,7 +170,7 @@ function modalTileImageToggle(event) { event.stopPropagation(); } -onUiUpdate(function() { +onAfterUiUpdate(function() { var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img'); if (fullImg_preview != null) { fullImg_preview.forEach(setupImageForLightbox); diff --git a/javascript/imageviewerGamepad.js b/javascript/imageviewerGamepad.js index 31d226de..a22c7e6e 100644 --- a/javascript/imageviewerGamepad.js +++ b/javascript/imageviewerGamepad.js @@ -1,7 +1,9 @@ +let gamepads = []; + window.addEventListener('gamepadconnected', (e) => { const index = e.gamepad.index; let isWaiting = false; - setInterval(async() => { + gamepads[index] = setInterval(async() => { if (!opts.js_modal_lightbox_gamepad || isWaiting) return; const gamepad = navigator.getGamepads()[index]; const xValue = gamepad.axes[0]; @@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => { }, 10); }); +window.addEventListener('gamepaddisconnected', (e) => { + clearInterval(gamepads[e.gamepad.index]); +}); + /* Primarily for vr controller type pointer devices. I use the wheel event because there's currently no way to do it properly with web xr. diff --git a/javascript/notification.js b/javascript/notification.js index a68a76f2..76c5715d 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -4,7 +4,7 @@ let lastHeadImg = null; let notificationButton = null; -onUiUpdate(function() { +onAfterUiUpdate(function() { if (notificationButton == null) { notificationButton = gradioApp().getElementById('request_notifications'); diff --git a/javascript/token-counters.js b/javascript/token-counters.js new file mode 100644 index 00000000..9d81a723 --- /dev/null +++ b/javascript/token-counters.js @@ -0,0 +1,83 @@ +let promptTokenCountDebounceTime = 800; +let promptTokenCountTimeouts = {}; +var promptTokenCountUpdateFunctions = {}; + +function update_txt2img_tokens(...args) { + // Called from Gradio + update_token_counter("txt2img_token_button"); + if (args.length == 2) { + return args[0]; + } + return args; +} + +function update_img2img_tokens(...args) { + // Called from Gradio + update_token_counter("img2img_token_button"); + if (args.length == 2) { + return args[0]; + } + return args; +} + +function update_token_counter(button_id) { + if (opts.disable_token_counters) { + return; + } + if (promptTokenCountTimeouts[button_id]) { + clearTimeout(promptTokenCountTimeouts[button_id]); + } + promptTokenCountTimeouts[button_id] = setTimeout( + () => gradioApp().getElementById(button_id)?.click(), + promptTokenCountDebounceTime, + ); +} + + +function recalculatePromptTokens(name) { + promptTokenCountUpdateFunctions[name]?.(); +} + +function recalculate_prompts_txt2img() { + // Called from Gradio + recalculatePromptTokens('txt2img_prompt'); + recalculatePromptTokens('txt2img_neg_prompt'); + return Array.from(arguments); +} + +function recalculate_prompts_img2img() { + // Called from Gradio + recalculatePromptTokens('img2img_prompt'); + recalculatePromptTokens('img2img_neg_prompt'); + return Array.from(arguments); +} + +function setupTokenCounting(id, id_counter, id_button) { + var prompt = gradioApp().getElementById(id); + var counter = gradioApp().getElementById(id_counter); + var textarea = gradioApp().querySelector(`#${id} > label > textarea`); + + if (opts.disable_token_counters) { + counter.style.display = "none"; + return; + } + + if (counter.parentElement == prompt.parentElement) { + return; + } + + prompt.parentElement.insertBefore(counter, prompt); + prompt.parentElement.style.position = "relative"; + + promptTokenCountUpdateFunctions[id] = function() { + update_token_counter(id_button); + }; + textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]); +} + +function setupTokenCounters() { + setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); + setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); + setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); + setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); +} diff --git a/javascript/ui.js b/javascript/ui.js index 648a5290..d70a681b 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -248,29 +248,8 @@ function confirm_clear_prompt(prompt, negative_prompt) { } -var promptTokecountUpdateFuncs = {}; - -function recalculatePromptTokens(name) { - if (promptTokecountUpdateFuncs[name]) { - promptTokecountUpdateFuncs[name](); - } -} - -function recalculate_prompts_txt2img() { - recalculatePromptTokens('txt2img_prompt'); - recalculatePromptTokens('txt2img_neg_prompt'); - return Array.from(arguments); -} - -function recalculate_prompts_img2img() { - recalculatePromptTokens('img2img_prompt'); - recalculatePromptTokens('img2img_neg_prompt'); - return Array.from(arguments); -} - - var opts = {}; -onUiUpdate(function() { +onAfterUiUpdate(function() { if (Object.keys(opts).length != 0) return; var json_elem = gradioApp().getElementById('settings_json'); @@ -302,28 +281,7 @@ onUiUpdate(function() { json_elem.parentElement.style.display = "none"; - function registerTextarea(id, id_counter, id_button) { - var prompt = gradioApp().getElementById(id); - var counter = gradioApp().getElementById(id_counter); - var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - - if (counter.parentElement == prompt.parentElement) { - return; - } - - prompt.parentElement.insertBefore(counter, prompt); - prompt.parentElement.style.position = "relative"; - - promptTokecountUpdateFuncs[id] = function() { - update_token_counter(id_button); - }; - textarea.addEventListener("input", promptTokecountUpdateFuncs[id]); - } - - registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); - registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); - registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); - registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); + setupTokenCounters(); var show_all_pages = gradioApp().getElementById('settings_show_all_pages'); var settings_tabs = gradioApp().querySelector('#settings div'); @@ -354,33 +312,6 @@ onOptionsChanged(function() { }); let txt2img_textarea, img2img_textarea = undefined; -let wait_time = 800; -let token_timeouts = {}; - -function update_txt2img_tokens(...args) { - update_token_counter("txt2img_token_button"); - if (args.length == 2) { - return args[0]; - } - return args; -} - -function update_img2img_tokens(...args) { - update_token_counter( - "img2img_token_button" - ); - if (args.length == 2) { - return args[0]; - } - return args; -} - -function update_token_counter(button_id) { - if (token_timeouts[button_id]) { - clearTimeout(token_timeouts[button_id]); - } - token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); -} function restart_reload() { document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; diff --git a/modules/api/api.py b/modules/api/api.py index eee99bbb..be28c59a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -23,6 +23,7 @@ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights +from modules.sd_vae import vae_dict from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -189,6 +190,7 @@ class Api: self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) @@ -541,6 +543,9 @@ class Api: def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] + def get_sd_vaes(self): + return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()] + def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] @@ -700,4 +705,4 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port) + uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) diff --git a/modules/api/models.py b/modules/api/models.py index 1ff2fb33..47fdede2 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -249,6 +249,10 @@ class SDModelItem(BaseModel): filename: str = Field(title="Filename") config: Optional[str] = Field(title="Config file") +class SDVaeItem(BaseModel): + model_name: str = Field(title="Model Name") + filename: str = Field(title="Filename") + class HypernetworkItem(BaseModel): name: str = Field(title="Name") path: Optional[str] = Field(title="Path") diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 3eeb84d5..0974056d 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
-parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
+parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
diff --git a/modules/devices.py b/modules/devices.py index d8a34a0f..1ed6ffdc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,5 +1,7 @@ import sys import contextlib +from functools import lru_cache + import torch from modules import errors @@ -154,3 +156,19 @@ def test_for_nans(x, where): message += " Use --disable-nan-check commandline argument to disable this check." raise NansException(message) + + +@lru_cache +def first_time_calculation(): + """ + just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and + spends about 2.7 seconds doing that, at least wih NVidia. + """ + + x = torch.zeros((1, 1)).to(device, dtype) + linear = torch.nn.Linear(1, 1).to(device, dtype) + linear(x) + + x = torch.zeros((1, 1, 3, 3)).to(device, dtype) + conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) + conv2d(x) diff --git a/modules/errors.py b/modules/errors.py index f6b80dbb..da4694f8 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -12,9 +12,13 @@ def print_error_explanation(message): print('=' * max_len, file=sys.stderr)
-def display(e: Exception, task):
+def display(e: Exception, task, *, full_traceback=False):
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ te = traceback.TracebackException.from_exception(e)
+ if full_traceback:
+ # include frames leading up to the try-catch block
+ te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
+ print(*te.format(), sep="", file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index d5f0a49b..071bd9ea 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -35,7 +35,7 @@ def reset(): def quote(text):
- if ',' not in str(text) and '\n' not in str(text):
+ if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
return text
return json.dumps(text, ensure_ascii=False)
@@ -306,6 +306,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "RNG" not in res:
res["RNG"] = "GPU"
+ if "Schedule type" not in res:
+ res["Schedule type"] = "Automatic"
+
+ if "Schedule max sigma" not in res:
+ res["Schedule max sigma"] = 0
+
+ if "Schedule min sigma" not in res:
+ res["Schedule min sigma"] = 0
+
+ if "Schedule rho" not in res:
+ res["Schedule rho"] = 0
+
return res
@@ -318,6 +330,10 @@ infotext_to_setting_name_mapping = [ ('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'),
+ ('Schedule type', 'k_sched_type'),
+ ('Schedule max sigma', 'sigma_max'),
+ ('Schedule min sigma', 'sigma_min'),
+ ('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
diff --git a/modules/images.py b/modules/images.py index 4e8cd993..e21e554c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -21,6 +21,8 @@ from modules import sd_samplers, shared, script_callbacks, errors from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
+import modules.sd_vae as sd_vae
+
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -336,8 +338,20 @@ def sanitize_filename_part(text, replace_spaces=True): class FilenameGenerator:
+ def get_vae_filename(self): #get the name of the VAE file.
+ if sd_vae.loaded_vae_file is None:
+ return "NoneType"
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
+ split_file_name = file_name.split('.')
+ if len(split_file_name) > 1 and split_file_name[0] == '':
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
+ else:
+ return split_file_name[0]
+
replacements = {
'seed': lambda self: self.seed if self.seed is not None else '',
+ 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
+ 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.image.width,
@@ -354,19 +368,23 @@ class FilenameGenerator: 'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
- 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
- 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
+ 'batch_size': lambda self: self.p.batch_size,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'vae_filename': lambda self: self.get_vae_filename(),
+
}
default_time_format = '%Y%m%d%H%M%S'
- def __init__(self, p, seed, prompt, image):
+ def __init__(self, p, seed, prompt, image, zip=False):
self.p = p
self.seed = seed
self.prompt = prompt
self.image = image
+ self.zip = zip
def hasprompt(self, *args):
lower = self.prompt.lower()
@@ -665,9 +683,10 @@ def read_info_from_image(image): items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration']:
- items.pop(field, None)
+ for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity']:
+ items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
diff --git a/modules/img2img.py b/modules/img2img.py index d704bf90..4c12c2c5 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -92,7 +92,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask).convert('L')
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 35a52310..ca089674 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -229,13 +229,11 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
- taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
- taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
@@ -286,7 +284,6 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
- git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
diff --git a/modules/paths.py b/modules/paths.py index 5f6474c0..5171df4f 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
- (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
diff --git a/modules/processing.py b/modules/processing.py index 29a3743f..395c851f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -321,14 +321,13 @@ class StableDiffusionProcessing: have been used before. The second element is where the previously
computed result is stored.
"""
-
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]:
return cache[1]
with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps)
- cache[0] = (required_prompts, steps)
+ cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info)
return cache[1]
def setup_conds(self):
@@ -674,6 +673,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
+ sd_unet.apply_unet()
+
if state.job_count == -1:
state.job_count = p.n_iter
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 40f388a5..d2728e12 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -111,6 +111,7 @@ callback_map = dict( callbacks_before_ui=[],
callbacks_on_reload=[],
callbacks_list_optimizers=[],
+ callbacks_list_unets=[],
)
@@ -271,6 +272,18 @@ def list_optimizers_callback(): return res
+def list_unets_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_unets']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_unets')
+
+ return res
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -430,3 +443,10 @@ def on_list_optimizers(callback): to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback)
+
+
+def on_list_unets(callback):
+ """register a function to be called when UI is making a list of alternative options for unet.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
+
+ add_callback(callback_map['callbacks_list_unets'], callback)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f93df0a6..487dfd60 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -3,7 +3,7 @@ from torch.nn.functional import silu from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -43,7 +43,7 @@ def list_optimizers(): optimizers.extend(new_optimizers)
-def apply_optimizations():
+def apply_optimizations(option=None):
global current_optimizer
undo_optimizations()
@@ -60,7 +60,7 @@ def apply_optimizations(): current_optimizer.undo()
current_optimizer = None
- selection = shared.opts.cross_attention_optimization
+ selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
@@ -72,12 +72,13 @@ def apply_optimizations(): matching_optimizer = optimizers[0]
if matching_optimizer is not None:
- print(f"Applying optimization: {matching_optimizer.name}... ", end='')
+ print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
print("done.")
current_optimizer = matching_optimizer
return current_optimizer.name
else:
+ print("Disabling attention optimization")
return ''
@@ -155,9 +156,9 @@ class StableDiffusionModelHijack: def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
- def apply_optimizations(self):
+ def apply_optimizations(self, option=None):
try:
- self.optimization_method = apply_optimizations()
+ self.optimization_method = apply_optimizations(option)
except Exception as e:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
@@ -194,6 +195,11 @@ class StableDiffusionModelHijack: self.layers = flatten(m)
+ if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
+ ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
+
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
@@ -215,6 +221,8 @@ class StableDiffusionModelHijack: self.layers = None
self.clip = None
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
diff --git a/modules/sd_models.py b/modules/sd_models.py index b1afbaa7..232eb9c4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -164,6 +164,7 @@ def model_hash(filename): def select_checkpoint():
+ """Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
@@ -171,14 +172,14 @@ def select_checkpoint(): return checkpoint_info
if len(checkpoints_list) == 0:
- print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
if shared.cmd_opts.ckpt is not None:
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {model_path}", file=sys.stderr)
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
+ error_message += f"\n - directory {model_path}"
if shared.cmd_opts.ckpt_dir is not None:
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
- exit(1)
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
+ raise FileNotFoundError(error_message)
checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
@@ -423,7 +424,7 @@ class SdModelData: try:
load_model()
except Exception as e:
- errors.display(e, "loading stable diffusion model")
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
@@ -508,6 +509,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks")
+ with devices.autocast(), torch.no_grad():
+ sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+
+ timer.record("calculate empty prompt")
+
print(f"Model loaded in {timer.summary()}.")
return sd_model
@@ -527,6 +533,8 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
+ sd_unet.apply_unet("None")
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 59982fc9..e9ba2c61 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -44,6 +44,14 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
+k_diffusion_scheduler = {
+ 'Automatic': None,
+ 'karras': k_diffusion.sampling.get_sigmas_karras,
+ 'exponential': k_diffusion.sampling.get_sigmas_exponential,
+ 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
+}
+
class CFGDenoiser(torch.nn.Module):
"""
@@ -125,6 +133,16 @@ class CFGDenoiser(torch.nn.Module): x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
+ # TODO add infotext entry
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+ elif num_repeats > 0:
+ uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
@@ -255,6 +273,13 @@ class KDiffusionSampler: try:
return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
except sd_samplers_common.InterruptedException:
return self.last_latent
@@ -294,6 +319,31 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
+ elif opts.k_sched_type != "Automatic":
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
+ sigmas_kwargs = {
+ 'sigma_min': sigma_min,
+ 'sigma_max': sigma_max,
+ }
+
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
+ p.extra_generation_params["Schedule type"] = opts.k_sched_type
+
+ if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
+ sigmas_kwargs['sigma_min'] = opts.sigma_min
+ p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
+ if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
+ sigmas_kwargs['sigma_max'] = opts.sigma_max
+ p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
+
+ default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
+
+ if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
+ sigmas_kwargs['rho'] = opts.rho
+ p.extra_generation_params["Schedule rho"] = opts.rho
+
+ sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
diff --git a/modules/sd_unet.py b/modules/sd_unet.py new file mode 100644 index 00000000..6d708ad2 --- /dev/null +++ b/modules/sd_unet.py @@ -0,0 +1,92 @@ +import torch.nn
+import ldm.modules.diffusionmodules.openaimodel
+
+from modules import script_callbacks, shared, devices
+
+unet_options = []
+current_unet_option = None
+current_unet = None
+
+
+def list_unets():
+ new_unets = script_callbacks.list_unets_callback()
+
+ unet_options.clear()
+ unet_options.extend(new_unets)
+
+
+def get_unet_option(option=None):
+ option = option or shared.opts.sd_unet
+
+ if option == "None":
+ return None
+
+ if option == "Automatic":
+ name = shared.sd_model.sd_checkpoint_info.model_name
+
+ options = [x for x in unet_options if x.model_name == name]
+
+ option = options[0].label if options else "None"
+
+ return next(iter([x for x in unet_options if x.label == option]), None)
+
+
+def apply_unet(option=None):
+ global current_unet_option
+ global current_unet
+
+ new_option = get_unet_option(option)
+ if new_option == current_unet_option:
+ return
+
+ if current_unet is not None:
+ print(f"Dectivating unet: {current_unet.option.label}")
+ current_unet.deactivate()
+
+ current_unet_option = new_option
+ if current_unet_option is None:
+ current_unet = None
+
+ if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ shared.sd_model.model.diffusion_model.to(devices.device)
+
+ return
+
+ shared.sd_model.model.diffusion_model.to(devices.cpu)
+ devices.torch_gc()
+
+ current_unet = current_unet_option.create_unet()
+ current_unet.option = current_unet_option
+ print(f"Activating unet: {current_unet.option.label}")
+ current_unet.activate()
+
+
+class SdUnetOption:
+ model_name = None
+ """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
+
+ label = None
+ """name of the unet in UI"""
+
+ def create_unet(self):
+ """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
+ raise NotImplementedError()
+
+
+class SdUnet(torch.nn.Module):
+ def forward(self, x, timesteps, context, *args, **kwargs):
+ raise NotImplementedError()
+
+ def activate(self):
+ pass
+
+ def deactivate(self):
+ pass
+
+
+def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
+ if current_unet is not None:
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
+
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+
diff --git a/modules/shared.py b/modules/shared.py index b9f28036..acec7f18 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -314,6 +314,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
+ "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
@@ -403,6 +404,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -419,10 +421,11 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
- "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+ "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -487,6 +490,7 @@ options_templates.update(options_section(('ui', "User interface"), { "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order").needs_restart(),
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
+ "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
}))
options_templates.update(options_section(('infotext', "Infotext"), {
@@ -515,6 +519,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
diff --git a/modules/shared_items.py b/modules/shared_items.py index 2a8713c8..7f306a06 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -29,3 +29,14 @@ def cross_attention_optimizations(): return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
+def sd_unet_items():
+ import modules.sd_unet
+
+ return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
+
+
+def refresh_unet_list():
+ import modules.sd_unet
+
+ modules.sd_unet.list_unets()
+
diff --git a/modules/ui.py b/modules/ui.py index e62182da..6189ceeb 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -505,10 +505,10 @@ def create_ui(): with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
with gr.Row():
- hr_prompt = gr.Textbox(label="Prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
+ hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80):
with gr.Row():
- hr_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
+ hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
elif category == "batch":
if not opts.dimensions_and_batch_together:
@@ -616,7 +616,8 @@ def create_ui(): outputs=[
txt2img_prompt,
txt_prompt_img
- ]
+ ],
+ show_progress=False,
)
enable_hr.change(
@@ -902,7 +903,8 @@ def create_ui(): outputs=[
img2img_prompt,
img2img_prompt_img
- ]
+ ],
+ show_progress=False,
)
img2img_args = dict(
diff --git a/modules/ui_common.py b/modules/ui_common.py index 27ab3ebb..5a9204a4 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -50,9 +50,10 @@ def save_files(js_data, images, do_make_zip, index): save_to_dirs = shared.opts.use_save_to_dirs_for_ui
extension: str = shared.opts.samples_format
start_index = 0
+ only_one = False
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
+ only_one = True
images = [images[index]]
start_index = index
@@ -70,6 +71,7 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image)
+ p.batch_index = image_index-1
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
@@ -83,7 +85,10 @@ def save_files(js_data, images, do_make_zip, index): # Make Zip
if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
+ zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
+ namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
+ zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
+ zip_filepath = os.path.join(path, f"{zip_filename}.zip")
from zipfile import ZipFile
with ZipFile(zip_filepath, "w") as zip_file:
diff --git a/requirements.txt b/requirements.txt index 34e4520d..a464447b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ blendmodes accelerate
basicsr
gfpgan
-gradio==3.31.0
+gradio==3.32.0
numpy
omegaconf
opencv-contrib-python
diff --git a/requirements_versions.txt b/requirements_versions.txt index de501fda..31b179a9 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -3,7 +3,7 @@ transformers==4.25.1 accelerate==0.18.0
basicsr==1.4.2
gfpgan==1.3.8
-gradio==3.31.0
+gradio==3.32.0
numpy==1.23.5
Pillow==9.5.0
realesrgan==0.3.0
@@ -19,35 +19,79 @@ function get_uiCurrentTabContent() { } var uiUpdateCallbacks = []; +var uiAfterUpdateCallbacks = []; var uiLoadedCallbacks = []; var uiTabChangeCallbacks = []; var optionsChangedCallbacks = []; +var uiAfterUpdateTimeout = null; var uiCurrentTab = null; +/** + * Register callback to be called at each UI update. + * The callback receives an array of MutationRecords as an argument. + */ function onUiUpdate(callback) { uiUpdateCallbacks.push(callback); } + +/** + * Register callback to be called soon after UI updates. + * The callback receives no arguments. + * + * This is preferred over `onUiUpdate` if you don't need + * access to the MutationRecords, as your function will + * not be called quite as often. + */ +function onAfterUiUpdate(callback) { + uiAfterUpdateCallbacks.push(callback); +} + +/** + * Register callback to be called when the UI is loaded. + * The callback receives no arguments. + */ function onUiLoaded(callback) { uiLoadedCallbacks.push(callback); } + +/** + * Register callback to be called when the UI tab is changed. + * The callback receives no arguments. + */ function onUiTabChange(callback) { uiTabChangeCallbacks.push(callback); } + +/** + * Register callback to be called when the options are changed. + * The callback receives no arguments. + * @param callback + */ function onOptionsChanged(callback) { optionsChangedCallbacks.push(callback); } -function runCallback(x, m) { - try { - x(m); - } catch (e) { - (console.error || console.log).call(console, e.message, e); +function executeCallbacks(queue, arg) { + for (const callback of queue) { + try { + callback(arg); + } catch (e) { + console.error("error running callback", callback, ":", e); + } } } -function executeCallbacks(queue, m) { - queue.forEach(function(x) { - runCallback(x, m); - }); + +/** + * Schedule the execution of the callbacks registered with onAfterUiUpdate. + * The callbacks are executed after a short while, unless another call to this function + * is made before that time. IOW, the callbacks are executed only once, even + * when there are multiple mutations observed. + */ +function scheduleAfterUiUpdateCallbacks() { + clearTimeout(uiAfterUpdateTimeout); + uiAfterUpdateTimeout = setTimeout(function() { + executeCallbacks(uiAfterUpdateCallbacks); + }, 200); } var executedOnLoaded = false; @@ -60,6 +104,7 @@ document.addEventListener("DOMContentLoaded", function() { } executeCallbacks(uiUpdateCallbacks, m); + scheduleAfterUiUpdateCallbacks(); const newTab = get_uiCurrentTab(); if (newTab && (newTab !== uiCurrentTab)) { uiCurrentTab = newTab; diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index da820b39..7821cc65 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts
import gradio as gr
-from modules import images, sd_samplers, processing, sd_models, sd_vae
+from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, state
import modules.shared as shared
@@ -220,6 +220,10 @@ axis_options = [ AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
+ AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
+ AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
+ AxisOption("Schedule rho", float, apply_override("rho")),
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
diff --git a/webui-user.sh b/webui-user.sh index 49a426ff..70306c60 100644 --- a/webui-user.sh +++ b/webui-user.sh @@ -36,7 +36,6 @@ # Fixed git commits #export STABLE_DIFFUSION_COMMIT_HASH="" -#export TAMING_TRANSFORMERS_COMMIT_HASH="" #export CODEFORMER_COMMIT_HASH="" #export BLIP_COMMIT_HASH="" @@ -20,7 +20,7 @@ import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors # noqa: F401
+from modules import paths, timer, import_hook, errors, devices # noqa: F401
startup_timer = timer.Timer()
@@ -58,6 +58,7 @@ import modules.sd_hijack import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
+import modules.sd_unet
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
@@ -291,6 +292,9 @@ def initialize_rest(*, reload_script_modules=False): modules.sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
+ modules.sd_unet.list_unets()
+ startup_timer.record("scripts list_unets")
+
def load_model():
"""
Accesses shared.sd_model property to load model.
@@ -306,6 +310,8 @@ def initialize_rest(*, reload_script_modules=False): Thread(target=load_model).start()
+ Thread(target=devices.first_time_calculation).start()
+
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
@@ -381,17 +387,6 @@ def webui(): gradio_auth_creds = list(get_gradio_auth_creds()) or None
- # this restores the missing /docs endpoint
- if launch_api and not hasattr(FastAPI, 'original_setup'):
- # TODO: replace this with `launch(app_kwargs=...)` if https://github.com/gradio-app/gradio/pull/4282 gets merged
- def fastapi_setup(self):
- self.docs_url = "/docs"
- self.redoc_url = "/redoc"
- self.original_setup()
-
- FastAPI.original_setup = FastAPI.setup
- FastAPI.setup = fastapi_setup
-
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
@@ -404,6 +399,10 @@ def webui(): inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path,
+ app_kwargs={
+ "docs_url": "/docs",
+ "redoc_url": "/redoc",
+ },
)
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])
@@ -124,9 +124,12 @@ case "$gpu_info" in *) ;; esac -if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] +if ! echo "$gpu_info" | grep -q "NVIDIA"; then - export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2" + if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] + then + export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2" + fi fi for preq in "${GIT}" "${python_cmd}" @@ -190,7 +193,7 @@ fi # Try using TCMalloc on Linux prepare_tcmalloc() { if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then - TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)" + TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" if [[ ! -z "${TCMALLOC}" ]]; then echo "Using TCMalloc: ${TCMALLOC}" export LD_PRELOAD="${TCMALLOC}" |