From bbce167305091b34795284cabca7ab2fd56469b6 Mon Sep 17 00:00:00 2001 From: lenankamp <31517075+lenankamp@users.noreply.github.com> Date: Tue, 16 May 2023 14:37:45 -0400 Subject: Recursive batch img2img.py Searches sub directories and performs img2img batch processing, also limits inputs to jpg, webp, and png. Then saves to putput directory with relative paths. --- modules/img2img.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 9fc3a698..ad5f2e73 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -20,7 +20,13 @@ import modules.scripts def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) - images = shared.listfiles(input_dir) +# recursive batch, as written limits potential inputs to common image formats, may e better to just check if isfile for general use +images = [] + for root, directories, files in os.walk(input_dir): + for filename in files: + filepath = os.path.join(root, filename) + if filepath.endswith(".jpg") or filepath.endswith(".jpeg") or filepath.endswith(".png") or filepath.endswith(".webp"): + images.append(filepath) is_inpaint_batch = False if inpaint_mask_dir: @@ -70,16 +76,17 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): for n, processed_image in enumerate(proc.images): filename = os.path.basename(image) + relpath = os.path.dirname(os.path.relpath(image, input_dir)) if n > 0: left, right = os.path.splitext(filename) filename = f"{left}-{n}{right}" if not save_normally: - os.makedirs(output_dir, exist_ok=True) + os.makedirs(os.path.join(output_dir, relpath), exist_ok=True) if processed_image.mode == 'RGBA': processed_image = processed_image.convert("RGB") - processed_image.save(os.path.join(output_dir, filename)) + processed_image.save(os.path.join(output_dir, relpath, filename)) def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): -- cgit v1.2.1 From ff6acd35d0807a4e0c3ee86cdb1520a4a3a11cdd Mon Sep 17 00:00:00 2001 From: lenankamp <31517075+lenankamp@users.noreply.github.com> Date: Fri, 19 May 2023 03:20:19 -0400 Subject: Update img2img.py Hopefully corrected the white space issue --- modules/img2img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index ad5f2e73..d1872bed 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -20,14 +20,14 @@ import modules.scripts def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) -# recursive batch, as written limits potential inputs to common image formats, may e better to just check if isfile for general use -images = [] + images = [] for root, directories, files in os.walk(input_dir): for filename in files: filepath = os.path.join(root, filename) if filepath.endswith(".jpg") or filepath.endswith(".jpeg") or filepath.endswith(".png") or filepath.endswith(".webp"): images.append(filepath) + is_inpaint_batch = False if inpaint_mask_dir: inpaint_masks = shared.listfiles(inpaint_mask_dir) -- cgit v1.2.1 From 43bdaa2f0eda79c685792b06a2bd84c65806a48f Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 25 May 2023 18:49:28 -0600 Subject: Make ctrl+left/right optional --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index b3508883..280c4451 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -405,6 +405,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), + "keyedit_move": OptionInfo(False, "Ctrl+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}), "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), -- cgit v1.2.1 From ba70a220e3176153ba2a559acb9e5aa692dce7ca Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 5 Jun 2023 22:20:29 +0300 Subject: Remove a bunch of unused/vestigial code As found by Vulture and some eyes --- modules/api/api.py | 7 ------- modules/api/models.py | 4 ---- modules/codeformer_model.py | 4 ---- modules/devices.py | 7 ------- modules/generation_parameters_copypaste.py | 29 ----------------------------- modules/hypernetworks/hypernetwork.py | 24 ------------------------ modules/paths.py | 14 -------------- 7 files changed, 89 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..41cd7eca 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -32,13 +32,6 @@ import piexif import piexif.helper -def upscaler_to_index(name: str): - try: - return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) - except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e - - def script_name_to_index(name, scripts): try: return [script.title().lower() for script in scripts].index(name.lower()) diff --git a/modules/api/models.py b/modules/api/models.py index b3a745f0..b5683071 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel): prompt: Optional[str] = Field(title="Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt") -class ArtistItem(BaseModel): - name: str = Field(title="Name") - score: float = Field(title="Score") - category: str = Field(title="Category") class EmbeddingItem(BaseModel): step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 4260b016..a01fe63d 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -15,7 +15,6 @@ model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' -have_codeformer = False codeformer = None @@ -125,9 +124,6 @@ def setup_model(dirname): return restored_img - global have_codeformer - have_codeformer = True - global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) diff --git a/modules/devices.py b/modules/devices.py index 1ed6ffdc..620ed1a6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -15,13 +15,6 @@ def has_mps() -> bool: else: return mac_specific.has_mps -def extract_device_id(args, name): - for x in range(len(args)): - if name in args[x]: - return args[x + 1] - - return None - def get_cuda_device_string(): from modules import shared diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 1d02ffae..699b1a81 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -174,31 +174,6 @@ def send_image_and_dimensions(x): return img, w, h - -def find_hypernetwork_key(hypernet_name, hypernet_hash=None): - """Determines the config parameter name to use for the hypernet based on the parameters in the infotext. - - Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config - parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to. - - If the infotext has no hash, then a hypernet with the same name will be selected instead. - """ - hypernet_name = hypernet_name.lower() - if hypernet_hash is not None: - # Try to match the hash in the name - for hypernet_key in shared.hypernetworks.keys(): - result = re_hypernet_hash.search(hypernet_key) - if result is not None and result[1] == hypernet_hash: - return hypernet_key - else: - # Fall back to a hypernet with the same name - for hypernet_key in shared.hypernetworks.keys(): - if hypernet_key.lower().startswith(hypernet_name): - return hypernet_key - - return None - - def restore_old_hires_fix_params(res): """for infotexts that specify old First pass size parameter, convert it into width, height, and hr scale""" @@ -329,10 +304,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model return res -settings_map = {} - - - infotext_to_setting_name_mapping = [ ('Clip skip', 'CLIP_stop_at_last_layers', ), ('Conditional mask weight', 'inpainting_mask_weight'), diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5d12b449..51941c11 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -353,17 +353,6 @@ def load_hypernetworks(names, multipliers=None): shared.loaded_hypernetworks.append(hypernetwork) -def find_closest_hypernetwork_name(search: str): - if not search: - return None - search = search.lower() - applicable = [name for name in shared.hypernetworks if search in name.lower()] - if not applicable: - return None - applicable = sorted(applicable, key=lambda name: len(name)) - return applicable[0] - - def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) @@ -446,18 +435,6 @@ def statistics(data): return total_information, recent_information -def report_statistics(loss_info:dict): - keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x])) - for key in keys: - try: - print("Loss statistics for file " + key) - info, recent = statistics(list(loss_info[key])) - print(info) - print(recent) - except Exception as e: - print(e) - - def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) @@ -770,7 +747,6 @@ Last saved image: {html.escape(last_saved_image)}
pbar.leave = False pbar.close() hypernetwork.eval() - #report_statistics(loss_dict) sd_hijack_checkpoint.remove() diff --git a/modules/paths.py b/modules/paths.py index 5171df4f..bada804e 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs: else: sys.path.append(d) paths[what] = d - - -class Prioritize: - def __init__(self, name): - self.name = name - self.path = None - - def __enter__(self): - self.path = sys.path.copy() - sys.path = [paths[self.name]] + sys.path - - def __exit__(self, exc_type, exc_val, exc_tb): - sys.path = self.path - self.path = None -- cgit v1.2.1 From 8ca34ad6d8cc2502403b3b96bb811366bc13c076 Mon Sep 17 00:00:00 2001 From: Su Wei Date: Fri, 9 Jun 2023 13:14:20 +0800 Subject: add model exists status check to modeuls/api/api.py , /sdapi/v1/options [POST] --- modules/api/api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index eee99bbb..56b7858d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights,checkpoint_alisases from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -515,6 +515,11 @@ class Api: def set_config(self, req: Dict[str, Any]): for k, v in req.items(): + if k == "sd_model_checkpoint": + checkpoint_info = checkpoint_alisases.get(v, None) + if checkpoint_info is None: + print(f"model [{v}] not founded, skip config saving process") + return shared.opts.set(k, v) shared.opts.save(shared.config_filename) -- cgit v1.2.1 From 9142be0a0d8cea37cf1ae86c17fc7dcb161d9a42 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 10 Jun 2023 23:36:34 +0900 Subject: quit restart --- modules/api/api.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..317c809f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -208,6 +208,8 @@ class Api: self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -715,3 +717,10 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) + + def quit_webui(self): + restart.stop_program() + + def restart_webui(self): + if restart.is_restartable(): + restart.restart_program() -- cgit v1.2.1 From 7e2d39a2d158d1426321686b05d31a3ea694a99e Mon Sep 17 00:00:00 2001 From: Su Wei Date: Mon, 12 Jun 2023 15:22:49 +0800 Subject: update model checkpoint switch code --- modules/api/api.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 56b7858d..7d7dfe9a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -514,12 +514,11 @@ class Api: return options def set_config(self, req: Dict[str, Any]): + checkpoint_key="sd_model_checkpoint" + if checkpoint_key in req and str(req[checkpoint_key]) not in checkpoint_alisases: + raise RuntimeError(f"model {v!r} not found") + for k, v in req.items(): - if k == "sd_model_checkpoint": - checkpoint_info = checkpoint_alisases.get(v, None) - if checkpoint_info is None: - print(f"model [{v}] not founded, skip config saving process") - return shared.opts.set(k, v) shared.opts.save(shared.config_filename) -- cgit v1.2.1 From b9664ab6154818680ee25920e229b808a3cdec68 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 12 Jun 2023 18:15:27 +0900 Subject: move _stop route to api --- modules/api/api.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 317c809f..cb1cde78 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -208,8 +208,11 @@ class Api: self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) + + if shared.cmd_opts.add_stop_route: + self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) + self.add_api_route("/_stop", self.stop_route, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -724,3 +727,7 @@ class Api: def restart_webui(self): if restart.is_restartable(): restart.restart_program() + + def stop_route(request): + shared.state.server_command = "stop" + return Response("Stopping.") -- cgit v1.2.1 From 89352a2f52c6be51318192cedd86c8a342966a49 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:34:26 +0300 Subject: Move `load_file_from_url` to modelloader --- modules/esrgan_model.py | 4 +--- modules/modelloader.py | 29 ++++++++++++++++++++++++++--- modules/realesrgan_model.py | 4 ++-- 3 files changed, 29 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2fced999..f1a98c07 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -3,7 +3,6 @@ import os import numpy as np import torch from PIL import Image -from basicsr.utils.download_util import load_file_from_url import modules.esrgan_model_arch as arch from modules import modelloader, images, devices @@ -152,11 +151,10 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if "http" in path: - filename = load_file_from_url( + filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, file_name=f"{self.model_name}.pth", - progress=True, ) else: filename = path diff --git a/modules/modelloader.py b/modules/modelloader.py index be23071a..a69c8a4f 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import importlib @@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - from basicsr.utils.download_util import load_file_from_url - dl = load_file_from_url(model_url, places[0], True, download_name) - output.append(dl) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) else: output.append(model_url) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 2d27b321..0d9c2e48 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -2,7 +2,6 @@ import os import numpy as np from PIL import Image -from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from modules.upscaler import Upscaler, UpscalerData @@ -10,6 +9,7 @@ from modules.shared import cmd_opts, opts from modules import modelloader, errors + class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" @@ -71,7 +71,7 @@ class UpscalerRealESRGAN(Upscaler): return None if info.local_data_path.startswith("http"): - info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) + info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path) return info except Exception: -- cgit v1.2.1 From 0afbc0c2355ead3a0ce7149a6d678f1f2e2fbfee Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:41:36 +0300 Subject: Fix up `if "http" in ...:` to be more sensible startswiths --- modules/esrgan_model.py | 4 ++-- modules/gfpgan_model.py | 2 +- modules/modelloader.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index f1a98c07..0666a2c2 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -133,7 +133,7 @@ class UpscalerESRGAN(Upscaler): scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scalers.append(scaler_data) for file in model_paths: - if "http" in file: + if file.startswith("http"): name = self.model_name else: name = modelloader.friendly_name(file) @@ -150,7 +150,7 @@ class UpscalerESRGAN(Upscaler): return img def load_model(self, path: str): - if "http" in path: + if path.startswith("http"): filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index e239a09d..804fb53d 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -25,7 +25,7 @@ def gfpgann(): return None models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") - if len(models) == 1 and "http" in models[0]: + if len(models) == 1 and models[0].startswith("http"): model_file = models[0] elif len(models) != 0: latest_file = max(models, key=os.path.getctime) diff --git a/modules/modelloader.py b/modules/modelloader.py index a69c8a4f..b2f0bb71 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -82,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None def friendly_name(file: str): - if "http" in file: + if file.startswith("http"): file = urlparse(file).path file = os.path.basename(file) -- cgit v1.2.1 From e3a973a68df3cfe13039dae33d19cf2c02a741e0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:45:07 +0300 Subject: Add TODO comments to sus model loads --- modules/esrgan_model.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 0666a2c2..a20e8d91 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -151,6 +151,7 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if path.startswith("http"): + # TODO: this doesn't use `path` at all? filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, -- cgit v1.2.1 From bf67a5dcf44c3dbd88d1913478d4e02477915f33 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 10:38:51 +0300 Subject: Upscaler.load_model: don't return None, just use exceptions --- modules/esrgan_model.py | 14 ++++++-------- modules/realesrgan_model.py | 33 +++++++++++++++------------------ 2 files changed, 21 insertions(+), 26 deletions(-) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a20e8d91..02a1727d 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,4 +1,4 @@ -import os +import sys import numpy as np import torch @@ -6,9 +6,8 @@ from PIL import Image import modules.esrgan_model_arch as arch from modules import modelloader, images, devices -from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts - +from modules.upscaler import Upscaler, UpscalerData def mod2normal(state_dict): @@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler): self.scalers.append(scaler_data) def do_upscale(self, img, selected_model): - model = self.load_model(selected_model) - if model is None: + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) return img model.to(devices.device_esrgan) img = esrgan_upscale(model, img) @@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler): ) else: filename = path - if not os.path.exists(filename) or filename is None: - print(f"Unable to load {self.model_path} from {filename}") - return None state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 0d9c2e48..0700b853 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts from modules import modelloader, errors - class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" @@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler): if not self.enable: return img - info = self.load_model(path) - if not os.path.exists(info.local_data_path): - print(f"Unable to load RealESRGAN model: {info.name}") + try: + info = self.load_model(path) + except Exception: + errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img upsampler = RealESRGANer( @@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler): return image def load_model(self, path): - try: - info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) - - if info is None: - print(f"Unable to find model info: {path}") - return None - - if info.local_data_path.startswith("http"): - info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path) - - return info - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) - return None + for scaler in self.scalers: + if scaler.data_path == path: + if scaler.local_data_path.startswith("http"): + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + ) + if not os.path.exists(scaler.local_data_path): + raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}") + return scaler + raise ValueError(f"Unable to find model info: {path}") def load_models(self, _): return get_realesrgan_models(self) -- cgit v1.2.1 From d8071647760a2213aaf33a533addb4d84ba86816 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 13 Jun 2023 13:07:39 +0300 Subject: textual_inversion/logging.py: clean up duplicate key from sets (and sort them) (Ruff B033) --- modules/textual_inversion/logging.py | 48 +++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py index 734a4b6f..45823eb1 100644 --- a/modules/textual_inversion/logging.py +++ b/modules/textual_inversion/logging.py @@ -2,11 +2,51 @@ import datetime import json import os -saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"} -saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} -saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} +saved_params_shared = { + "batch_size", + "clip_grad_mode", + "clip_grad_value", + "create_image_every", + "data_root", + "gradient_step", + "initial_step", + "latent_sampling_method", + "learn_rate", + "log_directory", + "model_hash", + "model_name", + "num_of_dataset_images", + "steps", + "template_file", + "training_height", + "training_width", +} +saved_params_ti = { + "embedding_name", + "num_vectors_per_token", + "save_embedding_every", + "save_image_with_stored_embedding", +} +saved_params_hypernet = { + "activation_func", + "add_layer_norm", + "hypernetwork_name", + "layer_structure", + "save_hypernetwork_every", + "use_dropout", + "weight_init", +} saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet -saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} +saved_params_previews = { + "preview_cfg_scale", + "preview_height", + "preview_negative_prompt", + "preview_prompt", + "preview_sampler_index", + "preview_seed", + "preview_steps", + "preview_width", +} def save_settings_to_file(log_directory, all_params): -- cgit v1.2.1 From 5be6c026f55760039b3ebb284cc2ce85586be4ac Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 18:51:47 +0900 Subject: rename routes --- modules/api/api.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index cb1cde78..5ea1d21c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -210,9 +210,9 @@ class Api: self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) if shared.cmd_opts.add_stop_route: - self.add_api_route("/sdapi/v1/quit-webui", self.quit_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/restart-webui", self.restart_webui, methods=["POST"]) - self.add_api_route("/_stop", self.stop_route, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-terminate", self.terminate_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -721,13 +721,13 @@ class Api: self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) - def quit_webui(self): + def kill_webui(self): restart.stop_program() def restart_webui(self): if restart.is_restartable(): restart.restart_program() - def stop_route(request): + def terminate_webui(request): shared.state.server_command = "stop" return Response("Stopping.") -- cgit v1.2.1 From 49fb2a337661d1b9a80de8ff35a640083fa98d2f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 19:52:12 +0900 Subject: response 501 if not a able to restart --- modules/api/api.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5ea1d21c..4dc48a03 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -727,6 +727,7 @@ class Api: def restart_webui(self): if restart.is_restartable(): restart.restart_program() + return Response(status_code=501) def terminate_webui(request): shared.state.server_command = "stop" -- cgit v1.2.1 From 6091c4e4aa32b674f8ec755e6bd58989f09b08c5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 14 Jun 2023 19:53:08 +0900 Subject: terminate -> stop --- modules/api/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 4dc48a03..80d45ca7 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -212,7 +212,7 @@ class Api: if shared.cmd_opts.add_stop_route: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) - self.add_api_route("/sdapi/v1/server-terminate", self.terminate_webui, methods=["POST"]) + self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] @@ -729,6 +729,6 @@ class Api: restart.restart_program() return Response(status_code=501) - def terminate_webui(request): + def stop_webui(request): shared.state.server_command = "stop" return Response("Stopping.") -- cgit v1.2.1 From fa9d2ac2ff7cf6fbc73525190bd7fde724ec1fb3 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Wed, 14 Jun 2023 13:53:13 -0500 Subject: Fix gradio special args in the call queue --- modules/call_queue.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules') diff --git a/modules/call_queue.py b/modules/call_queue.py index 447bb764..64ebf868 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,3 +1,4 @@ +from functools import wraps import html import sys import threading @@ -20,6 +21,7 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): + @wraps(func) def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id @@ -47,6 +49,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + @wraps(func) def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: -- cgit v1.2.1 From 376f793bded0e7df40eafcacfd086e4e4d218bc5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 15 Jun 2023 04:23:52 +0900 Subject: git clone show progress --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 609a181e..97539e68 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -147,10 +147,10 @@ def git_clone(url, dir, name, commithash=None): return run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") + run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) return - run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") + run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True) if commithash is not None: run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") -- cgit v1.2.1 From d3c86e5178725b11a4679097f0aefb0a9fc90014 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Wed, 14 Jun 2023 14:03:44 -0500 Subject: Note the Gradio user in the Exif data --- modules/img2img.py | 5 ++++- modules/processing.py | 3 +++ modules/txt2img.py | 6 ++++-- 3 files changed, 11 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index d704bf90..83bd7857 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -2,6 +2,7 @@ import os import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError +import gradio as gr from modules import sd_samplers from modules.generation_parameters_copypaste import create_override_settings_dict @@ -78,7 +79,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -160,6 +161,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s p.scripts = modules.scripts.scripts_img2img p.script_args = args + p.user = request.username + if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index d22b353f..3e8d153e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -180,6 +180,8 @@ class StableDiffusionProcessing: self.uc = None self.c = None + self.user = None + @property def sd_model(self): return shared.sd_model @@ -578,6 +580,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, + "User": p.user, } generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) diff --git a/modules/txt2img.py b/modules/txt2img.py index 2e7d202d..6aa79f23 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic from modules.shared import opts, cmd_opts import modules.shared as shared from modules.ui import plaintext_to_html +import gradio as gr - -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step p.scripts = modules.scripts.scripts_txt2img p.script_args = args + p.user = request.username + if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) -- cgit v1.2.1 From 9ec2ba2d28bb0d8f01e19e2919b7bf2e3e864773 Mon Sep 17 00:00:00 2001 From: XiaoMeng Mai Date: Thu, 15 Jun 2023 22:43:09 +0800 Subject: Add github mirror for the download extension --- modules/ui_extensions.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 4379a641..6c717313 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -322,11 +322,21 @@ def normalize_git_url(url): return url -def install_extension_from_url(dirname, url, branch_name=None): +def install_extension_from_url(dirname, proxy, url, branch_name=None): check_access() assert url, 'No URL specified' + proxy_list = { + "none": "", + "ghproxy": "https://ghproxy.com/", + "hub.yzuu.cf": "https://hub.yzuu.cf/", + "hub.njuu.cf": "https://hub.njuu.cf/", + "hub.nuaa.cf": "https://hub.nuaa.cf/", + } + + url = proxy_list[proxy] + url + if dirname is None or dirname == "": *parts, last_part = url.split('/') last_part = normalize_git_url(last_part) @@ -346,12 +356,12 @@ def install_extension_from_url(dirname, url, branch_name=None): shutil.rmtree(tmpdir, True) if not branch_name: # if no branch is specified, use the default branch - with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], verbose=False) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() else: - with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name, verbose=False) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() @@ -593,6 +603,12 @@ def create_ui(): ) with gr.TabItem("Install from URL", id="install_from_url"): + + install_proxy = gr.Radio( + label="Install Proxy", choices=["none", "ghproxy", "hub.nuaa.cf","hub.yzuu.cf","hub.njuu.cf"], value="none", + info="If you can't access github.com, you can use a proxy to install extensions from github.com" + ) + install_url = gr.Text(label="URL for extension's git repository") install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") @@ -601,7 +617,7 @@ def create_ui(): install_button.click( fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), - inputs=[install_dirname, install_url, install_branch], + inputs=[install_dirname, install_proxy, install_url, install_branch], outputs=[install_url, extensions_table, install_result], ) -- cgit v1.2.1 From de022c4c80240a430a8099fb27a41aa505bf5b2f Mon Sep 17 00:00:00 2001 From: XiaoMeng Mai Date: Thu, 15 Jun 2023 22:59:46 +0800 Subject: Update code style --- modules/ui_extensions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 6c717313..e4423a06 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -336,7 +336,6 @@ def install_extension_from_url(dirname, proxy, url, branch_name=None): } url = proxy_list[proxy] + url - if dirname is None or dirname == "": *parts, last_part = url.split('/') last_part = normalize_git_url(last_part) @@ -603,9 +602,8 @@ def create_ui(): ) with gr.TabItem("Install from URL", id="install_from_url"): - install_proxy = gr.Radio( - label="Install Proxy", choices=["none", "ghproxy", "hub.nuaa.cf","hub.yzuu.cf","hub.njuu.cf"], value="none", + label="Install Proxy", choices=["none", "ghproxy", "hub.nuaa.cf","hub.yzuu.cf","hub.njuu.cf"], value="none", info="If you can't access github.com, you can use a proxy to install extensions from github.com" ) -- cgit v1.2.1 From 8f18e672439fa1926717df2c938e7089149f3a8b Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Thu, 15 Jun 2023 10:53:16 -0500 Subject: Add a user pattern to the filename generator --- modules/images.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 40efc96c..92b924ef 100644 --- a/modules/images.py +++ b/modules/images.py @@ -359,6 +359,7 @@ class FilenameGenerator: 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt..] 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, + 'user': lambda self: self.p.user, } default_time_format = '%Y%m%d%H%M%S' -- cgit v1.2.1 From f603275d84301b5ee952683e951dd1aad72ba615 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Thu, 15 Jun 2023 10:55:53 -0500 Subject: Add an opt-in infotext user name setting --- modules/processing.py | 2 +- modules/shared.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 3e8d153e..a0cc8db2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -580,7 +580,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, - "User": p.user, + "User": p.user if opts.add_user_name_to_info else None, } generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) diff --git a/modules/shared.py b/modules/shared.py index 271a062d..4c639a21 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -496,6 +496,7 @@ options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('infotext', "Infotext"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), + "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), })) -- cgit v1.2.1 From e9bd18c57bd83363d38c7409263fe87f3ed3a7f0 Mon Sep 17 00:00:00 2001 From: XiaoMeng Mai Date: Fri, 16 Jun 2023 00:09:54 +0800 Subject: Update call method --- modules/ui_extensions.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index e4423a06..2b542e59 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -330,12 +330,16 @@ def install_extension_from_url(dirname, proxy, url, branch_name=None): proxy_list = { "none": "", "ghproxy": "https://ghproxy.com/", - "hub.yzuu.cf": "https://hub.yzuu.cf/", - "hub.njuu.cf": "https://hub.njuu.cf/", - "hub.nuaa.cf": "https://hub.nuaa.cf/", + "yzuu": "hub.yzuu.cf", + "njuu": "hub.njuu.cf", + "nuaa": "hub.nuaa.cf", } - url = proxy_list[proxy] + url + if proxy in ['yzuu', 'njuu', 'nuaa']: + url = url.replace('github.com', proxy_list[proxy]) + elif proxy == 'ghproxy': + url = proxy_list[proxy] + url + if dirname is None or dirname == "": *parts, last_part = url.split('/') last_part = normalize_git_url(last_part) @@ -603,7 +607,7 @@ def create_ui(): with gr.TabItem("Install from URL", id="install_from_url"): install_proxy = gr.Radio( - label="Install Proxy", choices=["none", "ghproxy", "hub.nuaa.cf","hub.yzuu.cf","hub.njuu.cf"], value="none", + label="Install Proxy", choices=["none", "ghproxy", "nuaa", "yzuu", "njuu"], value="none", info="If you can't access github.com, you can use a proxy to install extensions from github.com" ) -- cgit v1.2.1 From 41363e0d27bbaa0c84eebe3c7c8451075390ec4e Mon Sep 17 00:00:00 2001 From: dhwz Date: Fri, 16 Jun 2023 18:10:15 +0200 Subject: fix very slow loading speed of .safetensors files --- modules/sd_models.py | 7 +++++-- modules/shared.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 918f6fd6..d9ac675b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -247,8 +247,11 @@ def read_metadata_from_safetensors(filename): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location or devices.get_optimal_device_name() - pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) + if not shared.opts.disable_mmap_load_safetensors: + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() + pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) + else: + pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read()) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) diff --git a/modules/shared.py b/modules/shared.py index 91c31d55..6b0ccac1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -376,6 +376,7 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), + "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files (fixes very slow loading speed in some cases)."), })) options_templates.update(options_section(('training', "Training"), { -- cgit v1.2.1 From 373ff5a217eca33607abb692b9ebfa38abb7fe33 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Fri, 16 Jun 2023 15:17:17 -0400 Subject: :bug: Allow Script to have metaclass --- modules/scripts.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index c902804b..52682fbf 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -2,6 +2,7 @@ import os import re import sys import traceback +import inspect from collections import namedtuple import gradio as gr @@ -238,7 +239,7 @@ def load_scripts(): def register_scripts_from_module(module): for script_class in module.__dict__.values(): - if type(script_class) != type: + if not inspect.isclass(script_class): continue if issubclass(script_class, Script): -- cgit v1.2.1 From 2e1710d88edc1e1a08b01c063fa386b50e5abc30 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 18 Jun 2023 14:07:41 +0900 Subject: update the description of --add-stop-rout --- modules/cmd_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index de905caa..624dcb4f 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api') -- cgit v1.2.1 From d2ccdcdc97f8e8b0a1a63f2031716b3866c7b53b Mon Sep 17 00:00:00 2001 From: George Gu Date: Mon, 19 Jun 2023 10:16:18 +0800 Subject: fix: adding elem_id for img2img resize to and resize by tabs --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 361f596e..ce019f9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -789,7 +789,7 @@ def create_ui(): selected_scale_tab = gr.State(value=0) with gr.Tabs(): - with gr.Tab(label="Resize to") as tab_scale_to: + with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: with FormRow(): with gr.Column(elem_id="img2img_column_size", scale=4): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") @@ -798,7 +798,7 @@ def create_ui(): res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") - with gr.Tab(label="Resize by") as tab_scale_by: + with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by: scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") with FormRow(): -- cgit v1.2.1 From 27e9e3f6fa41c10eb5256662ddf3643dee933810 Mon Sep 17 00:00:00 2001 From: stablegeniusdiffuser Date: Mon, 19 Jun 2023 20:36:44 +0200 Subject: Add use_main_prompt parameter to use proper metadata for batch run grids or individual images --- modules/processing.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 8da73884..1d97e95e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -549,7 +549,7 @@ def program_version(): return res -def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): +def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False): index = position_in_batch + iteration * p.batch_size clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) @@ -589,9 +589,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) + prompt_text = p.prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else "" - return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip() def process_images(p: StableDiffusionProcessing) -> Processed: @@ -663,8 +664,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] - def infotext(iteration=0, position_in_batch=0): - return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) + def infotext(iteration=0, position_in_batch=0, use_main_prompt=False): + return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() @@ -824,7 +825,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: grid = images.image_grid(output_images, p.batch_size) if opts.return_grid: - text = infotext() + text = infotext(use_main_prompt=True) infotexts.insert(0, text) if opts.enable_pnginfo: grid.info["parameters"] = text @@ -832,7 +833,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) if not p.disable_extra_networks and p.extra_network_data: extra_networks.deactivate(p, p.extra_network_data) -- cgit v1.2.1 From 928bd42da46683315c9f4498f6fbd5c59279da18 Mon Sep 17 00:00:00 2001 From: Ferdinand Weynschenk Date: Tue, 20 Jun 2023 13:33:36 +0200 Subject: PNG info support at img2img batch --- modules/img2img.py | 38 +++++++++++++++++++++++++++++++++----- modules/ui.py | 7 +++++++ 2 files changed, 40 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index d704bf90..88e172ff 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -3,8 +3,8 @@ import os import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError -from modules import sd_samplers -from modules.generation_parameters_copypaste import create_override_settings_dict +from modules import sd_samplers, images as imgutil +from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state import modules.shared as shared @@ -13,7 +13,7 @@ from modules.ui import plaintext_to_html import modules.scripts -def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): +def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) images = shared.listfiles(input_dir) @@ -34,6 +34,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): state.job_count = len(images) * p.n_iter + prompt = p.prompt + negative_prompt = p.negative_prompt + for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" if state.skipped: @@ -59,6 +62,31 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): mask_image_path = inpaint_masks[0] mask_image = Image.open(mask_image_path) p.image_mask = mask_image + + if use_png_info: + try: + info_img = img + if png_info_dir: + info_img_path = os.path.join(png_info_dir, os.path.basename(image)) + info_img = Image.open(info_img_path) + geninfo, _ = imgutil.read_info_from_image(info_img) + parsed_parameters = parse_generation_parameters(geninfo) + if("Prompt" in png_info_props): + p.prompt = prompt + " " + parsed_parameters["Prompt"] + if("Negative prompt" in png_info_props): + p.negative_prompt = negative_prompt + " " + parsed_parameters["Negative prompt"] + if("Seed" in png_info_props): + p.seed = int(parsed_parameters["Seed"]) + if("CFG scale" in png_info_props): + p.cfg_scale = float(parsed_parameters["CFG scale"]) + if("Sampler" in png_info_props): + p.sampler_name = parsed_parameters["Sampler"] + if("Steps" in png_info_props): + p.steps = int(parsed_parameters["Steps"]) + except: + p.prompt = prompt + p.negative_prompt = negative_prompt + print(f"batch png info: using ui set prompts; failed to get png info for {image}") proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: @@ -78,7 +106,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -169,7 +197,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) + process_batch(p, img2img_batch_use_png_info, img2img_batch_png_info_props, img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) processed = Processed(p, [], p.seed, "") else: diff --git a/modules/ui.py b/modules/ui.py index 361f596e..a79b0e6c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -751,6 +751,10 @@ def create_ui(): img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") + with gr.Accordion("PNG info"): + img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") + img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") + img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] @@ -943,6 +947,9 @@ def create_ui(): inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, + img2img_batch_use_png_info, + img2img_batch_png_info_props, + img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, -- cgit v1.2.1 From 7ad48120d45e678b6343f7d95a1f97607858009a Mon Sep 17 00:00:00 2001 From: Ferdinand Weynschenk Date: Tue, 20 Jun 2023 13:50:02 +0200 Subject: use ui params when retreiving png info fails Don't want to interrupt the process since batches can be huge. This makes more sense to me than using the previous images parameters --- modules/img2img.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 88e172ff..e46a6fde 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -34,8 +34,13 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp state.job_count = len(images) * p.n_iter + # extract "default" params to use in case getting png info fails prompt = p.prompt negative_prompt = p.negative_prompt + seed = p.seed + cfg_scale = p.cfg_scale + sampler_name = p.sampler_name + steps = p.steps for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" @@ -86,6 +91,10 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp except: p.prompt = prompt p.negative_prompt = negative_prompt + p.seed = seed + p.cfg_scale = cfg_scale + p.sampler_name = sampler_name + p.steps = steps print(f"batch png info: using ui set prompts; failed to get png info for {image}") proc = modules.scripts.scripts_img2img.run(p, *args) -- cgit v1.2.1 From c4c63dd5e4760c56405cef2e71abc5c3604c4578 Mon Sep 17 00:00:00 2001 From: Ferdinand Weynschenk Date: Tue, 20 Jun 2023 14:03:42 +0200 Subject: resolve linter --- modules/img2img.py | 7 ++++--- modules/ui.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index e46a6fde..f77dfd9f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -67,7 +67,7 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp mask_image_path = inpaint_masks[0] mask_image = Image.open(mask_image_path) p.image_mask = mask_image - + if use_png_info: try: info_img = img @@ -88,14 +88,15 @@ def process_batch(p, use_png_info, png_info_props, png_info_dir, input_dir, outp p.sampler_name = parsed_parameters["Sampler"] if("Steps" in png_info_props): p.steps = int(parsed_parameters["Steps"]) - except: + except Exception as e: + print(f"batch png info: using ui set prompts; failed to get png info for {image}") + print(e) p.prompt = prompt p.negative_prompt = negative_prompt p.seed = seed p.cfg_scale = cfg_scale p.sampler_name = sampler_name p.steps = steps - print(f"batch png info: using ui set prompts; failed to get png info for {image}") proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: diff --git a/modules/ui.py b/modules/ui.py index a79b0e6c..d9b21534 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -948,7 +948,7 @@ def create_ui(): inpaint_full_res_padding, inpainting_mask_invert, img2img_batch_use_png_info, - img2img_batch_png_info_props, + img2img_batch_png_info_props, img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, -- cgit v1.2.1 From dd268c48c9099c4cf308eb04590bd201c9b64253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20=28Netux=29=20Rodr=C3=ADguez?= Date: Sun, 25 Jun 2023 00:30:08 -0300 Subject: feat(extensions): add toggle all checkbox to Installed tab Small QoL addition. While there is the option to disable all extensions with the radio buttons at the top, that only acts as an added flag and doesn't really change the state of the extensions in the UI. An use case for this checkbox is to disable all extensions except for a few, which is important for debugging extensions. You could do that before, but you'd have to uncheck and recheck every extension one by one. --- modules/ui_extensions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 4379a641..50955fab 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -138,7 +138,10 @@ def extension_table(): - + @@ -170,7 +173,7 @@ def extension_table(): code += f""" - + -- cgit v1.2.1 From 9bb1fcfad43103778406ace17e6804c67fad9c17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 08:59:35 +0300 Subject: alternate fix for catch errors when retrieving extension index #11290 --- modules/ui_extensions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index f3db76f2..278bf5e4 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -571,9 +571,9 @@ def create_ui(): available_extensions_table = gr.HTML() refresh_available_extensions_button.click( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]), inputs=[available_extensions_index, hide_tags, sort_column], - outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text], + outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result], ) install_extension_button.click( -- cgit v1.2.1 From 24129368f1b732be25ef486edb2cf5a6ace66737 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 09:19:04 +0300 Subject: send tensors to the correct device when loading from safetensors file with memmap disabled for #11260 --- modules/sd_models.py | 4 +++- modules/shared.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 0391398a..f65f4e36 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -246,11 +246,13 @@ def read_metadata_from_safetensors(filename): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() + if not shared.opts.disable_mmap_load_safetensors: - device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read()) + pl_sd = {k: v.to(device) for k, v in pl_sd.items()} else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) diff --git a/modules/shared.py b/modules/shared.py index 4743a428..203ee1b9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -376,7 +376,7 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), - "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files (fixes very slow loading speed in some cases)."), + "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), })) options_templates.update(options_section(('training', "Training"), { -- cgit v1.2.1 From d06af4e517865277d0521642c2c5513af9afd76f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 09:26:18 +0300 Subject: fix and rework #11113 --- modules/api/api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index f96056b6..1d4aeff5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -522,9 +522,9 @@ class Api: return options def set_config(self, req: Dict[str, Any]): - checkpoint_key="sd_model_checkpoint" - if checkpoint_key in req and str(req[checkpoint_key]) not in checkpoint_alisases: - raise RuntimeError(f"model {v!r} not found") + checkpoint_name = req.get("sd_model_checkpoint", None) + if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases: + raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): shared.opts.set(k, v) -- cgit v1.2.1 From da14f6a6632e67cacaeaac7441344f0848f66114 Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Wed, 28 Jun 2023 10:16:44 +0700 Subject: Add options to change colors in grid --- modules/images.py | 13 +++++-------- modules/shared.py | 5 ++++- 2 files changed, 9 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 1906e2ab..320008be 100644 --- a/modules/images.py +++ b/modules/images.py @@ -10,7 +10,7 @@ import re import numpy as np import piexif import piexif.helper -from PIL import Image, ImageFont, ImageDraw, PngImagePlugin +from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin import string import json import hashlib @@ -156,10 +156,10 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0: fontsize -= 1 fnt = get_font(fontsize) - drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") + drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=ImageColor.getcolor(opts.grid_text_color_active, 'RGB') if line.is_active else ImageColor.getcolor(opts.grid_text_color_inactive, 'RGB'), anchor="mm", align="center") if not line.is_active: - drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4) + drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=ImageColor.getcolor(opts.grid_text_color_inactive, 'RGB'), width=4) draw_y += line.size[1] + line_spacing @@ -168,9 +168,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): fnt = get_font(fontsize) - color_active = (0, 0, 0) - color_inactive = (153, 153, 153) - pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4 cols = im.width // width @@ -179,7 +176,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}' assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}' - calc_img = Image.new("RGB", (1, 1), "white") + calc_img = Image.new("RGB", (1, 1), ImageColor.getcolor(opts.grid_background, 'RGB')) calc_d = ImageDraw.Draw(calc_img) for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)): @@ -200,7 +197,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 - result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white") + result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), ImageColor.getcolor(opts.grid_background, 'RGB')) for row in range(rows): for col in range(cols): diff --git a/modules/shared.py b/modules/shared.py index 203ee1b9..4a83cca4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -413,6 +413,10 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), + "font": OptionInfo("", "Font for image grids that have text"), + "grid_text_color_active": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), + "grid_text_color_inactive": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), + "grid_background": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), })) options_templates.update(options_section(('optimizations', "Optimizations"), { @@ -471,7 +475,6 @@ options_templates.update(options_section(('ui', "User interface"), { "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), - "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"), -- cgit v1.2.1 From b0ec69b360835a901a1aa57df1f7c8c9d55bf31c Mon Sep 17 00:00:00 2001 From: hako-mikan <122196982+hako-mikan@users.noreply.github.com> Date: Wed, 28 Jun 2023 18:37:08 +0900 Subject: add 'before_hr callback' script callback --- modules/processing.py | 3 +++ modules/scripts.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 8da73884..35463c37 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1074,6 +1074,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) + if self.scripts is not None: + self.scripts.before_hr(self) + samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) diff --git a/modules/scripts.py b/modules/scripts.py index 99bf836a..6485f398 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -186,6 +186,11 @@ class Script: return f'script_{tabname}{title}_{item_id}' + def before_hr(self, p ,*args): + """ + This function is called before hires fix start. + """ + pass current_basedir = paths.script_path @@ -548,6 +553,15 @@ class ScriptRunner: self.scripts[si].args_to = args_to + def before_hr(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_hr(p, *script_args) + except Exception: + errors.report(f"Error running before_hr: {script.filename}", exc_info=True) + + scripts_txt2img: ScriptRunner = None scripts_img2img: ScriptRunner = None scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None -- cgit v1.2.1 From 45ab7475d61fe42b70c37541974c03736cf73189 Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Wed, 28 Jun 2023 17:55:58 +0700 Subject: Revision --- modules/images.py | 13 +++++++++---- modules/shared.py | 6 +++--- 2 files changed, 12 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 320008be..913c3c2f 100644 --- a/modules/images.py +++ b/modules/images.py @@ -139,6 +139,11 @@ class GridAnnotation: def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): + + color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB') + color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB') + color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB') + def wrap(drawing, text, font, line_length): lines = [''] for word in text.split(): @@ -156,10 +161,10 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0: fontsize -= 1 fnt = get_font(fontsize) - drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=ImageColor.getcolor(opts.grid_text_color_active, 'RGB') if line.is_active else ImageColor.getcolor(opts.grid_text_color_inactive, 'RGB'), anchor="mm", align="center") + drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") if not line.is_active: - drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=ImageColor.getcolor(opts.grid_text_color_inactive, 'RGB'), width=4) + drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4) draw_y += line.size[1] + line_spacing @@ -176,7 +181,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}' assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}' - calc_img = Image.new("RGB", (1, 1), ImageColor.getcolor(opts.grid_background, 'RGB')) + calc_img = Image.new("RGB", (1, 1), color_background) calc_d = ImageDraw.Draw(calc_img) for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)): @@ -197,7 +202,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 - result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), ImageColor.getcolor(opts.grid_background, 'RGB')) + result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background) for row in range(rows): for col in range(cols): diff --git a/modules/shared.py b/modules/shared.py index 4a83cca4..22e6bd0b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -414,9 +414,9 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), "font": OptionInfo("", "Font for image grids that have text"), - "grid_text_color_active": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), - "grid_text_color_inactive": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), - "grid_background": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), + "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), + "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), + "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), })) options_templates.update(options_section(('optimizations', "Optimizations"), { -- cgit v1.2.1 From d22eb8a17f8d8c0e8018d9f9c71f7a96108544ee Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Wed, 28 Jun 2023 17:57:34 +0700 Subject: Fix lint --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 913c3c2f..3e6988fc 100644 --- a/modules/images.py +++ b/modules/images.py @@ -139,7 +139,7 @@ class GridAnnotation: def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): - + color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB') color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB') color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB') -- cgit v1.2.1 From f74fb5049506b85a98b02b1c2fd7361e9f751980 Mon Sep 17 00:00:00 2001 From: NoCrypt <57245077+NoCrypt@users.noreply.github.com> Date: Wed, 28 Jun 2023 20:24:57 +0700 Subject: Move change colors options to Saving images/grids --- modules/shared.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 22e6bd0b..76d8e221 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -311,6 +311,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), + "font": OptionInfo("", "Font for image grids that have text"), + "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), + "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), + "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), @@ -413,10 +417,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), - "font": OptionInfo("", "Font for image grids that have text"), - "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), - "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), - "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), })) options_templates.update(options_section(('optimizations', "Optimizations"), { -- cgit v1.2.1 From 9c2a7f1e8bafcb59e566bf568fdefe1be95905fe Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 19 Jun 2023 15:37:20 +0900 Subject: add callback after_extra_networks_activate --- modules/extra_networks.py | 3 +++ modules/scripts.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) (limited to 'modules') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 1f093df2..41799b0a 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -103,6 +103,9 @@ def activate(p, extra_network_data): except Exception as e: errors.display(e, f"activating extra network {extra_network_name}") + if p.scripts is not None: + p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data) + def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call diff --git a/modules/scripts.py b/modules/scripts.py index 99bf836a..340f1480 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -116,6 +116,21 @@ class Script: pass + def after_extra_networks_activate(self, p, *args, **kwargs): + """ + Calledafter extra networks activation, before conds calculation + allow modification of the network after extra networks activation been applied + won't be call if p.disable_extra_networks + + **kwargs will have those items: + - batch_number - index of current batch, from 0 to number of batches-1 + - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things + - seeds - list of seeds for current batch + - subseeds - list of subseeds for current batch + - extra_network_data - list of ExtraNetworkParams for current stage + """ + pass + def process_batch(self, p, *args, **kwargs): """ Same as process(), but called for every batch. @@ -483,6 +498,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) + def after_extra_networks_activate(self, p, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.after_extra_networks_activate(p, *script_args, **kwargs) + except Exception: + errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True) + def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: try: -- cgit v1.2.1 From cc9c1719786de00d4a5bfcf83be4bf2808cf0cb5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 14:21:28 +0900 Subject: rename --add-stop-route to --api-server-stop --- modules/api/api.py | 2 +- modules/cmd_args.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..adc633db 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -202,7 +202,7 @@ class Api: self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) - if shared.cmd_opts.add_stop_route: + if shared.cmd_opts.api_server_stop: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"]) self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"]) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 624dcb4f..278a605e 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -106,4 +106,4 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='enable server stop/restart/kill via api') +parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') -- cgit v1.2.1 From 0bc0e652a3d8cde5533af52c4f232c213b9989f0 Mon Sep 17 00:00:00 2001 From: hunshcn Date: Thu, 29 Jun 2023 18:12:55 +0800 Subject: sync default value of process_focal_crop_entropy_weight between ui and api --- modules/textual_inversion/preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 0d4c3f84..dbd856bd 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru from modules.textual_inversion import autocrop -def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): try: if process_caption: shared.interrogator.load() -- cgit v1.2.1 From d47324b898d057c0f854b9be891f2483a2b7001f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 19:25:18 +0900 Subject: add stars --- modules/ui_extensions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 278bf5e4..ac239d64 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -424,6 +424,7 @@ sort_ordering = [ (False, lambda x: x.get('name', 'z')), (True, lambda x: x.get('name', 'z')), (False, lambda x: 'z'), + (True, lambda x: x.get('stars', 0)), ] @@ -451,6 +452,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): name = ext.get("name", "noname") + stars = int(ext.get("stars", 0)) added = ext.get('added', 'unknown') url = ext.get("url", None) description = ext.get("description", "") @@ -478,7 +480,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" code += f""" - + @@ -562,7 +564,7 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) - sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", "stars"], type="index") with gr.Row(): search_extensions_text = gr.Text(label="Search").style(container=False) -- cgit v1.2.1 From 0416a7bfbaecab20a4ae4cd8330faee556bb3d89 Mon Sep 17 00:00:00 2001 From: Akiba Date: Thu, 29 Jun 2023 18:46:52 +0800 Subject: fix can't get current hash --- modules/launch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 97539e68..0e0dbca4 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -142,7 +142,7 @@ def git_clone(url, dir, name, commithash=None): if commithash is None: return - current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() + current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() if current_hash == commithash: return -- cgit v1.2.1 From 2ccc832b3333fe520961466aa1f05b24aafdd792 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 29 Jun 2023 22:46:59 +0900 Subject: add extensions Update Created dates with sorting --- modules/ui_extensions.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ac239d64..dff522ef 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -424,10 +424,19 @@ sort_ordering = [ (False, lambda x: x.get('name', 'z')), (True, lambda x: x.get('name', 'z')), (False, lambda x: 'z'), + (True, lambda x: x.get('commit_time', '')), + (True, lambda x: x.get('created_at', '')), (True, lambda x: x.get('stars', 0)), ] +def get_date(info: dict, key): + try: + return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d") + except (ValueError, TypeError): + return '' + + def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} @@ -454,6 +463,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" name = ext.get("name", "noname") stars = int(ext.get("stars", 0)) added = ext.get('added', 'unknown') + update_time = get_date(ext, 'commit_time') + create_time = get_date(ext, 'created_at') url = ext.get("url", None) description = ext.get("description", "") extension_tags = ext.get("tags", []) @@ -480,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" code += f""" - + @@ -564,7 +576,7 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) - sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", "stars"], type="index") + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index") with gr.Row(): search_extensions_text = gr.Text(label="Search").style(container=False) -- cgit v1.2.1 From 7f46f81dd7b517e829395734750f0eb8360675d4 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 1 Jul 2023 17:20:56 -0600 Subject: Change default seed_resize to 0 --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 8da73884..9e838aad 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -109,7 +109,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = 0, seed_resize_from_w: int = 0, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) -- cgit v1.2.1 From 74d001bc68c2106aa963e3474eee70327b8f3760 Mon Sep 17 00:00:00 2001 From: ramyma Date: Sun, 2 Jul 2023 04:59:59 +0300 Subject: Hotfix: call processing close to cleanup API generation calls --- modules/api/api.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..f10e3fe3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -335,6 +335,7 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -392,6 +393,7 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] -- cgit v1.2.1 From f44feb6a10aacc6a5ff4c9275fba2546b2858935 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:11:31 +0300 Subject: Add job argument to State.begin() --- modules/api/api.py | 14 +++++++------- modules/call_queue.py | 2 +- modules/extras.py | 3 +-- modules/interrogate.py | 3 +-- modules/postprocessing.py | 3 +-- modules/shared.py | 4 ++-- 6 files changed, 13 insertions(+), 16 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 279c384a..3ea099ad 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -327,7 +327,7 @@ class Api: p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples - shared.state.begin() + shared.state.begin(job="scripts_txt2img") if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here @@ -384,7 +384,7 @@ class Api: p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples - shared.state.begin() + shared.state.begin(job="scripts_img2img") if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here @@ -599,7 +599,7 @@ class Api: def create_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_embedding") filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used shared.state.end() @@ -610,7 +610,7 @@ class Api: def create_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="create_hypernetwork") filename = create_hypernetwork(**args) # create empty embedding shared.state.end() return models.CreateResponse(info=f"create hypernetwork filename: {filename}") @@ -620,7 +620,7 @@ class Api: def preprocess(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="preprocess") preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() return models.PreprocessResponse(info = 'preprocess complete') @@ -636,7 +636,7 @@ class Api: def train_embedding(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_embedding") apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -657,7 +657,7 @@ class Api: def train_hypernetwork(self, args: dict): try: - shared.state.begin() + shared.state.begin(job="train_hypernetwork") shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None diff --git a/modules/call_queue.py b/modules/call_queue.py index 69bf63d2..3b94f8a4 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -30,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): id_task = None with queue_lock: - shared.state.begin() + shared.state.begin(job=id_task) progress.start_task(id_task) try: diff --git a/modules/extras.py b/modules/extras.py index 830b53aa..e9c0263e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -73,8 +73,7 @@ def to_half(tensor, enable): def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): - shared.state.begin() - shared.state.job = 'model-merge' + shared.state.begin(job="model-merge") def fail(message): shared.state.textinfo = message diff --git a/modules/interrogate.py b/modules/interrogate.py index 9b2c5b60..a3ae1dd5 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -184,8 +184,7 @@ class InterrogateModels: def interrogate(self, pil_image): res = "" - shared.state.begin() - shared.state.job = 'interrogate' + shared.state.begin(job="interrogate") try: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 736315e2..544b2f72 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -9,8 +9,7 @@ from modules.shared import opts def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() - shared.state.begin() - shared.state.job = 'extras' + shared.state.begin(job="extras") image_data = [] image_names = [] diff --git a/modules/shared.py b/modules/shared.py index 203ee1b9..7df2879c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,7 +173,7 @@ class State: return obj - def begin(self): + def begin(self, job: str = "(unknown)"): self.sampling_step = 0 self.job_count = -1 self.processing_has_refined_job_count = False @@ -187,7 +187,7 @@ class State: self.interrupted = False self.textinfo = None self.time_start = time.time() - + self.job = job devices.torch_gc() def end(self): -- cgit v1.2.1 From e4303443477b9ac3c90ec4dd58a4810f7ac1eabe Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:11:49 +0300 Subject: API: use finally: for state.end() --- modules/api/api.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 3ea099ad..8b79495d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -602,37 +602,35 @@ class Api: shared.state.begin(job="create_embedding") filename = create_embedding(**args) # create empty embedding sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used - shared.state.end() return models.CreateResponse(info=f"create embedding filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create embedding error: {e}") + finally: + shared.state.end() + def create_hypernetwork(self, args: dict): try: shared.state.begin(job="create_hypernetwork") filename = create_hypernetwork(**args) # create empty embedding - shared.state.end() return models.CreateResponse(info=f"create hypernetwork filename: {filename}") except AssertionError as e: - shared.state.end() return models.TrainResponse(info=f"create hypernetwork error: {e}") + finally: + shared.state.end() def preprocess(self, args: dict): try: shared.state.begin(job="preprocess") preprocess(**args) # quick operation unless blip/booru interrogation is enabled shared.state.end() - return models.PreprocessResponse(info = 'preprocess complete') + return models.PreprocessResponse(info='preprocess complete') except KeyError as e: - shared.state.end() return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") - except AssertionError as e: - shared.state.end() + except Exception as e: return models.PreprocessResponse(info=f"preprocess error: {e}") - except FileNotFoundError as e: + finally: shared.state.end() - return models.PreprocessResponse(info=f'preprocess error: {e}') def train_embedding(self, args: dict): try: @@ -649,11 +647,11 @@ class Api: finally: if not apply_optimizations: sd_hijack.apply_optimizations() - shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError as msg: - shared.state.end() + except Exception as msg: return models.TrainResponse(info=f"train embedding error: {msg}") + finally: + shared.state.end() def train_hypernetwork(self, args: dict): try: @@ -675,9 +673,10 @@ class Api: sd_hijack.apply_optimizations() shared.state.end() return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") - except AssertionError: + except Exception as exc: + return models.TrainResponse(info=f"train embedding error: {exc}") + finally: shared.state.end() - return models.TrainResponse(info=f"train embedding error: {error}") def get_memory(self): try: -- cgit v1.2.1 From 522a8b9f629940a205812b5b023f25c051f3c8d8 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 30 Jun 2023 13:24:17 +0300 Subject: Add a status logger in modules.shared --- modules/shared.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 7df2879c..9ab9d98b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,6 +4,7 @@ import os import sys import threading import time +import logging import gradio as gr import torch @@ -18,6 +19,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi from ldm.models.diffusion.ddpm import LatentDiffusion from typing import Optional +log = logging.getLogger(__name__) + demo = None parser = cmd_args.parser @@ -144,12 +147,15 @@ class State: def request_restart(self) -> None: self.interrupt() self.server_command = "restart" + log.info("Received restart request") def skip(self): self.skipped = True + log.info("Received skip request") def interrupt(self): self.interrupted = True + log.info("Received interrupt request") def nextjob(self): if opts.live_previews_enable and opts.show_progress_every_n_steps == -1: @@ -189,8 +195,11 @@ class State: self.time_start = time.time() self.job = job devices.torch_gc() + log.info("Starting job %s", job) def end(self): + duration = time.time() - self.time_start + log.info("Ending job %s (%.2f seconds)", self.job, duration) self.job = "" self.job_count = 0 -- cgit v1.2.1 From 08f9b705cda4277aed49ed00c405ada2925e3b50 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 13:08:28 +0300 Subject: Use read_info_from_image in postprocessing Avoids bad keys such as `exif` ending up in the "PNG info" passed forward --- modules/postprocessing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 736315e2..38544c38 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -54,7 +54,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, for image, name in zip(image_data, image_names): shared.state.textinfo = name - existing_pnginfo = image.info or {} + parameters, existing_pnginfo = images.read_info_from_image(image) + if parameters: + existing_pnginfo["parameters"] = parameters pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) -- cgit v1.2.1 From b2c574891f492d00e310e387a024638a7bcf2353 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 13:09:37 +0300 Subject: read_info_from_image: add `photoshop` to ignored --- modules/images.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 1906e2ab..ac53a3c5 100644 --- a/modules/images.py +++ b/modules/images.py @@ -662,6 +662,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i return fullfn, txt_fullfn +IGNORED_INFO_KEYS = { + 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', + 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression', + 'icc_profile', 'chromaticity', 'photoshop', +} + + def read_info_from_image(image): items = image.info or {} @@ -679,9 +686,7 @@ def read_info_from_image(image): items['exif comment'] = exif_comment geninfo = exif_comment - for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', - 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression', - 'icc_profile', 'chromaticity']: + for field in IGNORED_INFO_KEYS: items.pop(field, None) if items.get("Software", None) == "NovelAI": -- cgit v1.2.1 From 96f0593c8fcfb5d31da9731d995c6d6f2ad77829 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 13:10:20 +0300 Subject: read_info_from_image: add type --- modules/images.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index ac53a3c5..74a10a7b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import pytz @@ -669,7 +671,7 @@ IGNORED_INFO_KEYS = { } -def read_info_from_image(image): +def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: items = image.info or {} geninfo = items.pop('parameters', None) -- cgit v1.2.1 From 5c6a33b3e11f5aa7b2fc56753c5a724e1351ce81 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 13:10:42 +0300 Subject: read_info_from_image: don't mutate info in passed-in image --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 74a10a7b..ec421993 100644 --- a/modules/images.py +++ b/modules/images.py @@ -672,7 +672,7 @@ IGNORED_INFO_KEYS = { def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]: - items = image.info or {} + items = (image.info or {}).copy() geninfo = items.pop('parameters', None) -- cgit v1.2.1 From 32788873176e9d79e1fffd6f89f94b6d0ec8bb91 Mon Sep 17 00:00:00 2001 From: ramyma Date: Mon, 3 Jul 2023 20:02:30 +0300 Subject: Handle cleanup in case there's an exception thrown --- modules/api/api.py | 57 +++++++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 26 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index f10e3fe3..d9278e9e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -323,19 +323,21 @@ class Api: with self.queue_lock: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - p.scripts = script_runner - p.outpath_grids = opts.outdir_txt2img_grids - p.outpath_samples = opts.outdir_txt2img_samples - - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() - p.close() + try: + p.scripts = script_runner + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples + + shared.state.begin() + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() + finally: + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -380,20 +382,23 @@ class Api: with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) - p.init_images = [decode_base64_to_image(x) for x in init_images] - p.scripts = script_runner - p.outpath_grids = opts.outdir_img2img_grids - p.outpath_samples = opts.outdir_img2img_samples + try: + p.init_images = [decode_base64_to_image(x) for x in init_images] + p.scripts = script_runner + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples + + shared.state.begin() + if selectable_scripts is not None: + p.script_args = script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here + else: + p.script_args = tuple(script_args) # Need to pass args as tuple here + processed = process_images(p) + shared.state.end() - shared.state.begin() - if selectable_scripts is not None: - p.script_args = script_args - processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here - else: - p.script_args = tuple(script_args) # Need to pass args as tuple here - processed = process_images(p) - shared.state.end() - p.close() + finally: + p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] -- cgit v1.2.1 From c1c04928596f69ddb39b8841a8435ecefb0594e9 Mon Sep 17 00:00:00 2001 From: ramyma Date: Mon, 3 Jul 2023 20:17:47 +0300 Subject: Use contextlib for closing the generation process --- modules/api/api.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index d9278e9e..e92c2938 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -30,6 +30,7 @@ from modules import devices from typing import Dict, List, Any import piexif import piexif.helper +from contextlib import closing def script_name_to_index(name, scripts): @@ -322,8 +323,7 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - try: + with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: p.scripts = script_runner p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples @@ -336,8 +336,6 @@ class Api: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) shared.state.end() - finally: - p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -381,8 +379,7 @@ class Api: args.pop('save_images', None) with self.queue_lock: - p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) - try: + with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] p.scripts = script_runner p.outpath_grids = opts.outdir_img2img_grids @@ -397,8 +394,6 @@ class Api: processed = process_images(p) shared.state.end() - finally: - p.close() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] -- cgit v1.2.1 From f731a728c68035ee36317ed0096ac5ecbfd50553 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 3 Jul 2023 11:41:10 -0600 Subject: Check seed_resize_from <= 0 --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 9e838aad..dc552121 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -109,7 +109,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = 0, seed_resize_from_w: int = 0, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -573,7 +573,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), - "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), + "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, -- cgit v1.2.1 From c602471b85d270e8c36707817d9bad92b0ff991e Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 5 Jul 2023 03:19:26 -0600 Subject: Allow gif for extra network previews --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a7d3bc79..1efd00b0 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -30,8 +30,8 @@ def fetch_file(filename: str = ""): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") ext = os.path.splitext(filename)[1].lower() - if ext not in (".png", ".jpg", ".jpeg", ".webp"): - raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.") + if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"): + raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.") # would profit from returning 304 return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) -- cgit v1.2.1 From fb661e089f24a3056b9724c580e3badc214467cc Mon Sep 17 00:00:00 2001 From: semjon00 Date: Wed, 5 Jul 2023 15:39:04 +0300 Subject: Fix throwing exception when trying to resize image with I;16 mode --- modules/images.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 1906e2ab..91e3fae2 100644 --- a/modules/images.py +++ b/modules/images.py @@ -639,12 +639,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i oversize = image.width > opts.target_side_length or image.height > opts.target_side_length if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024): ratio = image.width / image.height - + resize_to = None if oversize and ratio > 1: - image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS) + resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width) elif oversize: - image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS) + resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length) + if resize_to is not None: + try: + # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16 + image = image.resize(resize_to, LANCZOS) + except: + image = image.resize(resize_to) try: _atomically_save_image(image, fullfn_without_extension, ".jpg") except Exception as e: -- cgit v1.2.1 From daf41a273485e865c9c9ef458b2c26be4422bcb2 Mon Sep 17 00:00:00 2001 From: Hao-Wu Date: Thu, 6 Jul 2023 15:37:10 +0800 Subject: Fix warning of 'has_mps' is deprecated from PyTorch --- modules/mac_specific.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/mac_specific.py b/modules/mac_specific.py index d74c6b95..735847f5 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -4,16 +4,21 @@ from modules.sd_hijack_utils import CondFunc from packaging import version -# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. -# check `getattr` and try it for compatibility +# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, +# use check `getattr` and try it for compatibility. +# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty, +# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 def check_for_mps() -> bool: - if not getattr(torch, 'has_mps', False): - return False - try: - torch.zeros(1).to(torch.device("mps")) - return True - except Exception: - return False + if version.parse(torch.__version__) <= version.parse("2.0.1"): + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + else: + return torch.backends.mps.is_available() and torch.backends.mps.is_built() has_mps = check_for_mps() -- cgit v1.2.1 From 259967b7c60cbd2aeb091e691b5f49d9fb64b872 Mon Sep 17 00:00:00 2001 From: jovijovi Date: Thu, 6 Jul 2023 18:43:17 +0800 Subject: fix(api): convert to "RGB" if image mode is "RGBA" --- modules/api/api.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 2e49526e..6507f641 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -84,6 +84,8 @@ def encode_pil_to_base64(image): image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + if image.mode == "RGBA": + image = image.convert("RGB") parameters = image.info.get('parameters', None) exif_bytes = piexif.dump({ "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } -- cgit v1.2.1 From c258dd34a888b7c6c9e4c9bbef76732d9d7db6e7 Mon Sep 17 00:00:00 2001 From: Neil Mahseth Date: Thu, 6 Jul 2023 22:02:47 +0530 Subject: Fix UnicodeEncodeError when writing to file CLIP Interrogator Batch Mode The code snippet print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a')) raises a UnicodeEncodeError with the message "'charmap' codec can't encode character '\u016b' in position 129". This error occurs because the default encoding used by the open() function cannot handle certain Unicode characters. To fix this issue, the encoding parameter needs to be explicitly specified when opening the file. By using an appropriate encoding, such as 'utf-8', we can ensure that Unicode characters are properly encoded and written to the file. The updated code should be modified as follows: python Copy code print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8')) By making this change, the code will no longer raise the UnicodeEncodeError and will correctly handle Unicode characters during the file write operation. --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index e2e3b6da..10e35ec3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di img = Image.open(image) filename = os.path.basename(image) left, _ = os.path.splitext(filename) - print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a')) + print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8')) return [gr.update(), None] -- cgit v1.2.1 From 19772c3c97647bdda76cd7f652ae517840431e88 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 13:43:42 +0300 Subject: fix problem with extra network saving images as previews losing generation info add a description for save_image_with_geninfo --- modules/images.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 1906e2ab..04f55f14 100644 --- a/modules/images.py +++ b/modules/images.py @@ -497,13 +497,23 @@ def get_next_sequence_number(path, basename): return result + 1 -def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None): +def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'): + """ + Saves image to filename, including geninfo as text information for generation info. + For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key. + For JPG images, there's no dictionary and geninfo just replaces the EXIF description. + """ + if extension is None: extension = os.path.splitext(filename)[1] image_format = Image.registered_extensions()[extension] if extension.lower() == '.png': + existing_pnginfo = existing_pnginfo or {} + if opts.enable_pnginfo: + existing_pnginfo[pnginfo_section_name] = geninfo + if opts.enable_pnginfo: pnginfo_data = PngImagePlugin.PngInfo() for k, v in (existing_pnginfo or {}).items(): @@ -622,7 +632,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i """ temp_file_path = f"{filename_without_extension}.tmp" - save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo) + save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name) os.replace(temp_file_path, filename_without_extension + extension) -- cgit v1.2.1 From 7a7fa25d02d469533dab5084bbd08d96d2df45a2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 14:21:40 +0300 Subject: lint fix for #11492 --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index fdccec09..9a5d9585 100644 --- a/modules/images.py +++ b/modules/images.py @@ -661,7 +661,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i try: # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16 image = image.resize(resize_to, LANCZOS) - except: + except Exception: image = image.resize(resize_to) try: _atomically_save_image(image, fullfn_without_extension, ".jpg") -- cgit v1.2.1 From 3602602260abaa325850e4768b7e253834e207d0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 14:44:02 +0300 Subject: whitespace for #11477 --- modules/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/scripts.py b/modules/scripts.py index d96f88b0..a07adc42 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -187,7 +187,7 @@ class Script: return f'script_{tabname}{title}_{item_id}' - def before_hr(self, p ,*args): + def before_hr(self, p, *args): """ This function is called before hires fix start. """ -- cgit v1.2.1 From b88645d9ebddfa26aaf6ee25519a95c967a23138 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 15:14:14 +0300 Subject: additional changes for merge conflict for #11337 --- modules/img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index b8ea3a3c..5e18bab9 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -230,7 +230,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_input_dir, png_info_props=img2img_batch_output_dir, png_info_dir=img2img_batch_inpaint_mask_dir) processed = Processed(p, [], p.seed, "") else: -- cgit v1.2.1 From 9043b91649f35adaa732d811184e81afb7a34b71 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 15:14:24 +0300 Subject: additional changes for merge conflict for #11337 --- modules/ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index c752a64d..e83f2651 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -934,13 +934,13 @@ def create_ui(): inpaint_full_res, inpaint_full_res_padding, inpainting_mask_invert, - img2img_batch_use_png_info, - img2img_batch_png_info_props, - img2img_batch_png_info_dir, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, override_settings, + img2img_batch_use_png_info, + img2img_batch_png_info_props, + img2img_batch_png_info_dir, ] + custom_inputs, outputs=[ img2img_gallery, -- cgit v1.2.1 From 1d71c36de2d7bbbcd290ba4dc5afd8ba909c74f8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 15:21:29 +0300 Subject: third time's the charm --- modules/img2img.py | 2 +- modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 5e18bab9..881212fc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -230,7 +230,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_input_dir, png_info_props=img2img_batch_output_dir, png_info_dir=img2img_batch_inpaint_mask_dir) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) processed = Processed(p, [], p.seed, "") else: diff --git a/modules/ui.py b/modules/ui.py index e83f2651..39d226ad 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -733,7 +733,7 @@ def create_ui(): img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") - with gr.Accordion("PNG info"): + with gr.Accordion("PNG info", open=False): img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info") img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir") img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.") -- cgit v1.2.1 From 274a3e21babe5fa913b4a34d49b5d7cd72c5fa89 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 15:42:00 +0300 Subject: small rework for img2img PNG info --- modules/img2img.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 881212fc..a5f1c148 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -96,27 +96,16 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal info_img = Image.open(info_img_path) geninfo, _ = imgutil.read_info_from_image(info_img) parsed_parameters = parse_generation_parameters(geninfo) - if("Prompt" in png_info_props): - p.prompt = prompt + " " + parsed_parameters["Prompt"] - if("Negative prompt" in png_info_props): - p.negative_prompt = negative_prompt + " " + parsed_parameters["Negative prompt"] - if("Seed" in png_info_props): - p.seed = int(parsed_parameters["Seed"]) - if("CFG scale" in png_info_props): - p.cfg_scale = float(parsed_parameters["CFG scale"]) - if("Sampler" in png_info_props): - p.sampler_name = parsed_parameters["Sampler"] - if("Steps" in png_info_props): - p.steps = int(parsed_parameters["Steps"]) - except Exception as e: - print(f"batch png info: using ui set prompts; failed to get png info for {image}") - print(e) - p.prompt = prompt - p.negative_prompt = negative_prompt - p.seed = seed - p.cfg_scale = cfg_scale - p.sampler_name = sampler_name - p.steps = steps + parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})} + except Exception: + parsed_parameters = {} + + p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "") + p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "") + p.seed = int(parsed_parameters.get("Seed", seed)) + p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale)) + p.sampler_name = parsed_parameters.get("Sampler", sampler_name) + p.steps = int(parsed_parameters.get("Steps", steps)) proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: -- cgit v1.2.1 From 7a6abc59ea1ecd8bb311de1719b018fb5960cd80 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 16:15:28 +0300 Subject: for #10650: change key to alt+arrows, enable by default --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index b9c53875..b29c3307 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -494,7 +494,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), - "keyedit_move": OptionInfo(False, "Ctrl+left/right moves prompt elements"), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), -- cgit v1.2.1 From d7d6e8cfc8b85a99a48f82975ee213d487783c28 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 16:45:59 +0300 Subject: use natural sort for shared.walk_files and shared.listfiles, as well as for dirs in extra networks --- modules/shared.py | 14 +++++++++++--- modules/ui_extra_networks.py | 4 ++-- 2 files changed, 13 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index b29c3307..48478a68 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,6 +1,7 @@ import datetime import json import os +import re import sys import threading import time @@ -832,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) mem_mon.start() +def natural_sort_key(s, regex=re.compile('([0-9]+)')): + return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)] + + def listfiles(dirname): - filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")] + filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] @@ -858,8 +863,11 @@ def walk_files(path, allowed_extensions=None): if allowed_extensions is not None: allowed_extensions = set(allowed_extensions) - for root, _, files in os.walk(path, followlinks=True): - for filename in files: + items = list(os.walk(path, followlinks=True)) + items = sorted(items, key=lambda x: natural_sort_key(x[0])) + + for root, _, files in items: + for filename in sorted(files, key=natural_sort_key): if allowed_extensions is not None: _, ext = os.path.splitext(filename) if ext not in allowed_extensions: diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 1efd00b0..693cafb6 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -90,8 +90,8 @@ class ExtraNetworksPage: subdirs = {} for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: - for root, dirs, _ in os.walk(parentdir, followlinks=True): - for dirname in dirs: + for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): + for dirname in sorted(dirs, key=shared.natural_sort_key): x = os.path.join(root, dirname) if not os.path.isdir(x): -- cgit v1.2.1 From e161b5a0259c870b9d01408d02c504c3281dbdb1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 16:54:03 +0300 Subject: rework #10436 to use shared.walk_files --- modules/img2img.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 3b83814b..ef87eb0f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -18,13 +18,7 @@ import modules.scripts def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): processing.fix_seed(p) - images = [] - for root, directories, files in os.walk(input_dir): - for filename in files: - filepath = os.path.join(root, filename) - if filepath.endswith(".jpg") or filepath.endswith(".jpeg") or filepath.endswith(".png") or filepath.endswith(".webp"): - images.append(filepath) - + images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp"))) is_inpaint_batch = False if inpaint_mask_dir: -- cgit v1.2.1 From da8916f92649fc4d947cb46d9d8f8ea1621b2a59 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 17:13:18 +0300 Subject: added torch.mps.empty_cache() to torch_gc() changed a bunch of places that use torch.cuda.empty_cache() to use torch_gc() instead --- modules/codeformer_model.py | 2 +- modules/devices.py | 3 +++ modules/sd_models.py | 1 - 3 files changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index f293acf5..da42b5e9 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -99,7 +99,7 @@ def setup_model(dirname): output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output - torch.cuda.empty_cache() + devices.torch_gc() except Exception: errors.report('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) diff --git a/modules/devices.py b/modules/devices.py index 620ed1a6..c5ad950f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -49,10 +49,13 @@ def get_device_for(task): def torch_gc(): + if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() + elif has_mps() and hasattr(torch.mps, 'empty_cache'): + torch.mps.empty_cache() def enable_tf32(): diff --git a/modules/sd_models.py b/modules/sd_models.py index f65f4e36..653c4cc0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -590,7 +590,6 @@ def unload_model_weights(sd_model=None, info=None): sd_model = None gc.collect() devices.torch_gc() - torch.cuda.empty_cache() print(f"Unloaded weights {timer.summary()}.") -- cgit v1.2.1 From da468a585bb631bc91c3435f349dfb7ce7fe3895 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 3 Jul 2023 12:17:20 +0300 Subject: Fix typo: checkpoint_alisases --- modules/api/api.py | 4 ++-- modules/processing.py | 2 +- modules/sd_models.py | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 224bbfc6..5793bb44 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases +from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases from modules.sd_vae import vae_dict from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models @@ -519,7 +519,7 @@ class Api: def set_config(self, req: Dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases: + if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): diff --git a/modules/processing.py b/modules/processing.py index 21d1492c..cd568a20 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -606,7 +606,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint - if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None: + if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) sd_models.reload_model_weights() diff --git a/modules/sd_models.py b/modules/sd_models.py index 653c4cc0..060e0007 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -23,7 +23,8 @@ model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} -checkpoint_alisases = {} +checkpoint_aliases = {} +checkpoint_alisases = checkpoint_aliases # for compatibility with old name checkpoints_loaded = collections.OrderedDict() @@ -66,7 +67,7 @@ class CheckpointInfo: def register(self): checkpoints_list[self.title] = self for id in self.ids: - checkpoint_alisases[id] = self + checkpoint_aliases[id] = self def calculate_shorthash(self): self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") @@ -112,7 +113,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - checkpoint_alisases.clear() + checkpoint_aliases.clear() cmd_ckpt = shared.cmd_opts.ckpt if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt): @@ -136,7 +137,7 @@ def list_models(): def get_closet_checkpoint_match(search_string): - checkpoint_info = checkpoint_alisases.get(search_string, None) + checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: return checkpoint_info @@ -166,7 +167,7 @@ def select_checkpoint(): """Raises `FileNotFoundError` if no checkpoints are found.""" model_checkpoint = shared.opts.sd_model_checkpoint - checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) + checkpoint_info = checkpoint_aliases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info -- cgit v1.2.1 From 4981c7d3704e50dd93fe1b68d299239a4ded1ec2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 17:52:03 +0300 Subject: move github proxy to settings, System page. --- modules/shared.py | 1 + modules/ui_extensions.py | 33 ++++++++++++++------------------- 2 files changed, 15 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 48478a68..b7518de6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -391,6 +391,7 @@ options_templates.update(options_section(('system', "System"), { "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "github_proxy": OptionInfo("None", "Github proxy", ui_components.DropdownEditable, lambda: {"choices": ["None", "ghproxy.com", "hub.yzuu.cf", "hub.njuu.cf", "hub.nuaa.cf"]}).info("for custom inputs will just replace github.com with the input"), })) options_templates.update(options_section(('training', "Training"), { diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ac523bcf..a208012d 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -325,7 +325,18 @@ def normalize_git_url(url): return url -def install_extension_from_url(dirname, proxy, url, branch_name=None): +def github_proxy(url): + proxy = shared.opts.github_proxy + + if proxy == 'None': + return url + if proxy == 'ghproxy.com': + return "https://ghproxy.com/" + url + + return url.replace('github.com', proxy) + + +def install_extension_from_url(dirname, url, branch_name=None): check_access() if isinstance(dirname, str): @@ -335,18 +346,7 @@ def install_extension_from_url(dirname, proxy, url, branch_name=None): assert url, 'No URL specified' - proxy_list = { - "none": "", - "ghproxy": "https://ghproxy.com/", - "yzuu": "hub.yzuu.cf", - "njuu": "hub.njuu.cf", - "nuaa": "hub.nuaa.cf", - } - - if proxy in ['yzuu', 'njuu', 'nuaa']: - url = url.replace('github.com', proxy_list[proxy]) - elif proxy == 'ghproxy': - url = proxy_list[proxy] + url + url = github_proxy(url) if dirname is None or dirname == "": *parts, last_part = url.split('/') @@ -628,11 +628,6 @@ def create_ui(): ) with gr.TabItem("Install from URL", id="install_from_url"): - install_proxy = gr.Radio( - label="Install Proxy", choices=["none", "ghproxy", "nuaa", "yzuu", "njuu"], value="none", - info="If you can't access github.com, you can use a proxy to install extensions from github.com" - ) - install_url = gr.Text(label="URL for extension's git repository") install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") @@ -641,7 +636,7 @@ def create_ui(): install_button.click( fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]), - inputs=[install_dirname, install_proxy, install_url, install_branch], + inputs=[install_dirname, install_url, install_branch], outputs=[install_url, extensions_table, install_result], ) -- cgit v1.2.1 From 75f56406cede0095eb0e5dcc4e0d5759063e89dc Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sun, 9 Jul 2023 22:40:23 +0800 Subject: Revert Pull Request #11244 Revert "Add github mirror for the download extension" This reverts commit 9ec2ba2d28bb0d8f01e19e2919b7bf2e3e864773. Revert "Update code style" This reverts commit de022c4c80240a430a8099fb27a41aa505bf5b2f. Revert "Update call method" This reverts commit e9bd18c57bd83363d38c7409263fe87f3ed3a7f0. Revert "move github proxy to settings, System page." This reverts commit 4981c7d3704e50dd93fe1b68d299239a4ded1ec2. --- modules/shared.py | 1 - modules/ui_extensions.py | 17 ++--------------- 2 files changed, 2 insertions(+), 16 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index b7518de6..48478a68 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -391,7 +391,6 @@ options_templates.update(options_section(('system', "System"), { "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), - "github_proxy": OptionInfo("None", "Github proxy", ui_components.DropdownEditable, lambda: {"choices": ["None", "ghproxy.com", "hub.yzuu.cf", "hub.njuu.cf", "hub.nuaa.cf"]}).info("for custom inputs will just replace github.com with the input"), })) options_templates.update(options_section(('training', "Training"), { diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index a208012d..dff522ef 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -325,17 +325,6 @@ def normalize_git_url(url): return url -def github_proxy(url): - proxy = shared.opts.github_proxy - - if proxy == 'None': - return url - if proxy == 'ghproxy.com': - return "https://ghproxy.com/" + url - - return url.replace('github.com', proxy) - - def install_extension_from_url(dirname, url, branch_name=None): check_access() @@ -346,8 +335,6 @@ def install_extension_from_url(dirname, url, branch_name=None): assert url, 'No URL specified' - url = github_proxy(url) - if dirname is None or dirname == "": *parts, last_part = url.split('/') last_part = normalize_git_url(last_part) @@ -367,12 +354,12 @@ def install_extension_from_url(dirname, url, branch_name=None): shutil.rmtree(tmpdir, True) if not branch_name: # if no branch is specified, use the default branch - with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], verbose=False) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() else: - with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name, verbose=False) as repo: + with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo: repo.remote().fetch() for submodule in repo.submodules: submodule.update() -- cgit v1.2.1 From 44c27ebc7393ea793245aa565ace6c9bf1313980 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 10 Jul 2023 20:08:23 +0300 Subject: Use closing() with processing classes everywhere Follows up on #11569 --- modules/hypernetworks/hypernetwork.py | 6 ++++-- modules/img2img.py | 20 ++++++++++---------- modules/textual_inversion/textual_inversion.py | 6 ++++-- modules/txt2img.py | 11 ++++++----- 4 files changed, 24 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 51941c11..79670b87 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -3,6 +3,7 @@ import glob import html import os import inspect +from contextlib import closing import modules.textual_inversion.dataset import torch @@ -711,8 +712,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images) > 0 else None + with closing(p): + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None if unload: shared.sd_model.cond_stage_model.to(devices.cpu) diff --git a/modules/img2img.py b/modules/img2img.py index ef87eb0f..4d9a02cc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -1,4 +1,5 @@ import os +from contextlib import closing from pathlib import Path import numpy as np @@ -217,18 +218,17 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if mask: p.extra_generation_params["Mask blur"] = mask_blur - if is_batch: - assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" + with closing(p): + if is_batch: + assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir) - processed = Processed(p, [], p.seed, "") - else: - processed = modules.scripts.scripts_img2img.run(p, *args) - if processed is None: - processed = process_images(p) - - p.close() + processed = Processed(p, [], p.seed, "") + else: + processed = modules.scripts.scripts_img2img.run(p, *args) + if processed is None: + processed = process_images(p) shared.total_tqdm.clear() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index bb6f211c..cbe975b7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,5 +1,6 @@ import os from collections import namedtuple +from contextlib import closing import torch import tqdm @@ -584,8 +585,9 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st preview_text = p.prompt - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images) > 0 else None + with closing(p): + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images) > 0 else None if unload: shared.sd_model.first_stage_model.to(devices.cpu) diff --git a/modules/txt2img.py b/modules/txt2img.py index 6aa79f23..d0be2e73 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,3 +1,5 @@ +from contextlib import closing + import modules.scripts from modules import sd_samplers, processing from modules.generation_parameters_copypaste import create_override_settings_dict @@ -53,12 +55,11 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) - processed = modules.scripts.scripts_txt2img.run(p, *args) - - if processed is None: - processed = processing.process_images(p) + with closing(p): + processed = modules.scripts.scripts_txt2img.run(p, *args) - p.close() + if processed is None: + processed = processing.process_images(p) shared.total_tqdm.clear() -- cgit v1.2.1 From 10d4e4ace2d243c020e6a83060c938dee7d8c02d Mon Sep 17 00:00:00 2001 From: TangJicheng Date: Tue, 11 Jul 2023 17:30:57 +0900 Subject: add cmd_args: --timeout-keep-alive --- modules/cmd_args.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index de905caa..982d9055 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -107,3 +107,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') -- cgit v1.2.1 From 14501f56aaf3c97fb2c38633350dc747b9651f43 Mon Sep 17 00:00:00 2001 From: TangJicheng Date: Tue, 11 Jul 2023 17:32:04 +0900 Subject: set timeout_keep_alive --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 7f7e3a9b..4ea5d825 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -715,4 +715,4 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) + uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive) -- cgit v1.2.1 From b85fc7187d953828340d4e3af34af46d9fc70b9e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 10 Jul 2023 21:18:34 +0300 Subject: Fix MPS cache cleanup Importing torch does not import torch.mps so the call failed. --- modules/devices.py | 5 +++-- modules/mac_specific.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index c5ad950f..57e51da3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -54,8 +54,9 @@ def torch_gc(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() - elif has_mps() and hasattr(torch.mps, 'empty_cache'): - torch.mps.empty_cache() + + if has_mps(): + mac_specific.torch_mps_gc() def enable_tf32(): diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 735847f5..2c2f15ca 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,8 +1,12 @@ +import logging + import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +log = logging.getLogger() + # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, # use check `getattr` and try it for compatibility. @@ -19,9 +23,19 @@ def check_for_mps() -> bool: return False else: return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + has_mps = check_for_mps() +def torch_mps_gc() -> None: + try: + from torch.mps import empty_cache + empty_cache() + except Exception: + log.warning("MPS garbage collection failed", exc_info=True) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': -- cgit v1.2.1 From af081211ee93622473ee575de30fed2fd8263c09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 11 Jul 2023 21:16:43 +0300 Subject: getting SD2.1 to run on SDXL repo --- modules/launch_utils.py | 3 ++ modules/paths.py | 1 + modules/prompt_parser.py | 64 +++++++++++++++++++++++++++++++-------- modules/sd_hijack.py | 9 ++++++ modules/sd_hijack_open_clip.py | 4 +++ modules/sd_models.py | 8 +++-- modules/sd_models_config.py | 2 ++ modules/sd_models_xl.py | 40 ++++++++++++++++++++++++ modules/sd_samplers_kdiffusion.py | 45 +++++++++++++++++++++------ 9 files changed, 152 insertions(+), 24 deletions(-) create mode 100644 modules/sd_models_xl.py (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0e0dbca4..3b740dbd 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -235,11 +235,13 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -297,6 +299,7 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index bada804e..f509a85f 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,6 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 0069d8b0..d7f9e9a9 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps): cond_schedule = [] for i, (end_at_step, _) in enumerate(prompt_schedule): - cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i])) + if isinstance(conds, dict): + cond = {k: v[i] for k, v in conds.items()} + else: + cond = conds[i] + + cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond)) cache[prompt] = cond_schedule res.append(cond_schedule) @@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) +class DictWithShape(dict): + def __init__(self, x, shape): + super().__init__() + self.update(x) + + @property + def shape(self): + return self["crossattn"].shape + + def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond - res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + is_dict = isinstance(param, dict) + + if is_dict: + dict_cond = param + res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()} + res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape) + else: + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, entry in enumerate(cond_schedule): if current_step <= entry.end_at_step: target_index = current break - res[i] = cond_schedule[target_index].cond + + if is_dict: + for k, param in cond_schedule[target_index].cond.items(): + res[k][i] = param + else: + res[i] = cond_schedule[target_index].cond return res +def stack_conds(tensors): + # if prompts have wildly different lengths above the limit we'll get tensors of different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + + return torch.stack(tensors) + + + def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): param = c.batch[0][0].schedules[0].cond @@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) - # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes - # and won't be able to torch.stack them. So this fixes that. - token_count = max([x.shape[0] for x in tensors]) - for i in range(len(tensors)): - if tensors[i].shape[0] != token_count: - last_vector = tensors[i][-1:] - last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) - tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + if isinstance(tensors[0], dict): + keys = list(tensors[0].keys()) + stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys} + stacked = DictWithShape(stacked, stacked['crossattn'].shape) + else: + stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype) - return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + return conds_list, stacked re_attention = re.compile(r""" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3b6f95ce..c4b9211f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -166,6 +166,15 @@ class StableDiffusionModelHijack: undo_optimizations() def hijack(self, m): + conditioner = getattr(m, 'conditioner', None) + if conditioner: + for i in range(len(conditioner.embedders)): + embedder = conditioner.embedders[i] + if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + conditioner.embedders[i] = m.cond_stage_model + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index f733e852..6ac5bda6 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -16,6 +16,10 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit self.id_end = tokenizer.encoder[""] self.id_pad = 0 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def tokenize(self, texts): assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' diff --git a/modules/sd_models.py b/modules/sd_models.py index 060e0007..8d639583 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -289,6 +289,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + if hasattr(model, 'conditioner'): + sd_models_xl.extend_sdxl(model) + model.load_state_dict(state_dict, strict=False) del state_dict timer.record("apply weights to model") @@ -334,7 +337,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.sd_checkpoint_info = checkpoint_info shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - model.logvar = model.logvar.to(devices.device) # fix for training + if hasattr(model, 'logvar'): + model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9bfe1237..96501569 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -6,11 +6,13 @@ from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py new file mode 100644 index 00000000..d43b8868 --- /dev/null +++ b/modules/sd_models_xl.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import torch + +import sgm.models.diffusion +import sgm.modules.diffusionmodules.denoiser_scaling +import sgm.modules.diffusionmodules.discretizer +from modules import devices + + +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): + for embedder in self.conditioner.embedders: + embedder.ucg_rate = 0.0 + + c = self.conditioner({'txt': batch}) + + return c + + +def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + return self.model(x, t, cond) + + +def extend_sdxl(model): + dtype = next(model.model.diffusion_model.parameters()).dtype + model.model.diffusion_model.dtype = dtype + model.model.conditioning_key = 'crossattn' + + model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_key = model.cond_stage_model.input_key + + model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" + + discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b76..73289ce4 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -53,6 +53,28 @@ k_diffusion_scheduler = { } +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) @@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module): if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} else: image_uncond = image_cond - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) @@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module): num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: - tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1) + tensor = pad_cond(tensor, -num_repeats, empty) self.padded_cond_uncond = True elif num_repeats > 0: - uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) + uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: - cond_in = torch.cat([tensor, uncond, uncond]) + cond_in = catenate_conds([tensor, uncond, uncond]) elif skip_uncond: cond_in = tensor else: - cond_in = torch.cat([tensor, uncond]) + cond_in = catenate_conds([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in)) + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size @@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module): b = min(a + batch_size, tensor.shape[0]) if not is_edit_model: - c_crossattn = [tensor[a:b]] + c_crossattn = subscript_cond(tensor, a, b) else: c_crossattn = torch.cat([tensor[a:b]], uncond) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) denoised_image_indexes = [x[0][0] for x in conds_list] if skip_uncond: -- cgit v1.2.1 From 3fee3c34f1b01d21770ab0a226b432cdd8444792 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Wed, 12 Jul 2023 02:45:03 -0600 Subject: Save img2img batch with images.save_image() --- modules/img2img.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 2c497020..15306972 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -8,6 +8,7 @@ from modules import sd_samplers from modules.generation_parameters_copypaste import create_override_settings_dict from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state +from modules.images import save_image import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html @@ -84,17 +85,17 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal proc = process_images(p) for n, processed_image in enumerate(proc.images): - filename = image_path.name + filename = image_path.stem + infotext = proc.infotext(p, n) if n > 0: - left, right = os.path.splitext(filename) - filename = f"{left}-{n}{right}" + filename += f"-{n}" if not save_normally: os.makedirs(output_dir, exist_ok=True) if processed_image.mode == 'RGBA': processed_image = processed_image.convert("RGB") - processed_image.save(os.path.join(output_dir, filename)) + save_image(processed_image, output_dir, None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False) def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): -- cgit v1.2.1 From 6c0d5d1198576dbe664f55cffec27b03d0789efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=80=E5=AE=97?= Date: Wed, 12 Jul 2023 16:51:50 +0800 Subject: fix: check fill size none zero when resize (fixes #11425) --- modules/images.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 7bbfc3e0..7935b122 100644 --- a/modules/images.py +++ b/modules/images.py @@ -302,12 +302,14 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): if ratio < src_ratio: fill_height = height // 2 - src_h // 2 - res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) - res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + if fill_height > 0: + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) elif ratio > src_ratio: fill_width = width // 2 - src_w // 2 - res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) - res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + if fill_width > 0: + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) return res -- cgit v1.2.1 From 8f6b24ce5922174d96eb9776126488cb28694ff8 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 12 Jul 2023 15:16:42 +0300 Subject: Add correct logger name --- modules/mac_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 2c2f15ca..328b5973 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -5,7 +5,7 @@ import platform from modules.sd_hijack_utils import CondFunc from packaging import version -log = logging.getLogger() +log = logging.getLogger(__name__) # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, -- cgit v1.2.1 From 3d524fd3f1bdb17946bf6fa8a3cdf7b10859c495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 12 Jul 2023 15:17:13 +0300 Subject: Don't do MPS GC when there's a latent that could still be sampled --- modules/mac_specific.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 328b5973..9ceb43ba 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -30,6 +30,10 @@ has_mps = check_for_mps() def torch_mps_gc() -> None: try: + from modules.shared import state + if state.current_latent is not None: + log.debug("`current_latent` is set, skipping MPS garbage collection") + return from torch.mps import empty_cache empty_cache() except Exception: -- cgit v1.2.1 From ea49bb06125262e61c44003e42219bc04e38b10b Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 12 Jul 2023 23:30:22 +0900 Subject: use submit blur for quick settings textbox --- modules/ui_settings.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 0c560b30..a6076bf3 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -260,13 +260,20 @@ class UiSettings: component = self.component_dict[k] info = opts.data_labels[k] - change_handler = component.release if hasattr(component, 'release') else component.change - change_handler( - fn=lambda value, k=k: self.run_settings_single(value, key=k), - inputs=[component], - outputs=[component, self.text_settings], - show_progress=info.refresh is not None, - ) + if isinstance(component, gr.Textbox): + methods = [component.submit, component.blur] + elif hasattr(component, 'release'): + methods = [component.release] + else: + methods = [component.change] + + for method in methods: + method( + fn=lambda value, k=k: self.run_settings_single(value, key=k), + inputs=[component], + outputs=[component, self.text_settings], + show_progress=info.refresh is not None, + ) button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) button_set_checkpoint.click( -- cgit v1.2.1 From da464a3fb39ecc6ea7b22fe87271194480d8501c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 12 Jul 2023 23:52:43 +0300 Subject: SDXL support --- modules/launch_utils.py | 17 +++++++++++++ modules/lowvram.py | 51 +++++++++++++++++++++++++++----------- modules/paths.py | 9 ++++++- modules/processing.py | 7 ++++-- modules/prompt_parser.py | 23 ++++++++++++++--- modules/sd_hijack.py | 23 ++++++++++++++++- modules/sd_hijack_clip.py | 16 +++++++++--- modules/sd_hijack_open_clip.py | 38 +++++++++++++++++++++++++--- modules/sd_hijack_optimizations.py | 51 ++++++++++++++++++++++++++++++++------ modules/sd_models.py | 14 +++++++++-- modules/sd_models_config.py | 5 +++- modules/sd_models_xl.py | 27 +++++++++++++++++--- modules/sd_samplers_kdiffusion.py | 2 +- modules/shared.py | 2 ++ 14 files changed, 240 insertions(+), 45 deletions(-) (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 3b740dbd..aa9d1880 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -224,6 +224,20 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) +def mute_sdxl_imports(): + """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" + + import importlib + + module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None)) + module.LPIPS = None + sys.modules['taming.modules.losses.lpips'] = module + + module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None)) + module.StableDataModuleFromConfig = None + sys.modules['sgm.data'] = module + + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -319,11 +333,14 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) + mute_sdxl_imports() + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) + def configure_for_tests(): if "--api" not in sys.argv: sys.argv.append("--api") diff --git a/modules/lowvram.py b/modules/lowvram.py index d95bcfbf..da4f33a8 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram): send_me_to_gpu(first_stage_model, None) return first_stage_model_decode(z) - # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field - if hasattr(sd_model.cond_stage_model, 'model'): - sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model - - # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then - # send the model to GPU. Then put modules back. the modules will be in CPU. - stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None + to_remain_in_cpu = [ + (sd_model, 'first_stage_model'), + (sd_model, 'depth_model'), + (sd_model, 'embedder'), + (sd_model, 'model'), + (sd_model, 'embedder'), + ] + + is_sdxl = hasattr(sd_model, 'conditioner') + is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') + + if is_sdxl: + to_remain_in_cpu.append((sd_model, 'conditioner')) + elif is_sd2: + to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) + else: + to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer')) + + # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model + stored = [] + for obj, field in to_remain_in_cpu: + module = getattr(obj, field, None) + stored.append(module) + setattr(obj, field, None) + + # send the model to GPU. sd_model.to(devices.device) - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored + + # put modules back. the modules will be in CPU. + for (obj, field), module in zip(to_remain_in_cpu, stored): + setattr(obj, field, module) # register hooks for those the first three models - sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + if is_sdxl: + sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) + elif is_sd2: + sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) + else: + sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap @@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model - if hasattr(sd_model.cond_stage_model, 'model'): - sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer - del sd_model.cond_stage_model.transformer - if use_medvram: sd_model.model.register_forward_pre_hook(send_me_to_gpu) else: diff --git a/modules/paths.py b/modules/paths.py index f509a85f..1100a8dc 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), - (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), @@ -36,6 +36,13 @@ for d, must_exist, what, options in path_dirs: d = os.path.abspath(d) if "atstart" in options: sys.path.insert(0, d) + elif "sgm" in options: + # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we + # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. + + sys.path.insert(0, d) + import sgm + sys.path.pop(0) else: sys.path.append(d) paths[what] = d diff --git a/modules/processing.py b/modules/processing.py index cd568a20..85d35423 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -343,10 +343,13 @@ class StableDiffusionProcessing: return cache[1] def setup_conds(self): + prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height) + sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index d7f9e9a9..33810669 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import namedtuple from typing import List @@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -def get_learned_conditioning(model, prompts, steps): +class SdConditioning(list): + """ + A list with prompts for stable diffusion's conditioner model. + Can also specify width and height of created image - SDXL needs it. + """ + def __init__(self, prompts, width=None, height=None): + super().__init__() + self.extend(prompts) + self.width = width or getattr(prompts, 'width', None) + self.height = height or getattr(prompts, 'height', None) + + +def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps): re_AND = re.compile(r"\bAND\b") re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") -def get_multicond_prompt_list(prompts): + +def get_multicond_prompt_list(prompts: SdConditioning | list[str]): res_indexes = [] - prompt_flat_list = [] prompt_indexes = {} + prompt_flat_list = SdConditioning(prompts) + prompt_flat_list.clear() for prompt in prompts: subprompts = re_AND.split(prompt) @@ -201,6 +217,7 @@ class MulticondLearnedConditioning: self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.batch: List[List[ComposableScheduledPromptConditioning]] = batch + def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. For each prompt, the list is obtained by splitting the prompt using the AND separator. diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c4b9211f..266811f9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim import ldm.models.diffusion.plms import ldm.modules.encoders.modules +import sgm.modules.attention +import sgm.modules.diffusionmodules.model +import sgm.modules.diffusionmodules.openaimodel +import sgm.modules.encoders.modules + attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward @@ -56,6 +61,9 @@ def apply_optimizations(option=None): ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + sgm.modules.diffusionmodules.model.nonlinearity = silu + sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + if current_optimizer is not None: current_optimizer.undo() current_optimizer = None @@ -89,6 +97,10 @@ def undo_optimizations(): ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + def fix_checkpoint(): """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want @@ -170,10 +182,19 @@ class StableDiffusionModelHijack: if conditioner: for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] - if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': + typename = type(embedder).__name__ + if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) conditioner.embedders[i] = m.cond_stage_model + if typename == 'FrozenCLIPEmbedder': + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self) + conditioner.embedders[i] = m.cond_stage_model + if typename == 'FrozenOpenCLIPEmbedder2': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 3b5a7666..6c17a81d 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def empty_chunk(self): """creates an empty PromptChunk and returns it""" @@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): """ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will - be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ @@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) self.hijack.comments.append(f"Used embeddings: {embeddings_list}") - return torch.hstack(zs) + if getattr(self.wrapped, 'return_pooled', False): + return torch.hstack(zs), zs[0].pooled + else: + return torch.hstack(zs) def process_tokens(self, remade_batch_tokens, batch_multipliers): """ @@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z = z * (original_mean / new_mean) + z *= (original_mean / new_mean) return z diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index 6ac5bda6..fcf5ad07 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -16,10 +16,6 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit self.id_end = tokenizer.encoder[""] self.id_pad = 0 - self.is_trainable = getattr(wrapped, 'is_trainable', False) - self.input_key = getattr(wrapped, 'input_key', 'txt') - self.legacy_ucg_val = None - def tokenize(self, texts): assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' @@ -39,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) return embedded + + +class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + d = self.wrapped.encode_with_transformer(tokens) + z = d[self.wrapped.layer] + + pooled = d.get("pooled") + if pooled is not None: + z.pooled = pooled + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 53e27ade..e99c9ba5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork import ldm.modules.attention import ldm.modules.diffusionmodules.model +import sgm.modules.attention +import sgm.modules.diffusionmodules.model + diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward class SdOptimization: @@ -39,6 +43,9 @@ class SdOptimization: ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward + class SdOptimizationXformers(SdOptimization): name = "xformers" @@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + sgm.modules.attention.CrossAttention.forward = xformers_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward class SdOptimizationSdpNoMem(SdOptimization): @@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward class SdOptimizationSdp(SdOptimizationSdpNoMem): @@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward class SdOptimizationSubQuad(SdOptimization): @@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward class SdOptimizationV1(SdOptimization): @@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization): cmd_opt = "opt_split_attention_v1" priority = 10 - def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 class SdOptimizationInvokeAI(SdOptimization): @@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI class SdOptimizationDoggettx(SdOptimization): @@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def list_optimizers(res): @@ -155,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) @@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): # taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None): +def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) @@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None): # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- mem_total_gb = psutil.virtual_memory().total // (1 << 30) + def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) s = s.softmax(dim=-1, dtype=s.dtype) return einsum('b i j, b j d -> b i d', s, v) + def einsum_op_slice_0(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): @@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size): r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) return r + def einsum_op_slice_1(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): @@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size): r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) return r + def einsum_op_mps_v1(q, k, v): if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 return einsum_op_compvis(q, k, v) @@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v): slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) + def einsum_op_mps_v2(q, k, v): if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: return einsum_op_slice_0(q, k, v, 1) + def einsum_op_tensor_mem(q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: @@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb): return einsum_op_slice_0(q, k, v, q.shape[0] // div) return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + def einsum_op_cuda(q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] @@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v): # Divide factor of safety as there's copying and fragmentation return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + def einsum_op(q, k, v): if q.device.type == 'cuda': return einsum_op_cuda(q, k, v) @@ -328,7 +354,8 @@ def einsum_op(q, k, v): # Tested on i7 with 8MB L3 cache. return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q = self.to_q(x) @@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None): +def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x + def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape @@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None -def xformers_attention_forward(self, x, context=None, mask=None): +def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None): out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) + # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None): +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): hidden_states = self.to_out[1](hidden_states) return hidden_states -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None): + +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x): return h3 + def xformers_attnblock_forward(self, x): try: h_ = x @@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x): except NotImplementedError: return cross_attention_attnblock_forward(self, x) + def sdp_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x): out = self.proj_out(out) return x + out + def sdp_no_mem_attnblock_forward(self, x): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return sdp_attnblock_forward(self, x) + def sub_quad_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8d639583..e4aae597 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -411,6 +411,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' +sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' class SdModelData: @@ -445,6 +446,15 @@ class SdModelData: model_data = SdModelData() +def get_empty_cond(sd_model): + if hasattr(sd_model, 'conditioner'): + d = sd_model.get_learned_conditioning([""]) + return d['crossattn'] + else: + return sd_model.cond_stage_model([""]) + + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict + clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict timer.record("find config") @@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks") with devices.autocast(), torch.no_grad(): - sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""]) + sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) timer.record("calculate empty prompt") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 96501569..2e92479a 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -70,7 +71,9 @@ def guess_model_config_from_state_dict(sd, filename): diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) - if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: + return config_sdxl + elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: return config_unclip diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index d43b8868..e8e270c3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,18 +1,30 @@ from __future__ import annotations +import sys + import torch import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer -from modules import devices +from modules import devices, shared, prompt_parser -def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): for embedder in self.conditioner.embedders: embedder.ucg_rate = 0.0 - c = self.conditioner({'txt': batch}) + width = getattr(self, 'target_width', 1024) + height = getattr(self, 'target_height', 1024) + + sdxl_conds = { + "txt": batch, + "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + } + + c = self.conditioner(sdxl_conds) return c @@ -26,7 +38,7 @@ def extend_sdxl(model): model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] model.cond_stage_key = model.cond_stage_model.input_key model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" @@ -34,7 +46,14 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.is_xl = True + sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.modules.attention.print = lambda *args: None +sgm.modules.diffusionmodules.model.print = lambda *args: None +sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None +sgm.modules.encoders.modules.print = lambda *args: None + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 73289ce4..5552a8dc 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -186,7 +186,7 @@ class CFGDenoiser(torch.nn.Module): for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size diff --git a/modules/shared.py b/modules/shared.py index b7518de6..71afd94f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,6 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), + "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), + "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { -- cgit v1.2.1 From 5cf623c58ef3c158e8b25f7c3d516ffc16769fa4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 00:08:19 +0300 Subject: linter --- modules/paths.py | 2 +- modules/sd_models_xl.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/paths.py b/modules/paths.py index 1100a8dc..c6f8904e 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -41,7 +41,7 @@ for d, must_exist, what, options in path_dirs: # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. sys.path.insert(0, d) - import sgm + import sgm # noqa: F401 sys.path.pop(0) else: sys.path.append(d) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index e8e270c3..9224c1a3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,7 +1,5 @@ from __future__ import annotations -import sys - import torch import sgm.models.diffusion -- cgit v1.2.1 From a04c95512148fc6df64535a995fbc8f499cae206 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 00:12:25 +0300 Subject: fix importlib.machinery issue on github's autotests #yolo --- modules/launch_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index aa9d1880..4f48f3a1 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -227,13 +227,14 @@ def run_extensions_installers(settings_file): def mute_sdxl_imports(): """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" - import importlib + class Dummy: + pass - module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None)) + module = Dummy() module.LPIPS = None sys.modules['taming.modules.losses.lpips'] = module - module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None)) + module = Dummy() module.StableDataModuleFromConfig = None sys.modules['sgm.data'] = module -- cgit v1.2.1 From b717eb7e56a4e620e77a2225e80223c89cb4f0d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 08:29:37 +0300 Subject: mute unneeded SDXL imports for tests too --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 4f48f3a1..56b972d5 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -334,8 +334,6 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) - mute_sdxl_imports() - if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) @@ -357,6 +355,8 @@ def configure_for_tests(): def start(): + mute_sdxl_imports() + print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: -- cgit v1.2.1 From ac4ccfa1369e74492b467294eab96c3f558b297b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 09:30:33 +0300 Subject: get attention optimizations to work --- modules/hypernetworks/hypernetwork.py | 2 +- modules/launch_utils.py | 1 + modules/sd_hijack_optimizations.py | 14 +++++++------- modules/sd_models_xl.py | 3 +++ 4 files changed, 12 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 79670b87..c4821d21 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None): return context_k, context_v -def attention_CrossAttention_forward(self, x, context=None, mask=None): +def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 56b972d5..183730d2 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -239,6 +239,7 @@ def mute_sdxl_imports(): sys.modules['sgm.data'] = module + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index e99c9ba5..b5f85ba5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -173,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -214,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None, additiona # taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -355,7 +355,7 @@ def einsum_op(q, k, v): return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) @@ -383,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, add # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -470,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None -def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -496,7 +496,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, additional_toke # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -537,7 +537,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None, addit return hidden_states -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 9224c1a3..4d1aa497 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -55,3 +55,6 @@ sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None sgm.modules.encoders.modules.print = lambda *args: None +# this gets the code to load the vanilla attention that we override +sgm.modules.attention.SDP_IS_AVAILABLE = True +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False \ No newline at end of file -- cgit v1.2.1 From 21aec6f567f52271efbbe33a2ab6561f9a47b787 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 09:38:54 +0300 Subject: lint --- modules/sd_models_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 4d1aa497..1dd4459f 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -57,4 +57,4 @@ sgm.modules.encoders.modules.print = lambda *args: None # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True -sgm.modules.attention.XFORMERS_IS_AVAILABLE = False \ No newline at end of file +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False -- cgit v1.2.1 From 594c8e7b263d9b37f4b18b56b159aeb6d1bba1b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 11:35:52 +0300 Subject: fix CLIP doing the unneeded normalization revert SD2.1 back to use the original repo add SDXL's force_zero_embeddings to negative prompt --- modules/processing.py | 2 +- modules/prompt_parser.py | 14 ++++++++++---- modules/sd_hijack.py | 2 +- modules/sd_hijack_clip.py | 15 +++++++++++++++ modules/sd_models_config.py | 1 - modules/sd_models_xl.py | 3 ++- 6 files changed, 29 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index 85d35423..f01a6907 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -344,7 +344,7 @@ class StableDiffusionProcessing: def setup_conds(self): prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) - negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 33810669..b29d079d 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -116,11 +116,17 @@ class SdConditioning(list): A list with prompts for stable diffusion's conditioner model. Can also specify width and height of created image - SDXL needs it. """ - def __init__(self, prompts, width=None, height=None): + def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None): super().__init__() self.extend(prompts) - self.width = width or getattr(prompts, 'width', None) - self.height = height or getattr(prompts, 'height', None) + + if copy_from is None: + copy_from = prompts + + self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False) + self.width = width or getattr(copy_from, 'width', None) + self.height = height or getattr(copy_from, 'height', None) + def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): @@ -153,7 +159,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): res.append(cached) continue - texts = [x[1] for x in prompt_schedule] + texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts) conds = model.get_learned_conditioning(texts) cond_schedule = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 266811f9..647cdfbe 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -190,7 +190,7 @@ class StableDiffusionModelHijack: if typename == 'FrozenCLIPEmbedder': model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) conditioner.embedders[i] = m.cond_stage_model if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 6c17a81d..b3771909 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -323,3 +323,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) return embedded + + +class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 2e92479a..04c09ab0 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -12,7 +12,6 @@ sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "conf config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") -config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1dd4459f..b799ff46 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -22,7 +22,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), } - c = self.conditioner(sdxl_conds) + force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c -- cgit v1.2.1 From 6f23da603d3cbba82262a3c62cc44c8d5cb9e6db Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 16:18:39 +0300 Subject: fix broken img2img --- modules/sd_models_xl.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index b799ff46..b19036f1 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -32,6 +32,9 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): return self.model(x, t, cond) +def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility + return x + def extend_sdxl(model): dtype = next(model.model.diffusion_model.parameters()).dtype model.model.diffusion_model.dtype = dtype @@ -50,6 +53,7 @@ def extend_sdxl(model): sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding sgm.modules.attention.print = lambda *args: None sgm.modules.diffusionmodules.model.print = lambda *args: None -- cgit v1.2.1 From b8159d0919dcaa3a1a8f29e3aa30c25fe8e5f13b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:24:54 +0300 Subject: add XL support for live previews: approx and TAESD --- modules/sd_models_xl.py | 2 +- modules/sd_vae_approx.py | 37 ++++++++++++++++++++++++++----------- modules/sd_vae_taesd.py | 26 +++++++++++++------------- 3 files changed, 40 insertions(+), 25 deletions(-) (limited to 'modules') diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index b19036f1..af445a61 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -48,7 +48,7 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) - model.is_xl = True + model.is_sdxl = True sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index e2f00468..b348f3ae 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -2,9 +2,9 @@ import os import torch from torch import nn -from modules import devices, paths +from modules import devices, paths, shared -sd_vae_approx_model = None +sd_vae_approx_models = {} class VAEApprox(nn.Module): @@ -31,19 +31,34 @@ class VAEApprox(nn.Module): return x +def download_model(model_path, model_url): + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading VAEApprox model to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + def model(): - global sd_vae_approx_model + model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" + loaded_model = sd_vae_approx_models.get(model_name) - if sd_vae_approx_model is None: - model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") - sd_vae_approx_model = VAEApprox() + if loaded_model is None: + model_path = os.path.join(paths.models_path, "VAE-approx", model_name) if not os.path.exists(model_path): - model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") - sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) - sd_vae_approx_model.eval() - sd_vae_approx_model.to(devices.device, devices.dtype) + model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name) + + if not os.path.exists(model_path): + model_path = os.path.join(paths.models_path, "VAE-approx", model_name) + download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) + + loaded_model = VAEApprox() + loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_approx_models[model_name] = loaded_model - return sd_vae_approx_model + return loaded_model def cheap_approximation(sample): diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5e8496e8..5bf7c76e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -8,9 +8,9 @@ import os import torch import torch.nn as nn -from modules import devices, paths_internal +from modules import devices, paths_internal, shared -sd_vae_taesd = None +sd_vae_taesd_models = {} def conv(n_in, n_out, **kwargs): @@ -61,9 +61,7 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def download_model(model_path): - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - +def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) @@ -72,17 +70,19 @@ def download_model(model_path): def model(): - global sd_vae_taesd + model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) - if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") - download_model(model_path) + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - sd_vae_taesd = TAESD(model_path) - sd_vae_taesd.eval() - sd_vae_taesd.to(devices.device, devices.dtype) + loaded_model = TAESD(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model else: raise FileNotFoundError('TAESD model not found') - return sd_vae_taesd.decoder + return loaded_model.decoder -- cgit v1.2.1 From e16ebc917dfc902f041963df0d4e99e8141cf82f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:32:35 +0300 Subject: repair --no-half for SDXL --- modules/sd_models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index e4aae597..9e8cb3cf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -395,10 +395,11 @@ def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + if hasattr(sd_config.model.params, 'unet_config'): + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" -- cgit v1.2.1 From ff73841c608f5f02e6352bb235d9dbf63d922990 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:42:16 +0300 Subject: mute SDXL imports in the place there SDXL is imported for the first time instead of launch.py --- modules/launch_utils.py | 18 ------------------ modules/paths.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 18 deletions(-) (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 183730d2..01ea7c91 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -224,22 +224,6 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) -def mute_sdxl_imports(): - """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" - - class Dummy: - pass - - module = Dummy() - module.LPIPS = None - sys.modules['taming.modules.losses.lpips'] = module - - module = Dummy() - module.StableDataModuleFromConfig = None - sys.modules['sgm.data'] = module - - - def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -356,8 +340,6 @@ def configure_for_tests(): def start(): - mute_sdxl_imports() - print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: diff --git a/modules/paths.py b/modules/paths.py index c6f8904e..25052339 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio import modules.safe # noqa: F401 +def mute_sdxl_imports(): + """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" + + class Dummy: + pass + + module = Dummy() + module.LPIPS = None + sys.modules['taming.modules.losses.lpips'] = module + + module = Dummy() + module.StableDataModuleFromConfig = None + sys.modules['sgm.data'] = module + + # data_path = cmd_opts_pre.data sys.path.insert(0, script_path) @@ -18,6 +33,8 @@ for possible_sd_path in possible_sd_paths: assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" +mute_sdxl_imports() + path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), -- cgit v1.2.1 From 6c5f83b19b331d51bde28c5033d13d0d64c11e54 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 21:17:50 +0300 Subject: add support for SDXL loras with te1/te2 modules --- modules/sd_models.py | 3 ++- modules/sd_models_xl.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 9e8cb3cf..07702175 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,7 +289,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if hasattr(model, 'conditioner'): + model.is_sdxl = hasattr(model, 'conditioner') + if model.is_sdxl: sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index af445a61..a7240dc0 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -48,7 +48,6 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) - model.is_sdxl = True sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -- cgit v1.2.1 From a3db187e4f1ef59a571b1023309923f5e5e6dda3 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 14 Jul 2023 05:48:14 +0900 Subject: handles model hash cache.json error --- modules/hashes.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/hashes.py b/modules/hashes.py index 8b7ea0ac..ec1187fe 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -5,7 +5,7 @@ import os.path import filelock from modules import shared -from modules.paths import data_path +from modules.paths import data_path, script_path cache_filename = os.path.join(data_path, "cache.json") @@ -26,8 +26,13 @@ def cache(subsection): if not os.path.isfile(cache_filename): cache_data = {} else: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + cache_data = {} s = cache_data.get(subsection, {}) cache_data[subsection] = s -- cgit v1.2.1 From 6d8dcdefa07d5f8f7e528046b0facdcc51185e60 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:16:01 +0300 Subject: initial SDXL refiner support --- modules/sd_hijack.py | 18 ++++++++++---- modules/sd_models.py | 3 ++- modules/sd_models_config.py | 3 +++ modules/sd_models_xl.py | 57 ++++++++++++++++++++++++++++++++++++--------- modules/shared.py | 9 +++++-- 5 files changed, 71 insertions(+), 19 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 647cdfbe..2b274c18 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,21 +180,29 @@ class StableDiffusionModelHijack: def hijack(self, m): conditioner = getattr(m, 'conditioner', None) if conditioner: + text_cond_models = [] + for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] typename = type(embedder).__name__ if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) - m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenCLIPEmbedder': - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + + if len(text_cond_models) == 1: + m.cond_stage_model = text_cond_models[0] + else: + m.cond_stage_model = conditioner if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_models.py b/modules/sd_models.py index 07702175..267f4d8e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -414,6 +414,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' +sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: @@ -477,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict + clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) timer.record("find config") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 04c09ab0..8266fa39 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -72,6 +73,8 @@ def guess_model_config_from_state_dict(sd, filename): if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: + return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index a7240dc0..01320c7a 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -14,15 +14,20 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: width = getattr(self, 'target_width', 1024) height = getattr(self, 'target_height', 1024) + is_negative_prompt = getattr(batch, 'is_negative_prompt', False) + aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score + + devices_args = dict(device=devices.device, dtype=devices.dtype) sdxl_conds = { "txt": batch, - "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), + "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), } - force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c @@ -35,25 +40,55 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility return x + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding + + +def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): + res = [] + + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: + encoded = embedder.encode_embedding_init_text(init_text, nvpt) + res.append(encoded) + + return torch.cat(res, dim=1) + + +def process_texts(self, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: + return embedder.process_texts(texts) + + +def get_target_prompt_token_count(self, token_count): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: + return embedder.get_target_prompt_token_count(token_count) + + +# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.process_texts = process_texts +sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count + + def extend_sdxl(model): + """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" + dtype = next(model.model.diffusion_model.parameters()).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - - model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] - model.cond_stage_key = model.cond_stage_model.input_key + model.cond_stage_key = 'txt' + # model.cond_stage_model will be set in sd_hijack model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.conditioner.wrapped = torch.nn.Module() -sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -sgm.models.diffusion.DiffusionEngine.apply_model = apply_model -sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding - sgm.modules.attention.print = lambda *args: None sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None diff --git a/modules/shared.py b/modules/shared.py index 71afd94f..234ede0d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,8 +428,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), - "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), - "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), +})) + +options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { + "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), + "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), + "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), + "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { -- cgit v1.2.1 From b7dbeda0d9e475aafa9db0cfe015bf724502ec20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:19:08 +0300 Subject: linter --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 267f4d8e..729f03d7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -478,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) + clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) timer.record("find config") -- cgit v1.2.1 From abb948dab09841571dd24c6be9ff9d6b212778ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:28:01 +0300 Subject: raise maximum Negative Guidance minimum sigma due to request in PR discussion --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 234ede0d..89b7132e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -439,7 +439,7 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { options_templates.update(options_section(('optimizations', "Optimizations"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), - "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), + "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), -- cgit v1.2.1 From 9a3f35b028a8026291679c35e1df5b2aea327a1d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:56:01 +0300 Subject: repair medvram and lowvram --- modules/lowvram.py | 4 +++- modules/sd_hijack_open_clip.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/lowvram.py b/modules/lowvram.py index da4f33a8..6bbc11eb 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -100,7 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) if sd_model.embedder: sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) - parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model + + if hasattr(sd_model, 'cond_stage_model'): + parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model if use_medvram: sd_model.model.register_forward_pre_hook(send_me_to_gpu) diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index fcf5ad07..bb0b96c7 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit def encode_embedding_init_text(self, init_text, nvpt): ids = tokenizer.encode(init_text) ids = torch.asarray([ids], device=devices.device, dtype=torch.int) - embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) return embedded @@ -66,6 +66,6 @@ class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWi def encode_embedding_init_text(self, init_text, nvpt): ids = tokenizer.encode(init_text) ids = torch.asarray([ids], device=devices.device, dtype=torch.int) - embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) return embedded -- cgit v1.2.1 From 471a5a66b73921d569242daccc5275cb195e3f06 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 17:54:09 +0300 Subject: add more relevant fields to caching conds --- modules/processing.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index f01a6907..f68e010d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -330,8 +330,21 @@ class StableDiffusionProcessing: caches is a list with items described above. """ + + cached_params = ( + required_prompts, + steps, + opts.CLIP_stop_at_last_layers, + shared.sd_model.sd_checkpoint_info, + extra_network_data, + opts.sdxl_crop_left, + opts.sdxl_crop_top, + self.width, + self.height, + ) + for cache in caches: - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: + if cache[0] is not None and cached_params == cache[0]: return cache[1] cache = caches[0] @@ -339,7 +352,7 @@ class StableDiffusionProcessing: with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) - cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) + cache[0] = cached_params return cache[1] def setup_conds(self): -- cgit v1.2.1 From ac2d47ff4c00b041cae3d882c2832662c2c64935 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 20:27:41 +0300 Subject: add cheap VAE approximation coeffs for SDXL --- modules/sd_vae_approx.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index b348f3ae..86bd658a 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -64,12 +64,22 @@ def model(): def cheap_approximation(sample): # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - coefs = torch.tensor([ - [0.298, 0.207, 0.208], - [0.187, 0.286, 0.173], - [-0.158, 0.189, 0.264], - [-0.184, -0.271, -0.473], - ]).to(sample.device) + if shared.sd_model.is_sdxl: + coeffs = [ + [ 0.3448, 0.4168, 0.4395], + [-0.1953, -0.0290, 0.0250], + [ 0.1074, 0.0886, -0.0163], + [-0.3730, -0.2499, -0.2088], + ] + else: + coeffs = [ + [ 0.298, 0.207, 0.208], + [ 0.187, 0.286, 0.173], + [-0.158, 0.189, 0.264], + [-0.184, -0.271, -0.473], + ] + + coefs = torch.tensor(coeffs).to(sample.device) x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) -- cgit v1.2.1 From 5dee0fa1f812cf9f5fa6675c22c9a57afad39983 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 21:41:21 +0300 Subject: add a message about unsupported samplers --- modules/sd_samplers.py | 3 +++ modules/sd_samplers_compvis.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f22aad8f..bea2684c 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -28,6 +28,9 @@ def create_sampler(name, model): assert config is not None, f'bad sampler name: {name}' + if model.is_sdxl and config.options.get("no_sdxl", False): + raise Exception(f"Sampler {config.name} is not supported for SDXL") + sampler = config.constructor(model) sampler.config = config diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index bdae8b40..4a8396f9 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc samplers_data_compvis = [ - sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}), - sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), - sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}), + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}), + sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}), + sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}), ] -- cgit v1.2.1 From 95ee0cb18817df3c4fae2e7ba7063b79b0c60b9c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 22:51:58 +0300 Subject: restyle time taken/VRAM display --- modules/call_queue.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/call_queue.py b/modules/call_queue.py index 3b94f8a4..61aa240f 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -85,9 +85,9 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): elapsed = time.perf_counter() - t elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 - elapsed_text = f"{elapsed_s:.2f}s" + elapsed_text = f"{elapsed_s:.1f} sec." if elapsed_m > 0: - elapsed_text = f"{elapsed_m}m "+elapsed_text + elapsed_text = f"{elapsed_m} min. "+elapsed_text if run_memmon: mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} @@ -95,14 +95,22 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): reserved_peak = mem_stats['reserved_peak'] sys_peak = mem_stats['system_peak'] sys_total = mem_stats['total'] - sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) + sys_pct = sys_peak/max(sys_total, 1) * 100 - vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)" + toltip_r = "Reserved: total amout of video memory allocated by the Torch library " + toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity" + + text_a = f"A: {active_peak/1024:.2f} GB" + text_r = f"R: {reserved_peak/1024:.2f} GB" + text_sys = f"Sys: {sys_peak/1024:.1f}/{sys_total/1024:g} GB ({sys_pct:.1f}%)" + + vram_html = f"

