aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py3
-rw-r--r--javascript/edit-attention.js5
-rw-r--r--javascript/extensions.js2
-rw-r--r--javascript/hints.js1
-rw-r--r--javascript/imageviewer.js10
-rw-r--r--javascript/localization.js6
-rw-r--r--javascript/progressbar.js39
-rw-r--r--javascript/ui.js24
-rw-r--r--modules/extras.py150
-rw-r--r--modules/hypernetworks/hypernetwork.py7
-rw-r--r--modules/images.py5
-rw-r--r--modules/img2img.py4
-rw-r--r--modules/processing.py11
-rw-r--r--modules/progress.py11
-rw-r--r--modules/realesrgan_model.py12
-rw-r--r--modules/sd_hijack.py8
-rw-r--r--modules/sd_hijack_checkpoint.py38
-rw-r--r--modules/sd_vae.py16
-rw-r--r--modules/shared.py5
-rw-r--r--modules/styles.py12
-rw-r--r--modules/textual_inversion/textual_inversion.py6
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui.py185
-rw-r--r--modules/upscaler.py1
-rw-r--r--requirements.txt2
-rw-r--r--requirements_versions.txt2
-rw-r--r--scripts/prompts_from_file.py2
-rw-r--r--style.css114
-rw-r--r--webui.py3
-rwxr-xr-xwebui.sh2
30 files changed, 473 insertions, 217 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index 0ad49f4e..bc11cc6e 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -1,7 +1,6 @@
import os
import gc
import time
-import warnings
import numpy as np
import torch
@@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack
-warnings.filterwarnings("ignore", category=UserWarning)
-
cached_ldsr_model: torch.nn.Module = None
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index b947cbec..ccc8344a 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -69,7 +69,6 @@ addEventListener('keydown', (event) => {
target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;
}
- // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its
- // internal Svelte data binding remains in sync.
- target.dispatchEvent(new Event("input", { bubbles: true }));
+
+ updateInput(target)
});
diff --git a/javascript/extensions.js b/javascript/extensions.js
index 59179ca6..ac6e35b9 100644
--- a/javascript/extensions.js
+++ b/javascript/extensions.js
@@ -29,7 +29,7 @@ function install_extension_from_index(button, url){
textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
- textarea.dispatchEvent(new Event("input", { bubbles: true }))
+ updateInput(textarea)
gradioApp().querySelector('#install_extension_button').click()
}
diff --git a/javascript/hints.js b/javascript/hints.js
index fa5e5ae8..e746e20d 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -92,6 +92,7 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M",
+ "No interpolation": "Result = A",
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index 1f29ad7b..aac2ee82 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -148,7 +148,15 @@ function showGalleryImage() {
if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer'
e.style.userSelect='none'
- e.addEventListener('mousedown', function (evt) {
+
+ var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
+
+ // For Firefox, listening on click first switched to next image then shows the lightbox.
+ // If you know how to fix this without switching to mousedown event, please.
+ // For other browsers the event is click to make it possiblr to drag picture.
+ var event = isFirefox ? 'mousedown' : 'click'
+
+ e.addEventListener(event, function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
diff --git a/javascript/localization.js b/javascript/localization.js
index f92d2d24..1a5a1dbb 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -10,10 +10,8 @@ ignore_ids_for_localization={
modelmerger_tertiary_model_name: 'OPTION',
train_embedding: 'OPTION',
train_hypernetwork: 'OPTION',
- txt2img_style_index: 'OPTION',
- txt2img_style2_index: 'OPTION',
- img2img_style_index: 'OPTION',
- img2img_style2_index: 'OPTION',
+ txt2img_styles: 'OPTION',
+ img2img_styles: 'OPTION',
setting_random_artist_categories: 'SPAN',
setting_face_restoration_model: 'SPAN',
setting_realesrgan_enabled_models: 'SPAN',
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index da6709bc..18c771a2 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -106,6 +106,19 @@ function formatTime(secs){
}
}
+function setTitle(progress){
+ var title = 'Stable Diffusion'
+
+ if(opts.show_progress_in_title && progress){
+ title = '[' + progress.trim() + '] ' + title;
+ }
+
+ if(document.title != title){
+ document.title = title;
+ }
+}
+
+
function randomId(){
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
}
@@ -117,7 +130,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
var dateStart = new Date()
var wasEverActive = false
var parentProgressbar = progressbarContainer.parentNode
- var parentGallery = gallery.parentNode
+ var parentGallery = gallery ? gallery.parentNode : null
var divProgress = document.createElement('div')
divProgress.className='progressDiv'
@@ -128,13 +141,16 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
divProgress.appendChild(divInner)
parentProgressbar.insertBefore(divProgress, progressbarContainer)
- var livePreview = document.createElement('div')
- livePreview.className='livePreview'
- parentGallery.insertBefore(livePreview, gallery)
+ if(parentGallery){
+ var livePreview = document.createElement('div')
+ livePreview.className='livePreview'
+ parentGallery.insertBefore(livePreview, gallery)
+ }
var removeProgressBar = function(){
+ setTitle("")
parentProgressbar.removeChild(divProgress)
- parentGallery.removeChild(livePreview)
+ if(parentGallery) parentGallery.removeChild(livePreview)
atEnd()
}
@@ -154,6 +170,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
progressText = ""
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
+ divInner.style.background = res.progress ? "" : "transparent"
if(res.progress > 0){
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
@@ -161,8 +178,13 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
if(res.eta){
progressText += " ETA: " + formatTime(res.eta)
- } else if(res.textinfo){
- progressText += " " + res.textinfo
+ }
+
+
+ setTitle(progressText)
+
+ if(res.textinfo && res.textinfo.indexOf("\n") == -1){
+ progressText = res.textinfo + " " + progressText
}
divInner.textContent = progressText
@@ -182,8 +204,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
}
- if(res.live_preview){
-
+ if(res.live_preview && gallery){
var rect = gallery.getBoundingClientRect()
if(rect.width){
livePreview.style.width = rect.width + "px"
diff --git a/javascript/ui.js b/javascript/ui.js
index ecf97cb3..37788a3e 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -109,6 +109,13 @@ function get_extras_tab_index(){
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
}
+function get_img2img_tab_index() {
+ let res = args_to_array(arguments)
+ res.splice(-2)
+ res[0] = get_tab_index('mode_img2img')
+ return res
+}
+
function create_submit_args(args){
res = []
for(var i=0;i<args.length;i++){
@@ -165,6 +172,15 @@ function submit_img2img(){
return res
}
+function modelmerger(){
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
+
+ var res = create_submit_args(arguments)
+ res[0] = id
+ return res
+}
+
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:')
@@ -278,3 +294,11 @@ function restart_reload(){
return []
}
+
+// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
+// will only visible on web page and not sent to python.
+function updateInput(target){
+ let e = new Event("input", { bubbles: true })
+ Object.defineProperty(e, "target", {value: target})
+ target.dispatchEvent(e);
+}
diff --git a/modules/extras.py b/modules/extras.py
index 22668fcd..d03f976e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass
-from modules import processing, shared, images, devices, sd_models, sd_samplers
+from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@@ -251,7 +251,8 @@ def run_pnginfo(image):
def create_config(ckpt_result, config_source, a, b, c):
def config(x):
- return sd_models.find_checkpoint_config(x) if x else None
+ res = sd_models.find_checkpoint_config(x) if x else None
+ return res if res != shared.sd_default_config else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
@@ -274,10 +275,25 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename)
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
+chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
+
+def to_half(tensor, enable):
+ if enable and tensor.dtype == torch.float:
+ return tensor.half()
+
+ return tensor
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
shared.state.begin()
shared.state.job = 'model-merge'
+ def fail(message):
+ shared.state.textinfo = message
+ shared.state.end()
+ return [*[gr.update() for _ in range(4)], message]
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -287,49 +303,89 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
- primary_model_info = sd_models.checkpoints_list[primary_model_name]
- secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
- result_is_inpainting_model = False
+ def filename_weighed_sum():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ Ma = round(1 - multiplier, 2)
+ Mb = round(multiplier, 2)
+
+ return f"{Ma}({a}) + {Mb}({b})"
+
+ def filename_add_differnece():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ c = tertiary_model_info.model_name
+ M = round(multiplier, 2)
+
+ return f"{a} + {M}({b} - {c})"
+
+ def filename_nothing():
+ return primary_model_info.model_name
theta_funcs = {
- "Weighted sum": (None, weighted_sum),
- "Add difference": (get_difference, add_difference),
+ "Weighted sum": (filename_weighed_sum, None, weighted_sum),
+ "Add difference": (filename_add_differnece, get_difference, add_difference),
+ "No interpolation": (filename_nothing, None, None),
}
- theta_func1, theta_func2 = theta_funcs[interp_method]
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
- if theta_func1 and not tertiary_model_info:
- shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
- shared.state.end()
- return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ if not primary_model_name:
+ return fail("Failed: Merging requires a primary model.")
+
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
+
+ if theta_func2 and not secondary_model_name:
+ return fail("Failed: Merging requires a secondary model.")
+
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
+
+ if theta_func1 and not tertiary_model_name:
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
- shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
- print(f"Loading {secondary_model_info.filename}...")
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
+
+ result_is_inpainting_model = False
+
+ if theta_func2:
+ shared.state.textinfo = f"Loading B"
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ else:
+ theta_1 = None
if theta_func1:
+ shared.state.textinfo = f"Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+ shared.state.textinfo = 'Merging B and C'
+ shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
+
+ shared.state.sampling_step += 1
del theta_2
+ shared.state.nextjob()
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...")
-
- chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
-
+ shared.state.textinfo = 'Merging A and B'
+ shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
- if 'model' in key and key in theta_1:
+ if theta_1 and 'model' in key and key in theta_1:
if key in chckpoint_dict_skip_on_merge:
continue
@@ -337,7 +393,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
a = theta_0[key]
b = theta_1[key]
- shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -352,36 +407,39 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
else:
theta_0[key] = theta_func2(a, b, multiplier)
- if save_as_half:
- theta_0[key] = theta_0[key].half()
+ theta_0[key] = to_half(theta_0[key], save_as_half)
- # I believe this part should be discarded, but I'll leave it for now until I am sure
- for key in theta_1.keys():
- if 'model' in key and key not in theta_0:
+ shared.state.sampling_step += 1
- if key in chckpoint_dict_skip_on_merge:
- continue
-
- theta_0[key] = theta_1[key]
- if save_as_half:
- theta_0[key] = theta_0[key].half()
del theta_1
- ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
+ if bake_in_vae_filename is not None:
+ print(f"Baking in VAE from {bake_in_vae_filename}")
+ shared.state.textinfo = 'Baking in VAE'
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
- filename = \
- primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
- secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
- interp_method.replace(" ", "_") + \
- '-merged.' + \
- ("inpainting." if result_is_inpainting_model else "") + \
- checkpoint_format
+ for key in vae_dict.keys():
+ theta_0_key = 'first_stage_model.' + key
+ if theta_0_key in theta_0:
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
+
+ del vae_dict
+
+ if save_as_half and not theta_func2:
+ for key in theta_0.keys():
+ theta_0[key] = to_half(theta_0[key], save_as_half)
+
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
+ filename = filename_generator() if custom_name == '' else custom_name
+ filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
- shared.state.textinfo = f"Saving to {output_modelname}..."
+ shared.state.nextjob()
+ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -394,8 +452,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
- print("Checkpoint saved.")
- shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ print(f"Checkpoint saved to {output_modelname}.")
+ shared.state.textinfo = "Checkpoint saved"
shared.state.end()
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c963fc40..74e78582 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -12,7 +12,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close()
hypernetwork.eval()
#report_statistics(loss_dict)
+ sd_hijack_checkpoint.remove()
+
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
diff --git a/modules/images.py b/modules/images.py
index c3a5fc8b..3b1c5f34 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -605,8 +605,9 @@ def read_info_from_image(image):
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
- items['exif comment'] = exif_comment
- geninfo = exif_comment
+ if exif_comment:
+ items['exif comment'] = exif_comment
+ geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
diff --git a/modules/img2img.py b/modules/img2img.py
index f4a03c57..2168c8e2 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
@@ -101,7 +101,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt,
negative_prompt=negative_prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
diff --git a/modules/processing.py b/modules/processing.py
index 9c3673de..a3e9f709 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -538,10 +538,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.process(p)
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
infotexts = []
output_images = []
@@ -572,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
if state.job_count == -1:
state.job_count = p.n_iter
@@ -857,7 +857,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
diff --git a/modules/progress.py b/modules/progress.py
index 3327b883..c69ecf3d 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -67,10 +67,13 @@ def progressapi(req: ProgressRequest):
progress = 0
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+ job_count, job_no = shared.state.job_count, shared.state.job_no
+ sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
+
+ if job_count > 0:
+ progress += job_no / job_count
+ if sampling_steps > 0 and job_count > 0:
+ progress += 1 / job_count * sampling_step / sampling_steps
progress = min(progress, 1)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 3ac0b97a..47f70251 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -38,13 +38,13 @@ class UpscalerRealESRGAN(Upscaler):
return img
info = self.load_model(path)
- if not os.path.exists(info.data_path):
+ if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
- model_path=info.data_path,
+ model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
@@ -58,17 +58,13 @@ class UpscalerRealESRGAN(Upscaler):
def load_model(self, path):
try:
- info = None
- for scaler in self.scalers:
- if scaler.data_path == path:
- info = scaler
+ info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
if info is None:
print(f"Unable to find model info: {path}")
return None
- model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
- info.data_path = model_file
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6b0d95af..870eba88 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -69,12 +69,6 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
-def fix_checkpoint():
- ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
-
-
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -106,8 +100,6 @@ class StableDiffusionModelHijack:
self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
-
- fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py
index 5712972f..2604d969 100644
--- a/modules/sd_hijack_checkpoint.py
+++ b/modules/sd_hijack_checkpoint.py
@@ -1,10 +1,46 @@
from torch.utils.checkpoint import checkpoint
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.openaimodel
+
+
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
+
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
+
def ResBlock_forward(self, x, emb):
- return checkpoint(self._forward, x, emb) \ No newline at end of file
+ return checkpoint(self._forward, x, emb)
+
+
+stored = []
+
+
+def add():
+ if len(stored) != 0:
+ return
+
+ stored.extend([
+ ldm.modules.attention.BasicTransformerBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward
+ ])
+
+ ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
+
+
+def remove():
+ if len(stored) == 0:
+ return
+
+ ldm.modules.attention.BasicTransformerBlock.forward = stored[0]
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]
+
+ stored.clear()
+
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index b2af2ce7..4ce238b8 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -72,6 +72,13 @@ def refresh_vae_list():
os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
]
+ if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):
+ paths += [
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),
+ os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),
+ ]
+
candidates = []
for path in paths:
candidates += glob.iglob(path, recursive=True)
@@ -113,6 +120,12 @@ def resolve_vae(checkpoint_file):
return None, None
+def load_vae_dict(filename, map_location):
+ vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ return vae_dict_1
+
+
def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, loaded_vae_file
# save_settings = False
@@ -130,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)
- vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
diff --git a/modules/shared.py b/modules/shared.py
index a708f23c..29b28bff 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -20,12 +20,14 @@ from modules.paths import models_path, script_path, sd_path
demo = None
+sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
+parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@@ -369,6 +371,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
}))
options_templates.update(options_section(('system', "System"), {
+ "show_warnings": OptionInfo(False, "Show warnings in console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
diff --git a/modules/styles.py b/modules/styles.py
index ce6e71ca..990d5623 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -40,12 +40,18 @@ def apply_styles_to_prompt(prompt, styles):
class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "")
- self.styles = {"None": self.no_style}
+ self.styles = {}
+ self.path = path
- if not os.path.exists(path):
+ self.reload()
+
+ def reload(self):
+ self.styles.clear()
+
+ if not os.path.exists(self.path):
return
- with open(path, "r", encoding="utf-8-sig", newline='') as file:
+ with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file)
for row in reader:
# Support loading old CSV format with "name, text"-columns
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 7e4a6d24..5a7be422 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -15,7 +15,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
return embedding, filename
+
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
diff --git a/modules/txt2img.py b/modules/txt2img.py
index ca5d4550..e945fd69 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,13 +8,13 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
negative_prompt=negative_prompt,
seed=seed,
subseed=subseed,
diff --git a/modules/ui.py b/modules/ui.py
index bbce9acd..af416d5f 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -11,6 +11,7 @@ import tempfile
import time
import traceback
from functools import partial, reduce
+import warnings
import gradio as gr
import gradio.routes
@@ -19,7 +20,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
+warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
@@ -180,7 +183,7 @@ def add_style(name: str, prompt: str, negative_prompt: str):
# reserialize all styles every time we save them
shared.prompt_styles.save_styles(shared.styles_filename)
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
@@ -197,16 +200,38 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
-def apply_styles(prompt, prompt_neg, style1_name, style2_name):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+def apply_styles(prompt, prompt_neg, styles):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
+
+
+def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
+ if mode in {0, 1, 3, 4}:
+ return [interrogation_function(ii_singles[mode]), None]
+ elif mode == 2:
+ return [interrogation_function(ii_singles[mode]["image"]), None]
+ elif mode == 5:
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+ images = shared.listfiles(ii_input_dir)
+ print(f"Will process {len(images)} images.")
+ if ii_output_dir != "":
+ os.makedirs(ii_output_dir, exist_ok=True)
+ else:
+ ii_output_dir = ii_input_dir
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
+ for image in images:
+ img = Image.open(image)
+ filename = os.path.basename(image)
+ left, _ = os.path.splitext(filename)
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
+
+ return [gr_show(True), None]
def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB"))
-
return gr_show(True) if prompt is None else prompt
@@ -374,13 +399,10 @@ def create_toprow(is_img2img):
)
with gr.Row():
- with gr.Column(scale=1, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+ prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
+ create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
-
- return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(*args, **kwargs):
@@ -420,17 +442,16 @@ def apply_setting(key, value):
return value
-def update_generation_info(args):
- generation_info, html_info, img_index = args
+def update_generation_info(generation_info, html_info, img_index):
try:
generation_info = json.loads(generation_info)
if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
+ return html_info, gr.update()
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
except Exception:
pass
# if the json parse or anything else fails, just return the old html_info
- return html_info
+ return html_info, gr.update()
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
@@ -511,10 +532,9 @@ Requested path was: {f}
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
+ _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }",
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info],
)
save.click(
@@ -529,7 +549,8 @@ Requested path was: {f}
outputs=[
download_files,
html_log,
- ]
+ ],
+ show_progress=False,
)
save_zip.click(
@@ -588,13 +609,13 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
@@ -617,7 +638,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes"):
+ with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
@@ -625,12 +646,12 @@ def create_ui():
elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1"):
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
- with FormRow(elem_id="txt2img_hires_fix_row2"):
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
@@ -674,8 +695,7 @@ def create_ui():
dummy_component,
txt2img_prompt,
txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
+ txt2img_prompt_styles,
steps,
sampler_index,
restore_faces,
@@ -770,12 +790,12 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="img2img_settings"):
+ with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
copy_image_destinations = {}
@@ -943,8 +963,7 @@ def create_ui():
dummy_component,
img2img_prompt,
img2img_negative_prompt,
- img2img_prompt_style,
- img2img_prompt_style2,
+ img2img_prompt_styles,
init_img,
sketch,
init_img_with_mask,
@@ -983,23 +1002,37 @@ def create_ui():
show_progress=False,
)
+ interrogate_args = dict(
+ _js="get_img2img_tab_index",
+ inputs=[
+ dummy_component,
+ img2img_batch_input_dir,
+ img2img_batch_output_dir,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ init_img_inpaint,
+ ],
+ outputs=[img2img_prompt, dummy_component],
+ show_progress=False,
+ )
+
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
img2img_interrogate.click(
- fn=interrogate,
- inputs=[init_img],
- outputs=[img2img_prompt],
+ fn=lambda *args : process_interrogate(interrogate, *args),
+ **interrogate_args,
)
img2img_deepbooru.click(
- fn=interrogate_deepbooru,
- inputs=[init_img],
- outputs=[img2img_prompt],
+ fn=lambda *args : process_interrogate(interrogate_deepbooru, *args),
+ **interrogate_args,
)
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
- style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
+ style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
@@ -1009,15 +1042,15 @@ def create_ui():
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
# the same number of parameters, but we only know the style-name after the JavaScript prompt
inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
+ outputs=[txt2img_prompt_styles, img2img_prompt_styles],
)
- for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
+ for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click(
fn=apply_styles,
_js=js_func,
- inputs=[prompt, negative_prompt, style1, style2],
- outputs=[prompt, negative_prompt, style1, style2],
+ inputs=[prompt, negative_prompt, styles],
+ outputs=[prompt, negative_prompt, styles],
)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
@@ -1048,7 +1081,7 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
+ with gr.Column(variant='compact'):
with gr.Tabs(elem_id="mode_extras"):
with gr.TabItem('Single Image', elem_id="extras_single_tab"):
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
@@ -1149,10 +1182,10 @@ def create_ui():
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
+ with gr.Column(variant='compact'):
+ gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
- with FormRow():
+ with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@@ -1164,18 +1197,27 @@ def create_ui():
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
- interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
- config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+ with FormRow():
+ with gr.Column():
+ config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
- modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+ with gr.Column():
+ with FormRow():
+ bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
+ create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
- with gr.Column(variant='panel'):
- submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
+ with gr.Row():
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+
+ with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
+ with gr.Group(elem_id="modelmerger_results_panel"):
+ modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
@@ -1576,6 +1618,7 @@ def create_ui():
previous_section = None
current_tab = None
+ current_row = None
with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
@@ -1584,10 +1627,14 @@ def create_ui():
elem_id, text = item.section
if current_tab is not None:
+ current_row.__exit__()
current_tab.__exit__()
+ gr.Group()
current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
previous_section = item.section
@@ -1602,6 +1649,7 @@ def create_ui():
components.append(component)
if current_tab is not None:
+ current_row.__exit__()
current_tab.__exit__()
with gr.TabItem("Actions"):
@@ -1683,7 +1731,7 @@ def create_ui():
interfaces += [(extensions_interface, "Extensions", "extensions")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
- with gr.Row(elem_id="quicksettings"):
+ with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
@@ -1739,12 +1787,15 @@ def create_ui():
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list
- return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
+ modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
modelmerger_merge.click(
- fn=modelmerger,
+ fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
+ _js='modelmerger',
inputs=[
+ dummy_component,
primary_model_name,
secondary_model_name,
tertiary_model_name,
@@ -1754,13 +1805,14 @@ def create_ui():
custom_name,
checkpoint_format,
config_source,
+ bake_in_vae,
],
outputs=[
- submit_result,
primary_model_name,
secondary_model_name,
tertiary_model_name,
component_dict['sd_model_checkpoint'],
+ modelmerger_result,
]
)
@@ -1792,7 +1844,10 @@ def create_ui():
if saved_value is None:
ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value):
- print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ pass
+
+ # this warning is generally not useful;
+ # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else:
setattr(obj, field, saved_value)
if init_field is not None:
@@ -1820,7 +1875,13 @@ def create_ui():
apply_field(x, 'value')
if type(x) == gr.Dropdown:
- apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
+ def check_dropdown(val):
+ if x.multiselect:
+ return all([value in x.choices for value in val])
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
diff --git a/modules/upscaler.py b/modules/upscaler.py
index 231680cb..a5bf5acb 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -95,6 +95,7 @@ class UpscalerData:
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
self.name = name
self.data_path = path
+ self.local_data_path = path
self.scaler = upscaler
self.scale = scale
self.model = model
diff --git a/requirements.txt b/requirements.txt
index e1dbf8e5..ef5e3472 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,7 +5,7 @@ fairscale==0.4.4
fonts
font-roboto
gfpgan
-gradio==3.15.0
+gradio==3.16.2
invisible-watermark
numpy
omegaconf
diff --git a/requirements_versions.txt b/requirements_versions.txt
index d2899292..f97ad765 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -3,7 +3,7 @@ transformers==4.19.2
accelerate==0.12.0
basicsr==1.4.2
gfpgan==1.3.8
-gradio==3.15.0
+gradio==3.16.2
numpy==1.23.3
Pillow==9.4.0
realesrgan==0.3.0
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index f3e711d7..76dc5778 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -116,7 +116,7 @@ class Script(scripts.Script):
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
- file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file"))
+ file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
diff --git a/style.css b/style.css
index 97f9402a..c10e32a1 100644
--- a/style.css
+++ b/style.css
@@ -20,7 +20,7 @@
padding-right: 0.25em;
margin: 0.1em 0;
opacity: 0%;
- cursor: default;
+ cursor: default;
}
.output-html p {margin: 0 0.5em;}
@@ -114,6 +114,7 @@
min-width: unset !important;
flex-grow: 0 !important;
padding: 0.4em 0;
+ gap: 0;
}
#roll_col > button {
@@ -141,10 +142,15 @@
min-width: 8em !important;
}
-#txt2img_style_index, #txt2img_style2_index, #img2img_style_index, #img2img_style2_index{
+#txt2img_styles, #img2img_styles{
margin-top: 1em;
}
+#txt2img_styles ul, #img2img_styles ul{
+ max-height: 35em;
+ z-index: 2000;
+}
+
.gr-form{
background: transparent;
}
@@ -154,10 +160,14 @@
margin-bottom: 0;
}
-#toprow div{
+#toprow div.gr-box, #toprow div.gr-form{
border: none;
gap: 0;
background: transparent;
+ box-shadow: none;
+}
+#toprow div{
+ gap: 0;
}
#resize_mode{
@@ -221,7 +231,10 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
background-color: rgb(31, 41, 55);
- box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55);
+ box-shadow: none;
+ border: 1px solid rgba(128, 128, 128, 0.1);
+ border-radius: 6px;
+ padding: 0.1em 0.5em;
}
#txt2img_column_batch, #img2img_column_batch{
@@ -286,30 +299,15 @@ input[type="range"]{
}
/* more gradio's garbage cleanup */
-.min-h-\[4rem\] {
- min-height: unset !important;
-}
-
-#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
- position: absolute;
- z-index: 1000;
- right: 0;
- padding-left: 5px;
- padding-right: 5px;
- display: block;
-}
-
-#txt2img_progress_row, #img2img_progress_row{
- margin-bottom: 10px;
- margin-top: -18px;
-}
+.min-h-\[4rem\] { min-height: unset !important; }
+.min-h-\[6rem\] { min-height: unset !important; }
.progressDiv{
position: absolute;
height: 20px;
top: -20px;
background: #b4c0cc;
- border-radius: 8px !important;
+ border-radius: 3px !important;
}
.dark .progressDiv{
@@ -325,9 +323,10 @@ input[type="range"]{
line-height: 20px;
padding: 0 8px 0 0;
text-align: right;
- border-radius: 8px;
+ border-radius: 3px;
overflow: visible;
white-space: nowrap;
+ padding: 0 0.5em;
}
.livePreview{
@@ -392,7 +391,7 @@ input[type="range"]{
grid-area: tile;
}
-.modalClose,
+.modalClose,
.modalZoom,
.modalTileImage {
color: white;
@@ -531,30 +530,21 @@ input[type="range"]{
gap: 0.4em;
}
-#quicksettings > div{
- border: none;
- background: none;
- flex: unset;
- gap: 1em;
-}
-
-#quicksettings > div > div{
- max-width: 32em;
+#quicksettings > div, #quicksettings > fieldset{
+ max-width: 24em;
min-width: 24em;
padding: 0;
+ border: none;
+ box-shadow: none;
+ background: none;
}
-#quicksettings > div > div > div > div > label > span {
+#quicksettings > div > div > div > label > span {
position: relative;
margin-right: 9em;
margin-bottom: -1em;
}
-#quicksettings > div > div > label > span {
- position: relative;
- margin-bottom: -1em;
-}
-
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
@@ -644,9 +634,23 @@ canvas[key="mask"] {
max-width: 2.5em;
min-width: 2.5em !important;
height: 2.4em;
- margin: 0.55em 0.7em 0.55em 0;
+ margin: 1.6em 0.7em 0.55em 0;
}
+#tab_modelmerger .gr-button-tool{
+ margin: 0.6em 0em 0.55em 0;
+}
+
+#modelmerger_results_container{
+ margin-top: 1em;
+ overflow: visible;
+}
+
+#modelmerger_models{
+ gap: 0;
+}
+
+
#quicksettings .gr-button-tool{
margin: 0;
}
@@ -689,7 +693,10 @@ footer {
font-weight: bold;
}
-#txt2img_checkboxes > div > div{
+#txt2img_checkboxes, #img2img_checkboxes{
+ margin-bottom: 0.5em;
+}
+#txt2img_checkboxes > div, #img2img_checkboxes > div{
flex: 0;
white-space: nowrap;
min-width: auto;
@@ -699,6 +706,29 @@ footer {
opacity: 0.5;
}
+.gr-compact {
+ border: none;
+}
+
+.dark .gr-compact{
+ background-color: rgb(31 41 55 / var(--tw-bg-opacity));
+ margin-left: 0.8em;
+}
+
+.gr-compact > *{
+ margin-top: 0.5em !important;
+}
+
+.gr-compact .gr-block, .gr-compact .gr-form{
+ border: none;
+ box-shadow: none;
+}
+
+.gr-compact .gr-box{
+ border-radius: .5rem !important;
+ border-width: 1px !important;
+}
+
#mode_img2img > div > div{
gap: 0 !important;
}
@@ -794,4 +824,4 @@ Then, you will need to add the RTL counterpart only if needed in the rtl section
right: unset;
left: 0.5em;
}
-} \ No newline at end of file
+}
diff --git a/webui.py b/webui.py
index 4624fe18..865a7300 100644
--- a/webui.py
+++ b/webui.py
@@ -158,7 +158,7 @@ def webui():
shared.demo = modules.ui.create_ui()
- app, local_url, share_url = shared.demo.queue(default_enabled=False).launch(
+ app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
server_port=cmd_opts.port,
@@ -188,7 +188,6 @@ def webui():
create_api(app)
modules.script_callbacks.app_started_callback(shared.demo, app)
- modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
print('Restarting UI...')
diff --git a/webui.sh b/webui.sh
index 6e07778f..1edf921d 100755
--- a/webui.sh
+++ b/webui.sh
@@ -165,7 +165,7 @@ else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
- gpu_info=$(lspci | grep VGA)
+ gpu_info=$(lspci 2>/dev/null | grep VGA)
if echo "$gpu_info" | grep -q "AMD"
then
if [[ -z "${TORCH_COMMAND}" ]]