diff options
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | html/licenses.html | 26 | ||||
-rw-r--r-- | javascript/extraNetworks.js | 31 | ||||
-rw-r--r-- | modules/config_states.py | 2 | ||||
-rw-r--r-- | modules/extensions.py | 19 | ||||
-rw-r--r-- | modules/extras.py | 12 | ||||
-rw-r--r-- | modules/processing.py | 20 | ||||
-rw-r--r-- | modules/sd_models.py | 1 | ||||
-rw-r--r-- | modules/sd_samplers.py | 8 | ||||
-rw-r--r-- | modules/sd_samplers_common.py | 35 | ||||
-rw-r--r-- | modules/sd_samplers_compvis.py | 8 | ||||
-rw-r--r-- | modules/sd_samplers_kdiffusion.py | 22 | ||||
-rw-r--r-- | modules/sd_vae_taesd.py | 88 | ||||
-rw-r--r-- | modules/shared.py | 4 | ||||
-rw-r--r-- | modules/ui.py | 10 | ||||
-rw-r--r-- | modules/ui_extensions.py | 26 | ||||
-rw-r--r-- | modules/ui_extra_networks.py | 19 | ||||
-rw-r--r-- | style.css | 2 | ||||
-rw-r--r-- | webui.py | 38 |
19 files changed, 281 insertions, 91 deletions
@@ -158,5 +158,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
+- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/html/licenses.html b/html/licenses.html index bc995aa0..ef6f2c0a 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
+</pre>
+
+<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
+<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
+<pre>
+MIT License
+
+Copyright (c) 2023 Ollin Boer Bohan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
</pre>
\ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index c85bc79a..4d9a522c 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -68,18 +68,27 @@ var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g; function tryToRemoveExtraNetworkFromPrompt(textarea, text){
var m = text.match(re_extranet)
- if(! m) return false
-
- var partToSearch = m[1]
var replaced = false
- var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
- m = found.match(re_extranet);
- if(m[1] == partToSearch){
- replaced = true;
- return ""
- }
- return found;
- })
+ var newTextareaText
+ if(m) {
+ var partToSearch = m[1]
+ newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
+ m = found.match(re_extranet);
+ if(m[1] == partToSearch){
+ replaced = true;
+ return ""
+ }
+ return found;
+ })
+ } else {
+ newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found){
+ if(found == text) {
+ replaced = true;
+ return ""
+ }
+ return found;
+ })
+ }
if(replaced){
textarea.value = newTextareaText
diff --git a/modules/config_states.py b/modules/config_states.py index 75da862a..db65bcdb 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -83,6 +83,8 @@ def get_extension_config(): ext_config = {} for ext in extensions.extensions: + ext.read_info_from_repo() + entry = { "name": ext.name, "path": ext.path, diff --git a/modules/extensions.py b/modules/extensions.py index bc2c0450..359a7aa5 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,8 +1,8 @@ import os
import sys
+import threading
import traceback
-import time
import git
from modules import shared
@@ -24,6 +24,8 @@ def active(): class Extension:
+ lock = threading.Lock()
+
def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
self.path = path
@@ -42,8 +44,13 @@ class Extension: if self.is_builtin or self.have_info_from_repo:
return
- self.have_info_from_repo = True
+ with self.lock:
+ if self.have_info_from_repo:
+ return
+ self.do_read_info_from_repo()
+
+ def do_read_info_from_repo(self):
repo = None
try:
if os.path.exists(os.path.join(self.path, ".git")):
@@ -58,18 +65,18 @@ class Extension: try:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
- head = repo.head.commit
self.commit_date = repo.head.commit.committed_date
- ts = time.asctime(time.gmtime(self.commit_date))
if repo.active_branch:
self.branch = repo.active_branch.name
- self.commit_hash = head.hexsha
- self.version = f'{self.commit_hash[:8]} ({ts})'
+ self.commit_hash = repo.head.commit.hexsha
+ self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
except Exception as ex:
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
self.remote = None
+ self.have_info_from_repo = True
+
def list_files(self, subdir, extension):
from modules import scripts
diff --git a/modules/extras.py b/modules/extras.py index bdf9b3b7..830b53aa 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -242,9 +242,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
- metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
+ metadata = None
if save_metadata:
+ metadata = {"format": "pt"}
+
merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256,
@@ -262,15 +264,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ }
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+ sd_merge_models = {}
+
def add_model_metadata(checkpoint_info):
checkpoint_info.calculate_shorthash()
- metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ sd_merge_models[checkpoint_info.sha256] = {
"name": checkpoint_info.name,
"legacy_hash": checkpoint_info.hash,
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
}
- metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+ sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
add_model_metadata(primary_model_info)
if secondary_model_info:
@@ -278,7 +282,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if tertiary_model_info:
add_model_metadata(tertiary_model_info)
- metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+ metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
diff --git a/modules/processing.py b/modules/processing.py index 94fe2625..cd63b9a6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -316,6 +316,7 @@ class Processed: self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
+ self.s_min_uncond = p.s_min_uncond
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
@@ -480,6 +481,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
enable_hr = getattr(p, 'enable_hr', False)
+ uses_ensd = opts.eta_noise_seed_delta != 0
+ if uses_ensd:
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
+
generation_params = {
"Steps": p.steps,
"Sampler": p.sampler_name,
@@ -496,17 +501,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
- "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
"Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
+ **p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
}
- generation_params.update(p.extra_generation_params)
-
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
@@ -678,12 +682,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- step_multiplier = 1
- if not shared.opts.dont_fix_second_order_samplers_schedule:
- try:
- step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
- except Exception:
- pass
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+ step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
diff --git a/modules/sd_models.py b/modules/sd_models.py index dddbc6e1..e612be10 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -540,7 +540,6 @@ def reload_model_weights(sd_model=None, info=None): if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
- checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 4f1bf21d..f22aad8f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -14,12 +14,18 @@ samplers_for_img2img = [] samplers_map = {}
-def create_sampler(name, model):
+def find_sampler_config(name):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
+ return config
+
+
+def create_sampler(name, model):
+ config = find_sampler_config(name)
+
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index bc074238..763829f1 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx
+from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
from modules.shared import opts, state
import modules.shared as shared
@@ -22,7 +22,7 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
def single_sample_to_image(sample, approximation=None):
@@ -30,15 +30,19 @@ def single_sample_to_image(sample, approximation=None): approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
+ x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
+ elif approximation == 3:
+ x_sample = sample * 1.5
+ x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
+
return Image.fromarray(x_sample)
@@ -58,6 +62,25 @@ def store_latent(decoded): shared.state.assign_current_image(sample_to_image(decoded))
+def is_sampler_using_eta_noise_seed_delta(p):
+ """returns whether sampler from config will use eta noise seed delta for image creation"""
+
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+
+ eta = p.eta
+
+ if eta is None and p.sampler is not None:
+ eta = p.sampler.eta
+
+ if eta is None and sampler_config is not None:
+ eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
+
+ if eta == 0:
+ return False
+
+ return sampler_config.options.get("uses_ensd", False)
+
+
class InterruptedException(BaseException):
pass
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index b1ee3be7..bdae8b40 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -11,7 +11,7 @@ import modules.models.diffusion.uni_pc samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
]
@@ -134,7 +134,11 @@ class VanillaStableDiffusionSampler: self.update_step(x)
def initialize(self, p):
- self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ if self.is_ddim:
+ self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ else:
+ self.eta = 0.0
+
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 61f23ad7..552c6c64 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -11,23 +11,23 @@ from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
]
samplers_data_k_diffusion = [
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py new file mode 100644 index 00000000..5e8496e8 --- /dev/null +++ b/modules/sd_vae_taesd.py @@ -0,0 +1,88 @@ +""" +Tiny AutoEncoder for Stable Diffusion +(DNN for encoding / decoding SD's latent space) + +https://github.com/madebyollin/taesd +""" +import os +import torch +import torch.nn as nn + +from modules import devices, paths_internal + +sd_vae_taesd = None + + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + + +class Clamp(nn.Module): + @staticmethod + def forward(x): + return torch.tanh(x / 3) * 3 + + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + + +def decoder(): + return nn.Sequential( + Clamp(), conv(4, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + + +class TAESD(nn.Module): + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, decoder_path="taesd_decoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.decoder = decoder() + self.decoder.load_state_dict( + torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + + @staticmethod + def unscale_latents(x): + """[0, 1] -> raw latents""" + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + +def download_model(model_path): + model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' + + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading TAESD decoder to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + +def model(): + global sd_vae_taesd + + if sd_vae_taesd is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") + download_model(model_path) + + if os.path.exists(model_path): + sd_vae_taesd = TAESD(model_path) + sd_vae_taesd.eval() + sd_vae_taesd.to(devices.device, devices.dtype) + else: + raise FileNotFoundError('TAESD model not found') + + return sd_vae_taesd.decoder diff --git a/modules/shared.py b/modules/shared.py index 07f18b1b..165509ea 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -448,7 +448,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
- "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}).info("Full = slow but pretty; Approx NN = fast but low quality; Approx cheap = super fast but terrible otherwise"),
+ "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
}))
@@ -458,8 +458,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
+ 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
diff --git a/modules/ui.py b/modules/ui.py index ff25c4ce..8e51e782 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1841,15 +1841,15 @@ def versions_html(): return f"""
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
- •
+ • 
python: <span title="{sys.version}">{python_version}</span>
- •
+ • 
torch: {getattr(torch, '__long_version__',torch.__version__)}
- •
+ • 
xformers: {xformers_version}
- •
+ • 
gradio: {gr.__version__}
- •
+ • 
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
"""
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index af497733..d7a0f685 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,6 +1,7 @@ import json
import os.path
import sys
+import threading
import time
from datetime import datetime
import traceback
@@ -140,7 +141,9 @@ def extension_table(): <tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
<th>URL</th>
- <th><abbr title="Extension version">Version</abbr></th>
+ <th>Branch</th>
+ <th>Version</th>
+ <th>Date</th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
@@ -148,6 +151,7 @@ def extension_table(): """
for ext in extensions.extensions:
+ ext: extensions.Extension
ext.read_info_from_repo()
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
@@ -169,7 +173,9 @@ def extension_table(): <tr>
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td>{remote}</td>
+ <td>{ext.branch}</td>
<td>{version_link}</td>
+ <td>{time.asctime(time.gmtime(ext.commit_date))}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@@ -484,11 +490,18 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" return code, list(tags)
+def preload_extensions_git_metadata():
+ for extension in extensions.extensions:
+ extension.read_info_from_repo()
+
+
def create_ui():
import modules.ui
config_states.list_config_states()
+ threading.Thread(target=preload_extensions_git_metadata).start()
+
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions"):
with gr.TabItem("Installed", id="installed"):
@@ -508,7 +521,8 @@ def create_ui(): </span>
"""
info = gr.HTML(html)
- extensions_table = gr.HTML(lambda: extension_table())
+ extensions_table = gr.HTML('Loading...')
+ ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
apply.click(
fn=apply_and_restart,
@@ -579,9 +593,9 @@ def create_ui(): install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
- fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
inputs=[install_dirname, install_url, install_branch],
- outputs=[extensions_table, install_result],
+ outputs=[install_url, extensions_table, install_result],
)
with gr.TabItem("Backup/Restore"):
@@ -595,7 +609,8 @@ def create_ui(): config_save_button = gr.Button(value="Save Current Config")
config_states_info = gr.HTML("")
- config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
+ config_states_table = gr.HTML("Loading...")
+ ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
@@ -608,4 +623,5 @@ def create_ui(): outputs=[config_states_table],
)
+
return ui
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 0baccf56..752cf2b8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -268,7 +268,7 @@ def create_ui(container, button, tabname): with gr.Tab(page.title, id=page_id):
elem_id = f"{tabname}_{page_id}_cards_html"
- page_elem = gr.HTML('', elem_id=elem_id)
+ page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
@@ -282,13 +282,24 @@ def create_ui(container, button, tabname): def toggle_visibility(is_visible):
is_visible = not is_visible
- if is_visible and not ui.pages_contents:
+ return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
+
+ def fill_tabs(is_empty):
+ """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
+
+ if not ui.pages_contents:
refresh()
- return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
+ if is_empty:
+ return True, *ui.pages_contents
+
+ return True, *[gr.update() for _ in ui.pages_contents]
state_visible = gr.State(value=False)
- button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
+
+ state_empty = gr.State(value=True)
+ button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
def refresh():
for pg in ui.stored_extra_pages:
@@ -363,12 +363,10 @@ div#extras_scale_to_tab div.form{ /* settings */
#quicksettings {
- width: fit-content;
align-items: end;
}
#quicksettings > div, #quicksettings > fieldset{
- max-width: 24em;
min-width: 24em;
padding: 0;
border: none;
@@ -144,16 +144,11 @@ Use --skip-version-check commandline argument to disable this check. """.strip())
-def initialize():
- fix_asyncio_event_loop_policy()
-
- check_versions()
-
- extensions.list_extensions()
- localization.list_localizations(cmd_opts.localizations_dir)
- startup_timer.record("list extensions")
-
+def restore_config_state_file():
config_state_file = shared.opts.restore_config_state_file
+ if config_state_file == "":
+ return
+
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
@@ -166,6 +161,18 @@ def initialize(): elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
+
+def initialize():
+ fix_asyncio_event_loop_policy()
+
+ check_versions()
+
+ extensions.list_extensions()
+ localization.list_localizations(cmd_opts.localizations_dir)
+ startup_timer.record("list extensions")
+
+ restore_config_state_file()
+
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
@@ -370,18 +377,7 @@ def webui(): extensions.list_extensions()
startup_timer.record("list extensions")
- config_state_file = shared.opts.restore_config_state_file
- shared.opts.restore_config_state_file = ""
- shared.opts.save(shared.config_filename)
-
- if os.path.isfile(config_state_file):
- print(f"*** About to restore extension state from file: {config_state_file}")
- with open(config_state_file, "r", encoding="utf-8") as f:
- config_state = json.load(f)
- config_states.restore_extension_config(config_state)
- startup_timer.record("restore extension config")
- elif config_state_file:
- print(f"!!! Config state backup not found: {config_state_file}")
+ restore_config_state_file()
localization.list_localizations(cmd_opts.localizations_dir)
|