{text_a}, {text_r}, {text_sys}

" else: vram_html = '' # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" return tuple(res) -- cgit v1.2.1 From 5d94088eac401545286ef8b16455cc88e9797300 Mon Sep 17 00:00:00 2001 From: Marcus Adams Date: Fri, 14 Jul 2023 21:52:00 -0400 Subject: Added [none] filename token. --- modules/images.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 4bdedb7f..fb5d2e75 100644 --- a/modules/images.py +++ b/modules/images.py @@ -380,6 +380,7 @@ class FilenameGenerator: 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, 'user': lambda self: self.p.user, 'vae_filename': lambda self: self.get_vae_filename(), + 'none': lambda self: '', # Overrides the default so you can get just the sequence number } default_time_format = '%Y%m%d%H%M%S' @@ -601,13 +602,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i else: file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]" + file_decoration = namegen.apply(file_decoration) + suffix + add_number = opts.save_images_add_number or file_decoration == '' if file_decoration != "" and add_number: file_decoration = f"-{file_decoration}" - file_decoration = namegen.apply(file_decoration) + suffix - if add_number: basecount = get_next_sequence_number(path, basename) fullfn = None -- cgit v1.2.1 From 14cf434bc36d0ef31f31d4c6cd2bd15d7857d5c8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 07:33:16 +0300 Subject: fix an issue in live previews that happens when you use SDXL with fp16 VAE --- modules/processing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index f68e010d..eb4a60eb 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -539,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): - with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x) + x = model.decode_first_stage(x.to(devices.dtype_vae)) return x -- cgit v1.2.1 From b8bd8ce4cf687e9e02000387adf4d751b22a4a36 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 07:44:37 +0300 Subject: disable rich exception output in console for API by default, use WEBUI_RICH_EXCEPTIONS env var to enable --- modules/api/api.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 11045292..2a4cd8a2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,5 +1,6 @@ import base64 import io +import os import time import datetime import uvicorn @@ -98,14 +99,16 @@ def encode_pil_to_base64(image): def api_middleware(app: FastAPI): - rich_available = True + rich_available = False try: - import anyio # importing just so it can be placed on silent list - import starlette # importing just so it can be placed on silent list - from rich.console import Console - console = Console() + if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None: + import anyio # importing just so it can be placed on silent list + import starlette # importing just so it can be placed on silent list + from rich.console import Console + console = Console() + rich_available = True except Exception: - rich_available = False + pass @app.middleware("http") async def log_and_time(req: Request, call_next): @@ -116,14 +119,14 @@ def api_middleware(app: FastAPI): endpoint = req.scope.get('path', 'err') if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( - t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), - code = res.status_code, - ver = req.scope.get('http_version', '0.0'), - cli = req.scope.get('client', ('0:0.0.0', 0))[0], - prot = req.scope.get('scheme', 'err'), - method = req.scope.get('method', 'err'), - endpoint = endpoint, - duration = duration, + t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code=res.status_code, + ver=req.scope.get('http_version', '0.0'), + cli=req.scope.get('client', ('0:0.0.0', 0))[0], + prot=req.scope.get('scheme', 'err'), + method=req.scope.get('method', 'err'), + endpoint=endpoint, + duration=duration, )) return res @@ -134,7 +137,7 @@ def api_middleware(app: FastAPI): "body": vars(e).get('body', ''), "errors": str(e), } - if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions + if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions message = f"API error: {request.method}: {request.url} {err}" if rich_available: print(message) -- cgit v1.2.1 From 127635409a7959f6c057a68ccb8e70734cbaf9f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 08:07:25 +0300 Subject: add padding and identification to generation log section (Failed to find Loras, Used embeddings, etc...) --- modules/img2img.py | 2 +- modules/txt2img.py | 2 +- modules/ui.py | 3 +-- modules/ui_common.py | 9 +++++---- 4 files changed, 8 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/img2img.py b/modules/img2img.py index 664e2688..a811e7a4 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -240,4 +240,4 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if opts.do_not_show_images: processed.images = [] - return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") diff --git a/modules/txt2img.py b/modules/txt2img.py index d0be2e73..29d94e8c 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -70,4 +70,4 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step if opts.do_not_show_images: processed.images = [] - return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") diff --git a/modules/ui.py b/modules/ui.py index 39d226ad..07ecee7b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -83,8 +83,7 @@ detect_image_size_symbol = '\U0001F4D0' # 📐 up_down_symbol = '\u2195\ufe0f' # ↕️ -def plaintext_to_html(text): - return ui_common.plaintext_to_html(text) +plaintext_to_html = ui_common.plaintext_to_html def send_gradio_gallery_to_image(x): diff --git a/modules/ui_common.py b/modules/ui_common.py index 57c2d0ad..11eb2a4b 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -29,9 +29,10 @@ def update_generation_info(generation_info, html_info, img_index): return html_info, gr.update() -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text +def plaintext_to_html(text, classname=None): + content = "
\n".join(html.escape(x) for x in text.split('\n')) + + return f"

{content}

" if classname else f"

{content}

" def save_files(js_data, images, do_make_zip, index): @@ -157,7 +158,7 @@ Requested path was: {f} with gr.Group(): html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': -- cgit v1.2.1 From 2b1bae0d755c2d5201f6a6aadeadb5588208d43f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 08:41:22 +0300 Subject: add textual inversion hashes to infotext --- modules/processing.py | 7 ++++--- modules/sd_hijack.py | 5 ++++- modules/sd_hijack_clip.py | 15 ++++++++++++--- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 9 ++++++++- 5 files changed, 29 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index cd568a20..49441e77 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -732,9 +732,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds() - if len(model_hijack.comments) > 0: - for comment in model_hijack.comments: - comments[comment] = 1 + for comment in model_hijack.comments: + comments[comment] = 1 + + p.extra_generation_params.update(model_hijack.extra_generation_params) if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3b6f95ce..6b5aae4b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -147,7 +147,6 @@ def undo_weighted_forward(sd_model): class StableDiffusionModelHijack: fixes = None - comments = [] layers = None circular_enabled = False clip = None @@ -156,6 +155,9 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() def __init__(self): + self.extra_generation_params = {} + self.comments = [] + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def apply_optimizations(self, option=None): @@ -236,6 +238,7 @@ class StableDiffusionModelHijack: def clear_comments(self): self.comments = [] + self.extra_generation_params = {} def get_prompt_lengths(self, text): if self.clip is None: diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 3b5a7666..c1d780a3 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -229,9 +229,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.process_tokens(tokens, multipliers) zs.append(z) - if len(used_embeddings) > 0: - embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) - self.hijack.comments.append(f"Used embeddings: {embeddings_list}") + if opts.textual_inversion_add_hashes_to_infotext and used_embeddings: + hashes = [] + for name, embedding in used_embeddings.items(): + shorthash = embedding.shorthash + if not shorthash: + continue + + name = name.replace(":", "").replace(",", "") + hashes.append(f"{name}: {shorthash}") + + if hashes: + self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) return torch.hstack(zs) diff --git a/modules/shared.py b/modules/shared.py index 48478a68..a32fd4ed 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -472,6 +472,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), + "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks), })) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cbe975b7..38e072a8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -49,6 +49,8 @@ class Embedding: self.sd_checkpoint_name = None self.optimizer_state_dict = None self.filename = None + self.hash = None + self.shorthash = None def save(self, filename): embedding_data = { @@ -82,6 +84,10 @@ class Embedding: self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + class DirWithTextualInversionEmbeddings: def __init__(self, path): @@ -199,6 +205,7 @@ class EmbeddingDatabase: embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] embedding.filename = path + embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) -- cgit v1.2.1 From 510e5fc8c60dd6278d0bc52effc23257c717dc1b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 09:20:43 +0300 Subject: cache git extension repo information --- modules/cache.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++++ modules/extensions.py | 26 ++++++++++--- modules/hashes.py | 38 ++----------------- modules/ui_extensions.py | 6 --- 4 files changed, 119 insertions(+), 47 deletions(-) create mode 100644 modules/cache.py (limited to 'modules') diff --git a/modules/cache.py b/modules/cache.py new file mode 100644 index 00000000..4c2db604 --- /dev/null +++ b/modules/cache.py @@ -0,0 +1,96 @@ +import json +import os.path + +import filelock + +from modules.paths import data_path, script_path + +cache_filename = os.path.join(data_path, "cache.json") +cache_data = None + + +def dump_cache(): + """ + Saves all cache data to a file. + """ + + with filelock.FileLock(f"{cache_filename}.lock"): + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + + +def cache(subsection): + """ + Retrieves or initializes a cache for a specific subsection. + + Parameters: + subsection (str): The subsection identifier for the cache. + + Returns: + dict: The cache data for the specified subsection. + """ + + global cache_data + + if cache_data is None: + with filelock.FileLock(f"{cache_filename}.lock"): + if not os.path.isfile(cache_filename): + cache_data = {} + else: + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + cache_data = {} + + s = cache_data.get(subsection, {}) + cache_data[subsection] = s + + return s + + +def cached_data_for_file(subsection, title, filename, func): + """ + Retrieves or generates data for a specific file, using a caching mechanism. + + Parameters: + subsection (str): The subsection of the cache to use. + title (str): The title of the data entry in the subsection of the cache. + filename (str): The path to the file to be checked for modifications. + func (callable): A function that generates the data if it is not available in the cache. + + Returns: + dict or None: The cached or generated data, or None if data generation fails. + + The `cached_data_for_file` function implements a caching mechanism for data stored in files. + It checks if the data associated with the given `title` is present in the cache and compares the + modification time of the file with the cached modification time. If the file has been modified, + the cache is considered invalid and the data is regenerated using the provided `func`. + Otherwise, the cached data is returned. + + If the data generation fails, None is returned to indicate the failure. Otherwise, the generated + or cached data is returned as a dictionary. + """ + + existing_cache = cache(subsection) + ondisk_mtime = os.path.getmtime(filename) + + entry = existing_cache.get(title) + if entry: + cached_mtime = existing_cache[title].get("mtime", 0) + if ondisk_mtime > cached_mtime: + entry = None + + if not entry: + entry = func() + if entry is None: + return None + + entry['mtime'] = ondisk_mtime + existing_cache[title] = entry + + dump_cache() + + return entry diff --git a/modules/extensions.py b/modules/extensions.py index abc6e2b1..c561159a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,7 @@ import os import threading -from modules import shared, errors +from modules import shared, errors, cache from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 @@ -21,6 +21,7 @@ def active(): class Extension: lock = threading.Lock() + cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] def __init__(self, name, path, enabled=True, is_builtin=False): self.name = name @@ -36,15 +37,29 @@ class Extension: self.remote = None self.have_info_from_repo = False + def to_dict(self): + return {x: getattr(self, x) for x in self.cached_fields} + + def from_dict(self, d): + for field in self.cached_fields: + setattr(self, field, d[field]) + def read_info_from_repo(self): if self.is_builtin or self.have_info_from_repo: return - with self.lock: - if self.have_info_from_repo: - return + def read_from_repo(): + with self.lock: + if self.have_info_from_repo: + return + + self.do_read_info_from_repo() + + return self.to_dict() - self.do_read_info_from_repo() + d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) + self.from_dict(d) + self.status = 'unknown' def do_read_info_from_repo(self): repo = None @@ -58,7 +73,6 @@ class Extension: self.remote = None else: try: - self.status = 'unknown' self.remote = next(repo.remote().urls, None) commit = repo.head.commit self.commit_date = commit.committed_date diff --git a/modules/hashes.py b/modules/hashes.py index ec1187fe..b7a33b42 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -1,43 +1,11 @@ import hashlib -import json import os.path -import filelock - from modules import shared -from modules.paths import data_path, script_path - - -cache_filename = os.path.join(data_path, "cache.json") -cache_data = None - - -def dump_cache(): - with filelock.FileLock(f"{cache_filename}.lock"): - with open(cache_filename, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4) - - -def cache(subsection): - global cache_data - - if cache_data is None: - with filelock.FileLock(f"{cache_filename}.lock"): - if not os.path.isfile(cache_filename): - cache_data = {} - else: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') - cache_data = {} - - s = cache_data.get(subsection, {}) - cache_data[subsection] = s +import modules.cache - return s +dump_cache = modules.cache.dump_cache +cache = modules.cache.cache def calculate_sha256(filename): diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index dff522ef..3fa3dea2 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -513,14 +513,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" def preload_extensions_git_metadata(): - t0 = time.time() for extension in extensions.extensions: extension.read_info_from_repo() - print( - f"preload_extensions_git_metadata for " - f"{len(extensions.extensions)} extensions took " - f"{time.time() - t0:.2f}s" - ) def create_ui(): -- cgit v1.2.1 From 0aa8d538e147ba87df36d8196845807c8fa3f4e1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 09:24:22 +0300 Subject: suppress printing TI embedding into console by default --- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index a32fd4ed..427dcc50 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -472,6 +472,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), + "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks), })) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 38e072a8..6166c76f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -256,7 +256,7 @@ class EmbeddingDatabase: self.word_embeddings.update(sorted_word_embeddings) displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) - if self.previously_displayed_embeddings != displayed_embeddings: + if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings: self.previously_displayed_embeddings = displayed_embeddings print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if self.skipped_embeddings: -- cgit v1.2.1 From c58cf73c806f08eb8b96bccc2af64403d903695f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 09:33:21 +0300 Subject: remove "## " from changelog.md version --- modules/launch_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0e0dbca4..ff77cbfd 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -69,10 +69,12 @@ def git_tag(): return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() except Exception: try: - from pathlib import Path - changelog_md = Path(__file__).parent.parent / "CHANGELOG.md" - with changelog_md.open(encoding="utf-8") as file: - return next((line.strip() for line in file if line.strip()), "") + + changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md") + with open(changelog_md, "r", encoding="utf-8") as file: + line = next((line.strip() for line in file if line.strip()), "") + line = line.replace("## ", "") + return line except Exception: return "" -- cgit v1.2.1 From 2d9d53be21d339a0723276517fff067db0181af5 Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Sat, 15 Jul 2023 17:09:51 +0800 Subject: allow replacing extensions index with environment variable --- modules/ui_extensions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 3fa3dea2..f3e4fba7 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,5 +1,5 @@ import json -import os.path +import os import threading import time from datetime import datetime @@ -564,7 +564,8 @@ def create_ui(): with gr.TabItem("Available", id="available"): with gr.Row(): refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") - available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False) + extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json") + available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False) extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) -- cgit v1.2.1 From e5d3ae2bf4e9d39c35e6edc96d6449fd42528e55 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 20:39:04 +0300 Subject: user metadata system for custom networks --- modules/ui_extra_networks.py | 54 +++++++-- modules/ui_extra_networks_checkpoints.py | 2 +- modules/ui_extra_networks_hypernets.py | 6 +- modules/ui_extra_networks_user_metadata.py | 169 +++++++++++++++++++++++++++++ 4 files changed, 220 insertions(+), 11 deletions(-) create mode 100644 modules/ui_extra_networks_user_metadata.py (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 693cafb6..eaae6217 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,7 +2,7 @@ import os.path import urllib.parse from pathlib import Path -from modules import shared +from modules import shared, ui_extra_networks_user_metadata, errors from modules.images import read_info_from_image, save_image_with_geninfo from modules.ui import up_down_symbol import gradio as gr @@ -60,13 +60,34 @@ class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() + self.id_page = self.name.replace(" ", "_") self.card_page = shared.html("extra-networks-card.html") self.allow_negative_prompt = False self.metadata = {} + self.items = {} def refresh(self): pass + def read_user_metadata(self, item): + filename = item.get("filename", None) + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + desc = metadata.get("description", None) + if desc is not None: + item["description"] = desc + + item["user_metadata"] = metadata + def link_preview(self, filename): quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) mtime = os.path.getmtime(filename) @@ -119,11 +140,15 @@ class ExtraNetworksPage: """ for subdir in subdirs]) - for item in self.list_items(): + self.items = {x["name"]: x for x in self.list_items()} + for item in self.items.values(): metadata = item.get("metadata") if metadata: self.metadata[item["name"]] = metadata + if "user_metadata" not in item: + self.read_user_metadata(item) + items_html += self.create_html_for_item(item, tabname) if items_html == '': @@ -166,7 +191,9 @@ class ExtraNetworksPage: metadata_button = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" + metadata_button = f"" + + edit_button = f"
" local_path = "" filename = item.get("filename", "") @@ -200,6 +227,7 @@ class ExtraNetworksPage: "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), "metadata_button": metadata_button, + "edit_button": edit_button, "search_only": " search_only" if search_only else "", "sort_keys": sort_keys, } @@ -247,6 +275,9 @@ class ExtraNetworksPage: pass return None + def create_user_metadata_editor(self, ui, tabname): + return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self) + def initialize(): extra_pages.clear() @@ -297,20 +328,23 @@ def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] + ui.user_metadata_editors = [] ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs"): for page in ui.stored_extra_pages: - page_id = page.title.lower().replace(" ", "_") - - with gr.Tab(page.title, id=page_id): - elem_id = f"{tabname}_{page_id}_cards_html" + with gr.Tab(page.title, id=page.id_page): + elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[]) + editor = page.create_user_metadata_editor(ui, tabname) + editor.create_ui() + ui.user_metadata_editors.append(editor) + gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder") @@ -363,6 +397,8 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): def save_preview(index, images, filename): + # this function is here for backwards compatibility and likely will be removed soon + if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] @@ -394,3 +430,7 @@ def setup_ui(ui, gallery): outputs=[*ui.pages] ) + for editor in ui.user_metadata_editors: + editor.setup_ui(gallery) + + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 8b9ab71b..bb5071e6 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -18,7 +18,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): path, ext = os.path.splitext(checkpoint.filename) yield { "name": checkpoint.name_for_extra, - "filename": path, + "filename": checkpoint.filename, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 7c19b532..ea0b7a44 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -12,12 +12,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): shared.reload_hypernetworks() def list_items(self): - for index, (name, path) in enumerate(shared.hypernetworks.items()): - path, ext = os.path.splitext(path) + for index, (name, full_path) in enumerate(shared.hypernetworks.items()): + path, ext = os.path.splitext(full_path) yield { "name": name, - "filename": path, + "filename": full_path, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(path), diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py new file mode 100644 index 00000000..8d20d026 --- /dev/null +++ b/modules/ui_extra_networks_user_metadata.py @@ -0,0 +1,169 @@ +import datetime +import html +import json +import os.path + +import gradio as gr + +from modules import generation_parameters_copypaste, images, sysinfo, errors + + +class UserMetadataEditor: + + def __init__(self, ui, tabname, page): + self.ui = ui + self.tabname = tabname + self.page = page + self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata" + + self.box = None + + self.edit_name_input = None + self.button_edit = None + + self.edit_name = None + self.edit_description = None + self.html_filedata = None + self.html_preview = None + + self.button_cancel = None + self.button_replace_preview = None + self.button_save = None + + def get_user_metadata(self, name): + item = self.page.items.get(name, {}) + + user_metadata = item.get('user_metadata', None) + if user_metadata is None: + user_metadata = {} + item['user_metadata'] = user_metadata + + return user_metadata + + def create_default_editor_elems(self): + with gr.Row(): + with gr.Column(scale=2): + self.edit_name = gr.HTML(elem_classes="extra-network-name") + self.edit_description = gr.Textbox(label="Description", lines=4) + self.html_filedata = gr.HTML() + + with gr.Column(scale=1, min_width=0): + self.html_preview = gr.HTML() + + def create_default_buttons(self): + + with gr.Row(): + self.button_cancel = gr.Button('Cancel') + self.button_replace_preview = gr.Button('Replace preview', variant='primary') + self.button_save = gr.Button('Save', variant='primary') + + self.button_cancel.click(fn=None, _js="closePopup") + + def get_card_html(self, name): + item = self.page.items.get(name, {}) + + preview_url = item.get("preview", None) + + if not preview_url: + filename, _ = os.path.splitext(item["filename"]) + preview_url = self.page.find_preview(filename) + item["preview"] = preview_url + + if preview_url: + preview = f''' +
+ +
+ ''' + else: + preview = "
" + + return preview + + def get_metadata_table(self, name): + item = self.page.items.get(name, {}) + try: + filename = item["filename"] + + stats = os.stat(filename) + params = [ + ('File size: ', sysinfo.pretty_bytes(stats.st_size)), + ('Created: ', datetime.datetime.fromtimestamp(stats.st_ctime).strftime('%Y-%m-%d %H:%M')), + ('Last modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), + ] + + return params + except Exception as e: + errors.display(e, f"reading info for {name}") + return [] + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + + params = self.get_metadata_table(name) + table = '
Extension + + Extension + URL Branch Version
{html.escape(ext.name)}{html.escape(ext.name)} {remote} {ext.branch} {version_link}
{html.escape(name)}
{tags_text}
{html.escape(description)}

Added: {html.escape(added)}

{html.escape(description)}

Added: {html.escape(added)}stars: {stars:,}

{install_code}
{html.escape(name)}
{tags_text}
{html.escape(description)}

Added: {html.escape(added)}stars: {stars:,}

{html.escape(description)}

+ Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}stars: {stars}

{install_code}
' + "".join(f"" for name, value in params) + '' + + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name) + + def write_user_metadata(self, name, metadata): + item = self.page.items.get(name, {}) + filename = item.get("filename", None) + basename, ext = os.path.splitext(filename) + + with open(basename + '.json', "w", encoding="utf8") as file: + json.dump(metadata, file) + + def save_user_metadata(self, name, desc): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + + self.write_user_metadata(name, user_metadata) + + def create_editor(self): + self.create_default_editor_elems() + + self.create_default_buttons() + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview])\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + self.button_save.click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[]).then(fn=None, _js="closePopup") + + def create_ui(self): + with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: + self.box = box + + self.edit_name_input = gr.Textbox("Edit user metadata card id", visible=False, elem_id=f"{self.id_part}_name") + self.button_edit = gr.Button("Edit user metadata", visible=False, elem_id=f"{self.id_part}_button") + + self.create_editor() + + def save_preview(self, index, gallery, name): + if len(gallery) == 0: + print("There is no image in gallery to save as a preview.") + return [self.get_card_html(name)] + [page.create_html(self.ui.tabname) for page in self.ui.stored_extra_pages] + + item = self.page.items.get(name, {}) + + index = int(index) + index = 0 if index < 0 else index + index = len(gallery) - 1 if index >= len(gallery) else index + + img_info = gallery[index if index >= 0 else 0] + image = generation_parameters_copypaste.image_from_url_text(img_info) + geninfo, items = images.read_info_from_image(image) + + images.save_image_with_geninfo(image, geninfo, item["local_preview"]) + + return [self.get_card_html(name)] + [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] + + def setup_ui(self, gallery): + self.button_replace_preview.click( + fn=self.save_preview, + _js="function(x, y, z){return [selected_gallery_index(), y, z]}", + inputs=[self.edit_name_input, gallery, self.edit_name_input], + outputs=[self.html_preview, *self.ui.pages] + ) + + -- cgit v1.2.1 From 11f339733de860b0b51adebe15dc945df7189edf Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 00:56:53 +0300 Subject: add lora user metadata editor dialog inspired by MrKuenning's mockup from #7458 --- modules/ui_extra_networks_user_metadata.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 8d20d026..0dbd7419 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -52,7 +52,7 @@ class UserMetadataEditor: def create_default_buttons(self): - with gr.Row(): + with gr.Row(elem_classes="edit-user-metadata-buttons"): self.button_cancel = gr.Button('Cancel') self.button_replace_preview = gr.Button('Replace preview', variant='primary') self.button_save = gr.Button('Save', variant='primary') @@ -88,8 +88,7 @@ class UserMetadataEditor: stats = os.stat(filename) params = [ ('File size: ', sysinfo.pretty_bytes(stats.st_size)), - ('Created: ', datetime.datetime.fromtimestamp(stats.st_ctime).strftime('%Y-%m-%d %H:%M')), - ('Last modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), + ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), ] return params @@ -100,7 +99,12 @@ class UserMetadataEditor: def put_values_into_components(self, name): user_metadata = self.get_user_metadata(name) - params = self.get_metadata_table(name) + try: + params = self.get_metadata_table(name) + except Exception as e: + errors.display(e, f"reading metadata info for {name}") + params = [] + table = '' + "".join(f"" for name, value in params) + '' return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name) @@ -128,7 +132,9 @@ class UserMetadataEditor: .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview])\ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) - self.button_save.click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[]).then(fn=None, _js="closePopup") + self.button_save\ + .click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[])\ + .then(fn=None, _js="extraNetworksReloadAll") def create_ui(self): with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: @@ -142,7 +148,7 @@ class UserMetadataEditor: def save_preview(self, index, gallery, name): if len(gallery) == 0: print("There is no image in gallery to save as a preview.") - return [self.get_card_html(name)] + [page.create_html(self.ui.tabname) for page in self.ui.stored_extra_pages] + return [self.get_card_html(name)] + self.regenerate_ui_pages() item = self.page.items.get(name, {}) @@ -156,7 +162,10 @@ class UserMetadataEditor: images.save_image_with_geninfo(image, geninfo, item["local_preview"]) - return [self.get_card_html(name)] + [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] + return [self.get_card_html(name)] + self.regenerate_ui_pages() + + def regenerate_ui_pages(self): + return [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] def setup_ui(self, gallery): self.button_replace_preview.click( -- cgit v1.2.1 From 8c11b126e5bd5052154c095177390f249e8e9889 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sat, 15 Jul 2023 23:43:49 -0400 Subject: 404 when thumb file not found --- modules/ui_extra_networks.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 693cafb6..a2565315 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -8,6 +8,7 @@ from modules.ui import up_down_symbol import gradio as gr import json import html +from fastapi.exceptions import HTTPException from modules.generation_parameters_copypaste import image_from_url_text @@ -26,6 +27,9 @@ def register_page(page): def fetch_file(filename: str = ""): from starlette.responses import FileResponse + if not os.path.isfile(filename): + raise HTTPException(status_code=404, detail="File not found") + if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") -- cgit v1.2.1 From a1d6ada69ac686a628e79b61b8f86d01592a7209 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 08:38:23 +0300 Subject: allow refreshing single card after editing user metadata instead of all cards --- modules/ui_extra_networks.py | 20 ++++++++++++++ modules/ui_extra_networks_checkpoints.py | 31 +++++++++++---------- modules/ui_extra_networks_hypernets.py | 31 +++++++++++---------- modules/ui_extra_networks_textual_inversion.py | 30 +++++++++++--------- modules/ui_extra_networks_user_metadata.py | 38 +++++++++++++++++--------- 5 files changed, 96 insertions(+), 54 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index eaae6217..42c4d0ac 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -51,9 +51,26 @@ def get_metadata(page: str = "", item: str = ""): return JSONResponse({"metadata": metadata}) +def get_single_card(page: str = "", tabname: str = "", name: str = ""): + from starlette.responses import JSONResponse + + page = next(iter([x for x in extra_pages if x.name == page]), None) + + try: + item = page.create_item(name) + except Exception as e: + errors.display(e, "creating item for extra network") + item = page.items.get(name) + + item_html = page.create_html_for_item(item, tabname) + + return JSONResponse({"html": item_html}) + + def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"]) + app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) class ExtraNetworksPage: @@ -168,6 +185,9 @@ class ExtraNetworksPage: return res + def create_item(self, name, index=None): + raise NotImplementedError() + def list_items(self): raise NotImplementedError() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index bb5071e6..ef8cdf35 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -12,21 +12,24 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.refresh_checkpoints() + def create_item(self, name, index=None): + checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name) + path, ext = os.path.splitext(checkpoint.filename) + return { + "name": checkpoint.name_for_extra, + "filename": checkpoint.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "local_preview": f"{path}.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, + + } + def list_items(self): - checkpoint: sd_models.CheckpointInfo - for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()): - path, ext = os.path.splitext(checkpoint.filename) - yield { - "name": checkpoint.name_for_extra, - "filename": checkpoint.filename, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), - "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', - "local_preview": f"{path}.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, - - } + for index, name in enumerate(sd_models.checkpoints_list): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index ea0b7a44..8dae23c6 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -11,21 +11,24 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.reload_hypernetworks() + def create_item(self, name, index=None): + full_path = shared.hypernetworks[name] + path, ext = os.path.splitext(full_path) + + return { + "name": name, + "filename": full_path, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(path), + "prompt": json.dumps(f""), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, + } + def list_items(self): - for index, (name, full_path) in enumerate(shared.hypernetworks.items()): - path, ext = os.path.splitext(full_path) - - yield { - "name": name, - "filename": full_path, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(path), - "prompt": json.dumps(f""), - "local_preview": f"{path}.preview.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, - - } + for index, name in enumerate(shared.hypernetworks): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return [shared.cmd_opts.hypernetwork_dir] diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 58a61c55..159f2d64 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -12,20 +12,24 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def refresh(self): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + def create_item(self, name, index=None): + embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) + + path, ext = os.path.splitext(embedding.filename) + return { + "name": name, + "filename": embedding.filename, + "preview": self.find_preview(path), + "description": self.find_description(path), + "search_term": self.search_terms_from_path(embedding.filename), + "prompt": json.dumps(embedding.name), + "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, + } + def list_items(self): - for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()): - path, ext = os.path.splitext(embedding.filename) - yield { - "name": embedding.name, - "filename": embedding.filename, - "preview": self.find_preview(path), - "description": self.find_description(path), - "search_term": self.search_terms_from_path(embedding.filename), - "prompt": json.dumps(embedding.name), - "local_preview": f"{path}.preview.{shared.opts.samples_format}", - "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, - - } + for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings): + yield self.create_item(name, index) def allowed_directories_for_previews(self): return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 0dbd7419..01ff4e4b 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -23,8 +23,10 @@ class UserMetadataEditor: self.edit_name = None self.edit_description = None + self.edit_notes = None self.html_filedata = None self.html_preview = None + self.html_status = None self.button_cancel = None self.button_replace_preview = None @@ -57,6 +59,8 @@ class UserMetadataEditor: self.button_replace_preview = gr.Button('Replace preview', variant='primary') self.button_save = gr.Button('Save', variant='primary') + self.html_status = gr.HTML(elem_classes="edit-user-metadata-status") + self.button_cancel.click(fn=None, _js="closePopup") def get_card_html(self, name): @@ -107,7 +111,7 @@ class UserMetadataEditor: table = '' + "".join(f"" for name, value in params) + '' - return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name) + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''), def write_user_metadata(self, name, metadata): item = self.page.items.get(name, {}) @@ -117,24 +121,30 @@ class UserMetadataEditor: with open(basename + '.json', "w", encoding="utf8") as file: json.dump(metadata, file) - def save_user_metadata(self, name, desc): + def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc + user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) + def setup_save_handler(self, button, func, components): + button\ + .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\ + .then(fn=None, _js="function(name){closePopup(); extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", inputs=[self.edit_name_input], outputs=[]) + def create_editor(self): self.create_default_editor_elems() + self.edit_notes = gr.TextArea(label='Notes', lines=4) + self.create_default_buttons() self.button_edit\ - .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview])\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) - self.button_save\ - .click(fn=self.save_user_metadata, inputs=[self.edit_name_input, self.edit_description], outputs=[])\ - .then(fn=None, _js="extraNetworksReloadAll") + self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes]) def create_ui(self): with gr.Box(visible=False, elem_id=self.id_part, elem_classes="edit-user-metadata") as box: @@ -147,8 +157,7 @@ class UserMetadataEditor: def save_preview(self, index, gallery, name): if len(gallery) == 0: - print("There is no image in gallery to save as a preview.") - return [self.get_card_html(name)] + self.regenerate_ui_pages() + return self.get_card_html(name), "There is no image in gallery to save as a preview." item = self.page.items.get(name, {}) @@ -162,17 +171,20 @@ class UserMetadataEditor: images.save_image_with_geninfo(image, geninfo, item["local_preview"]) - return [self.get_card_html(name)] + self.regenerate_ui_pages() - - def regenerate_ui_pages(self): - return [page.create_html(self.tabname) for page in self.ui.stored_extra_pages] + return self.get_card_html(name), '' def setup_ui(self, gallery): self.button_replace_preview.click( fn=self.save_preview, _js="function(x, y, z){return [selected_gallery_index(), y, z]}", inputs=[self.edit_name_input, gallery, self.edit_name_input], - outputs=[self.html_preview, *self.ui.pages] + outputs=[self.html_preview, self.html_status] + ).then( + fn=None, + _js="function(name){extraNetworksRefreshSingleCard(" + json.dumps(self.page.name) + "," + json.dumps(self.tabname) + ", name);}", + inputs=[self.edit_name_input], + outputs=[] ) + -- cgit v1.2.1 From 47d9dd0240872dc70fd26bc1bf309f49fe17c104 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 09:25:32 +0300 Subject: speedup extra networks listing --- modules/cache.py | 27 +++++++++++++------------- modules/ui_extra_networks.py | 20 ++++++++++++------- modules/ui_extra_networks_checkpoints.py | 4 ++-- modules/ui_extra_networks_hypernets.py | 4 ++-- modules/ui_extra_networks_textual_inversion.py | 4 ++-- 5 files changed, 33 insertions(+), 26 deletions(-) (limited to 'modules') diff --git a/modules/cache.py b/modules/cache.py index 4c2db604..07180602 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,12 +1,12 @@ import json import os.path - -import filelock +import threading from modules.paths import data_path, script_path cache_filename = os.path.join(data_path, "cache.json") cache_data = None +cache_lock = threading.Lock() def dump_cache(): @@ -14,7 +14,7 @@ def dump_cache(): Saves all cache data to a file. """ - with filelock.FileLock(f"{cache_filename}.lock"): + with cache_lock: with open(cache_filename, "w", encoding="utf8") as file: json.dump(cache_data, file, indent=4) @@ -33,17 +33,18 @@ def cache(subsection): global cache_data if cache_data is None: - with filelock.FileLock(f"{cache_filename}.lock"): - if not os.path.isfile(cache_filename): - cache_data = {} - else: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + with cache_lock: + if cache_data is None: + if not os.path.isfile(cache_filename): cache_data = {} + else: + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') + cache_data = {} s = cache_data.get(subsection, {}) cache_data[subsection] = s diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 42c4d0ac..f9d1fa31 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -73,6 +73,12 @@ def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) +def quote_js(s): + s = s.replace('\\', '\\\\') + s = s.replace('"', '\\"') + return f'"{s}"' + + class ExtraNetworksPage: def __init__(self, title): self.title = title @@ -203,7 +209,7 @@ class ExtraNetworksPage: onclick = item.get("onclick", None) if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' @@ -211,9 +217,9 @@ class ExtraNetworksPage: metadata_button = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" + metadata_button = f"" - edit_button = f"
" + edit_button = f"
" local_path = "" filename = item.get("filename", "") @@ -239,12 +245,12 @@ class ExtraNetworksPage: "background_image": background_image, "style": f"'display: none; {height}{width}'", "prompt": item.get("prompt", None), - "tabname": json.dumps(tabname), - "local_preview": json.dumps(item["local_preview"]), + "tabname": quote_js(tabname), + "local_preview": quote_js(item["local_preview"]), "name": item["name"], "description": (item.get("description") or ""), "card_clicked": onclick, - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), "metadata_button": metadata_button, "edit_button": edit_button, @@ -359,7 +365,7 @@ def create_ui(container, button, tabname): page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[]) + page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index ef8cdf35..e73b5b1f 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -1,8 +1,8 @@ import html -import json import os from modules import shared, ui_extra_networks, sd_models +from modules.ui_extra_networks import quote_js class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): @@ -21,7 +21,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), - "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', + "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 8dae23c6..e53ccb42 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -1,7 +1,7 @@ -import json import os from modules import shared, ui_extra_networks +from modules.ui_extra_networks import quote_js class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): @@ -21,7 +21,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(path), - "prompt": json.dumps(f""), + "prompt": quote_js(f""), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, } diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 159f2d64..d1794e50 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,7 +1,7 @@ -import json import os from modules import ui_extra_networks, sd_hijack, shared +from modules.ui_extra_networks import quote_js class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): @@ -22,7 +22,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(embedding.filename), - "prompt": json.dumps(embedding.name), + "prompt": quote_js(embedding.name), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, } -- cgit v1.2.1 From ccd97886da1f659472cdca3de8731f59a70bbc28 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 09:49:22 +0300 Subject: fix bogus metadata for extra networks appearing out of cache fix description editing for checkpoint not immediately appearing on cards --- modules/cache.py | 10 +++++----- modules/ui_extra_networks.py | 3 ++- modules/ui_extra_networks_checkpoints.py | 3 +-- 3 files changed, 8 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/cache.py b/modules/cache.py index 07180602..28d42a8c 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -80,18 +80,18 @@ def cached_data_for_file(subsection, title, filename, func): entry = existing_cache.get(title) if entry: - cached_mtime = existing_cache[title].get("mtime", 0) + cached_mtime = entry.get("mtime", 0) if ondisk_mtime > cached_mtime: entry = None if not entry: - entry = func() - if entry is None: + value = func() + if value is None: return None - entry['mtime'] = ondisk_mtime + entry = {'mtime': ondisk_mtime, 'value': value} existing_cache[title] = entry dump_cache() - return entry + return entry['value'] diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 760fba43..a4927c11 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -52,7 +52,7 @@ def get_metadata(page: str = "", item: str = ""): if metadata is None: return JSONResponse({}) - return JSONResponse({"metadata": metadata}) + return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)}) def get_single_card(page: str = "", tabname: str = "", name: str = ""): @@ -66,6 +66,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): errors.display(e, "creating item for extra network") item = page.items.get(name) + page.read_user_metadata(item) item_html = page.create_html_for_item(item, tabname) return JSONResponse({"html": item_html}) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index e73b5b1f..76780cfd 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -13,7 +13,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): shared.refresh_checkpoints() def create_item(self, name, index=None): - checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name) + checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) path, ext = os.path.splitext(checkpoint.filename) return { "name": checkpoint.name_for_extra, @@ -24,7 +24,6 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, - } def list_items(self): -- cgit v1.2.1 From 690d56f3c10e5359e15eeba9c68e56b2eb193ac3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 10:25:34 +0300 Subject: nuke thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs) --- modules/shared.py | 5 +++-- modules/ui_extra_networks.py | 9 ++++----- 2 files changed, 7 insertions(+), 7 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 427dcc50..f6604ef9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -466,10 +466,11 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), - "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), - "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), + "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), + "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a4927c11..d9deccb2 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -132,7 +132,6 @@ class ExtraNetworksPage: return "" def create_html(self, tabname): - view = shared.opts.extra_networks_default_view items_html = '' self.metadata = {} @@ -186,10 +185,10 @@ class ExtraNetworksPage: self_name_id = self.name.replace(" ", "_") res = f""" -
+
{subdirs_html}
-
+
{items_html}
""" @@ -248,12 +247,12 @@ class ExtraNetworksPage: args = { "background_image": background_image, - "style": f"'display: none; {height}{width}'", + "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", "prompt": item.get("prompt", None), "tabname": quote_js(tabname), "local_preview": quote_js(item["local_preview"]), "name": item["name"], - "description": (item.get("description") or ""), + "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), -- cgit v1.2.1 From 9d3dd64fe9e95873347710ca1df1f1e88d1908e1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 10:44:04 +0300 Subject: minor restyling for extra networks --- modules/ui_extra_networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index d9deccb2..6c73998f 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -11,6 +11,7 @@ import html from fastapi.exceptions import HTTPException from modules.generation_parameters_copypaste import image_from_url_text +from modules.ui_components import ToolButton extra_pages = [] allowed_dirs = set() @@ -377,7 +378,7 @@ def create_ui(container, button, tabname): gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) - gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder") + ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) -- cgit v1.2.1 From 570f42afd122405116b39b880cdb5163fd5ca3e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 12:28:50 +0300 Subject: possible fix for FP16 VAE failing in img2img SDXL --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/processing.py b/modules/processing.py index e7b10808..6567b3cf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1303,7 +1303,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(shared.device) + image = image.to(shared.device, dtype=devices.dtype_vae) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) -- cgit v1.2.1 From 67ea4eabc3e78c4b496a9fcd21aca95fd5ef7027 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 13:46:33 +0300 Subject: fix cache loading wrong entries from old cache files --- modules/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/cache.py b/modules/cache.py index 28d42a8c..ddf44637 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -84,7 +84,7 @@ def cached_data_for_file(subsection, title, filename, func): if ondisk_mtime > cached_mtime: entry = None - if not entry: + if not entry or 'value' not in entry: value = func() if value is None: return None -- cgit v1.2.1 From 9251ae3bc78e465058c286e86f3c26cb6f819a31 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 09:29:36 +0300 Subject: delay writing cache to prevent writing the same thing over and over --- modules/cache.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/cache.py b/modules/cache.py index ddf44637..71fe6302 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,6 +1,7 @@ import json import os.path import threading +import time from modules.paths import data_path, script_path @@ -8,15 +9,37 @@ cache_filename = os.path.join(data_path, "cache.json") cache_data = None cache_lock = threading.Lock() +dump_cache_after = None +dump_cache_thread = None + def dump_cache(): """ - Saves all cache data to a file. + Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written. """ + global dump_cache_after + global dump_cache_thread + + def thread_func(): + global dump_cache_after + global dump_cache_thread + + while dump_cache_after is not None and time.time() < dump_cache_after: + time.sleep(1) + + with cache_lock: + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + + dump_cache_after = None + dump_cache_thread = None + with cache_lock: - with open(cache_filename, "w", encoding="utf8") as file: - json.dump(cache_data, file, indent=4) + dump_cache_after = time.time() + 5 + if dump_cache_thread is None: + dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func) + dump_cache_thread.start() def cache(subsection): -- cgit v1.2.1 From 35510f7529dc05437a82496187ef06b852be9ab1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 10:06:02 +0300 Subject: add alias to lyco network read networks from LyCORIS dir if it exists add credits --- modules/extra_networks.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 41799b0a..6ae07e91 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -4,16 +4,22 @@ from collections import defaultdict from modules import errors extra_network_registry = {} +extra_network_aliases = {} def initialize(): extra_network_registry.clear() + extra_network_aliases.clear() def register_extra_network(extra_network): extra_network_registry[extra_network.name] = extra_network +def register_extra_network_alias(extra_network, alias): + extra_network_aliases[alias] = extra_network + + def register_default_extra_networks(): from modules.extra_networks_hypernet import ExtraNetworkHypernet register_extra_network(ExtraNetworkHypernet()) @@ -82,20 +88,26 @@ def activate(p, extra_network_data): """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list""" + activated = [] + for extra_network_name, extra_network_args in extra_network_data.items(): extra_network = extra_network_registry.get(extra_network_name, None) + + if extra_network is None: + extra_network = extra_network_aliases.get(extra_network_name, None) + if extra_network is None: print(f"Skipping unknown extra network: {extra_network_name}") continue try: extra_network.activate(p, extra_network_args) + activated.append(extra_network) except Exception as e: errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) - if args is not None: + if extra_network in activated: continue try: -- cgit v1.2.1 From 699108bfbb05c2a7d2ee4a2c7abcfaa0a244d8ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 18:56:14 +0300 Subject: hide cards for networks of incompatible stable diffusion version in Lora extra networks interface --- modules/sd_models.py | 3 +++ modules/ui_extra_networks.py | 3 ++- modules/ui_extra_networks_user_metadata.py | 7 ++++++- 3 files changed, 11 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_models.py b/modules/sd_models.py index 729f03d7..4d9382dd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -290,6 +290,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer state_dict = get_checkpoint_state_dict(checkpoint_info, timer) model.is_sdxl = hasattr(model, 'conditioner') + model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') + model.is_sd1 = not model.is_sdxl and not model.is_sd2 + if model.is_sdxl: sd_models_xl.extend_sdxl(model) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 6c73998f..49612298 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -62,7 +62,8 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): page = next(iter([x for x in extra_pages if x.name == page]), None) try: - item = page.create_item(name) + item = page.create_item(name, enable_filter=False) + page.items[name] = item except Exception as e: errors.display(e, "creating item for extra network") item = page.items.get(name) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 01ff4e4b..63d4b503 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -42,6 +42,9 @@ class UserMetadataEditor: return user_metadata + def create_extra_default_items_in_left_column(self): + pass + def create_default_editor_elems(self): with gr.Row(): with gr.Column(scale=2): @@ -49,6 +52,8 @@ class UserMetadataEditor: self.edit_description = gr.Textbox(label="Description", lines=4) self.html_filedata = gr.HTML() + self.create_extra_default_items_in_left_column() + with gr.Column(scale=1, min_width=0): self.html_preview = gr.HTML() @@ -111,7 +116,7 @@ class UserMetadataEditor: table = '' + "".join(f"" for name, value in params) + '' - return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''), + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') def write_user_metadata(self, name, metadata): item = self.page.items.get(name, {}) -- cgit v1.2.1 From a99d5708e6d603e8f7cfd1b8c6595f8026219ba0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 20:10:24 +0300 Subject: skip installing packages with pip if theyare already installed record time it took to launch --- modules/launch_utils.py | 46 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 434facbc..03552bc2 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -1,4 +1,5 @@ # this scripts installs necessary requirements and launches main program in webui.py +import re import subprocess import os import sys @@ -9,6 +10,9 @@ from functools import lru_cache from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir +from modules import timer + +timer.startup_timer.record("start") args, _ = cmd_args.parser.parse_known_args() @@ -226,6 +230,44 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) +re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") + + +def requrements_met(requirements_file): + """ + Does a simple parse of a requirements.txt file to determine if all rerqirements in it + are already installed. Returns True if so, False if not installed or parsing fails. + """ + + import importlib.metadata + import packaging.version + + with open(requirements_file, "r", encoding="utf8") as file: + for line in file: + if line.strip() == "": + continue + + m = re.match(re_requirement, line) + if m is None: + return False + + package = m.group(1).strip() + version_required = (m.group(2) or "").strip() + + if version_required == "": + continue + + try: + version_installed = importlib.metadata.version(package) + except Exception: + return False + + if packaging.version.parse(version_required) != packaging.version.parse(version_installed): + return False + + return True + + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -311,7 +353,9 @@ def prepare_environment(): if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) - run_pip(f"install -r \"{requirements_file}\"", "requirements") + + if not requrements_met(requirements_file): + run_pip(f"install -r \"{requirements_file}\"", "requirements") run_extensions_installers(settings_file=args.ui_settings_file) -- cgit v1.2.1 From f0e2098f1a533c88396536282c1d6cd7d847a51c Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 17 Jul 2023 23:39:38 -0400 Subject: Add support for `--upcast-sampling` with SD XL --- modules/sd_hijack_unet.py | 8 +++++++- modules/sd_models.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index ca1daf45..2101f1a0 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): if isinstance(cond, dict): for y in cond.keys(): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + if isinstance(cond[y], list): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + else: + cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() @@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) + +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4d9382dd..5813b550 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -326,7 +326,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer timer.record("apply half()") - devices.dtype_unet = model.model.diffusion_model.dtype + devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.1