diff options
Diffstat (limited to 'modules')
39 files changed, 1695 insertions, 829 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 9814bbc2..5d60fc0a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.extras import run_extras from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork @@ -45,10 +44,8 @@ def validate_sampler_name(name): def setUpscalers(req: dict): reqDict = vars(req) - reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) - reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) + reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def decode_base64_to_image(encoding): @@ -126,8 +123,6 @@ class Api: self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) @@ -246,7 +241,7 @@ class Api: reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) @@ -262,7 +257,7 @@ class Api: reqDict.pop('imageList') with self.queue_lock: - result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) @@ -390,12 +385,6 @@ class Api: return styleList - def get_artists_categories(self): - return shared.artist_db.cats - - def get_artists(self): - return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] - def get_embeddings(self): db = sd_hijack.model_hijack.embedding_db @@ -480,7 +469,7 @@ class Api: def train_hypernetwork(self, args: dict): try: shared.state.begin() - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -491,16 +480,15 @@ class Api: except Exception as e: error = e finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {error}".format(error = error)) + return TrainResponse(info="train embedding error: {error}".format(error=error)) def get_memory(self): try: diff --git a/modules/artists.py b/modules/artists.py deleted file mode 100644 index 3612758b..00000000 --- a/modules/artists.py +++ /dev/null @@ -1,25 +0,0 @@ -import os.path
-import csv
-from collections import namedtuple
-
-Artist = namedtuple("Artist", ['name', 'weight', 'category'])
-
-
-class ArtistsDatabase:
- def __init__(self, filename):
- self.cats = set()
- self.artists = []
-
- if not os.path.exists(filename):
- return
-
- with open(filename, "r", newline='', encoding="utf8") as file:
- reader = csv.DictReader(file)
-
- for row in reader:
- artist = Artist(row["artist"], float(row["score"]), row["category"])
- self.artists.append(artist)
- self.cats.add(artist.category)
-
- def categories(self):
- return sorted(self.cats)
diff --git a/modules/devices.py b/modules/devices.py index 6f034948..524ec7af 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -113,6 +113,9 @@ class NansException(Exception): def test_for_nans(x, where): from modules import shared + if shared.cmd_opts.disable_nan_check: + return + if not torch.all(torch.isnan(x)).item(): return @@ -166,8 +169,10 @@ orig_Tensor_cumsum = torch.Tensor.cumsum def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': output_dtype = kwargs.get('dtype', input.dtype) - if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + if output_dtype == torch.int64: return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input, *args, **kwargs) @@ -178,9 +183,10 @@ if has_mps(): torch.nn.functional.layer_norm = layer_norm_fix torch.Tensor.numpy = numpy_fix elif version.parse(torch.__version__) > version.parse("1.13.1"): - if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): - torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) - torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) + cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) orig_narrow = torch.narrow torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) diff --git a/modules/errors.py b/modules/errors.py index a668c014..f6b80dbb 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -19,11 +19,23 @@ def display(e: Exception, task): message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
print_error_explanation("""
-The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
""")
+already_displayed = {}
+
+
+def display_once(e: Exception, task):
+ if task in already_displayed:
+ return
+
+ display(e, task)
+
+ already_displayed[task] = 1
+
+
def run(code, task):
try:
code()
diff --git a/modules/extra_networks.py b/modules/extra_networks.py new file mode 100644 index 00000000..1978673d --- /dev/null +++ b/modules/extra_networks.py @@ -0,0 +1,147 @@ +import re
+from collections import defaultdict
+
+from modules import errors
+
+extra_network_registry = {}
+
+
+def initialize():
+ extra_network_registry.clear()
+
+
+def register_extra_network(extra_network):
+ extra_network_registry[extra_network.name] = extra_network
+
+
+class ExtraNetworkParams:
+ def __init__(self, items=None):
+ self.items = items or []
+
+
+class ExtraNetwork:
+ def __init__(self, name):
+ self.name = name
+
+ def activate(self, p, params_list):
+ """
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
+ Passes arguments related to this extra network in params_list.
+ User passes arguments by specifying this in his prompt:
+
+ <name:arg1:arg2:arg3>
+
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
+ separated by colon.
+
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
+ in this case, all effects of this extra networks should be disabled.
+
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
+
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
+
+ > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
+
+ params_list will be:
+
+ [
+ ExtraNetworkParams(items=["agm", "1.1"]),
+ ExtraNetworkParams(items=["ray"])
+ ]
+
+ """
+ raise NotImplementedError
+
+ def deactivate(self, p):
+ """
+ Called at the end of processing for housekeeping. No need to do anything here.
+ """
+
+ raise NotImplementedError
+
+
+def activate(p, extra_network_data):
+ """call activate for extra networks in extra_network_data in specified order, then call
+ activate for all remaining registered networks with an empty argument list"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ print(f"Skipping unknown extra network: {extra_network_name}")
+ continue
+
+ try:
+ extra_network.activate(p, extra_network_args)
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.activate(p, [])
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name}")
+
+
+def deactivate(p, extra_network_data):
+ """call deactivate for extra networks in extra_network_data in specified order, then call
+ deactivate for all remaining registered networks"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating extra network {extra_network_name}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
+
+
+re_extra_net = re.compile(r"<(\w+):([^>]+)>")
+
+
+def parse_prompt(prompt):
+ res = defaultdict(list)
+
+ def found(m):
+ name = m.group(1)
+ args = m.group(2)
+
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
+
+ return ""
+
+ prompt = re.sub(re_extra_net, found, prompt)
+
+ return prompt, res
+
+
+def parse_prompts(prompts):
+ res = []
+ extra_data = None
+
+ for prompt in prompts:
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
+
+ if extra_data is None:
+ extra_data = parsed_extra_data
+
+ res.append(updated_prompt)
+
+ return res, extra_data
+
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py new file mode 100644 index 00000000..ff279a1f --- /dev/null +++ b/modules/extra_networks_hypernet.py @@ -0,0 +1,21 @@ +from modules import extra_networks
+from modules.hypernetworks import hypernetwork
+
+
+class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('hypernet')
+
+ def activate(self, p, params_list):
+ names = []
+ multipliers = []
+ for params in params_list:
+ assert len(params.items) > 0
+
+ names.append(params.items[0])
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+ hypernetwork.load_hypernetworks(names, multipliers)
+
+ def deactivate(self, p):
+ pass
diff --git a/modules/extras.py b/modules/extras.py index 22668fcd..36123aa5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,230 +1,16 @@ -from __future__ import annotations
-import math
import os
-import sys
-import traceback
+import re
import shutil
-import numpy as np
-from PIL import Image
import torch
import tqdm
-from typing import Callable, List, OrderedDict, Tuple
-from functools import partial
-from dataclasses import dataclass
-
-from modules import processing, shared, images, devices, sd_models, sd_samplers
-from modules.shared import opts
-import modules.gfpgan_model
-from modules.ui import plaintext_to_html
-import modules.codeformer_model
+from modules import shared, images, sd_models, sd_vae
+from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
-class LruCache(OrderedDict):
- @dataclass(frozen=True)
- class Key:
- image_hash: int
- info_hash: int
- args_hash: int
-
- @dataclass
- class Value:
- image: Image.Image
- info: str
-
- def __init__(self, max_size: int = 5, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._max_size = max_size
-
- def get(self, key: LruCache.Key) -> LruCache.Value:
- ret = super().get(key)
- if ret is not None:
- self.move_to_end(key) # Move to end of eviction list
- return ret
-
- def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
- self[key] = value
- while len(self) > self._max_size:
- self.popitem(last=False)
-
-
-cached_images: LruCache = LruCache(max_size=5)
-
-
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
- devices.torch_gc()
-
- shared.state.begin()
- shared.state.job = 'extras'
-
- imageArr = []
- # Also keep track of original file names
- imageNameArr = []
- outputs = []
-
- if extras_mode == 1:
- #convert file to pillow image
- for img in image_folder:
- image = Image.open(img)
- imageArr.append(image)
- imageNameArr.append(os.path.splitext(img.orig_name)[0])
- elif extras_mode == 2:
- assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
-
- if input_dir == '':
- return outputs, "Please select an input directory.", ''
- image_list = shared.listfiles(input_dir)
- for img in image_list:
- try:
- image = Image.open(img)
- except Exception:
- continue
- imageArr.append(image)
- imageNameArr.append(img)
- else:
- imageArr.append(image)
- imageNameArr.append(None)
-
- if extras_mode == 2 and output_dir != '':
- outpath = output_dir
- else:
- outpath = opts.outdir_samples or opts.outdir_extras_samples
-
- # Extra operation definitions
-
- def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- shared.state.job = 'extras-gfpgan'
- restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
- res = Image.fromarray(restored_img)
-
- if gfpgan_visibility < 1.0:
- res = Image.blend(image, res, gfpgan_visibility)
-
- info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
- return (res, info)
-
- def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- shared.state.job = 'extras-codeformer'
- restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
- res = Image.fromarray(restored_img)
-
- if codeformer_visibility < 1.0:
- res = Image.blend(image, res, codeformer_visibility)
-
- info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
- return (res, info)
-
- def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
- shared.state.job = 'extras-upscale'
- upscaler = shared.sd_upscalers[scaler_index]
- res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
- if mode == 1 and crop:
- cropped = Image.new("RGB", (resize_w, resize_h))
- cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
- res = cropped
- return res
-
- def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
- nonlocal upscaling_resize
- if resize_mode == 1:
- upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
- crop_info = " (crop)" if upscaling_crop else ""
- info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
- return (image, info)
-
- @dataclass
- class UpscaleParams:
- upscaler_idx: int
- blend_alpha: float
-
- def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
- blended_result: Image.Image = None
- image_hash: str = hash(np.array(image.getdata()).tobytes())
- for upscaler in params:
- upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
- upscaling_resize_w, upscaling_resize_h, upscaling_crop)
- cache_key = LruCache.Key(image_hash=image_hash,
- info_hash=hash(info),
- args_hash=hash(upscale_args))
- cached_entry = cached_images.get(cache_key)
- if cached_entry is None:
- res = upscale(image, *upscale_args)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
- cached_images.put(cache_key, LruCache.Value(image=res, info=info))
- else:
- res, info = cached_entry.image, cached_entry.info
-
- if blended_result is None:
- blended_result = res
- else:
- blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
- return (blended_result, info)
-
- # Build a list of operations to run
- facefix_ops: List[Callable] = []
- facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
- facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
-
- upscale_ops: List[Callable] = []
- upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
-
- if upscaling_resize != 0:
- step_params: List[UpscaleParams] = []
- step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))
-
- upscale_ops.append(partial(run_upscalers_blend, step_params))
-
- extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
-
- for image, image_name in zip(imageArr, imageNameArr):
- if image is None:
- return outputs, "Please select an input image.", ''
-
- shared.state.textinfo = f'Processing image {image_name}'
-
- existing_pnginfo = image.info or {}
-
- image = image.convert("RGB")
- info = ""
- # Run each operation on each image
- for op in extras_ops:
- image, info = op(image, info)
-
- if opts.use_original_name_batch and image_name is not None:
- basename = os.path.splitext(os.path.basename(image_name))[0]
- else:
- basename = ''
-
- if opts.enable_pnginfo: # append info before save
- image.info = existing_pnginfo
- image.info["extras"] = info
-
- if save_output:
- # Add upscaler name as a suffix.
- suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
- # Add second upscaler if applicable.
- if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
- suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
-
- images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
-
- if extras_mode != 2 or show_extras_results :
- outputs.append(image)
-
- devices.torch_gc()
-
- return outputs, plaintext_to_html(info), ''
-
-def clear_cache():
- cached_images.clear()
-
def run_pnginfo(image):
if image is None:
@@ -251,7 +37,8 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c):
def config(x):
- return sd_models.find_checkpoint_config(x) if x else None
+ res = sd_models.find_checkpoint_config(x) if x else None
+ return res if res != shared.sd_default_config else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
@@ -274,10 +61,25 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename)
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
+checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
+
+def to_half(tensor, enable):
+ if enable and tensor.dtype == torch.float:
+ return tensor.half()
+
+ return tensor
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
shared.state.begin()
shared.state.job = 'model-merge'
+ def fail(message):
+ shared.state.textinfo = message
+ shared.state.end()
+ return [*[gr.update() for _ in range(4)], message]
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -287,57 +89,96 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
- primary_model_info = sd_models.checkpoints_list[primary_model_name]
- secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
- result_is_inpainting_model = False
+ def filename_weighted_sum():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ Ma = round(1 - multiplier, 2)
+ Mb = round(multiplier, 2)
+
+ return f"{Ma}({a}) + {Mb}({b})"
+
+ def filename_add_difference():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ c = tertiary_model_info.model_name
+ M = round(multiplier, 2)
+
+ return f"{a} + {M}({b} - {c})"
+
+ def filename_nothing():
+ return primary_model_info.model_name
theta_funcs = {
- "Weighted sum": (None, weighted_sum),
- "Add difference": (get_difference, add_difference),
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
+ "Add difference": (filename_add_difference, get_difference, add_difference),
+ "No interpolation": (filename_nothing, None, None),
}
- theta_func1, theta_func2 = theta_funcs[interp_method]
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
- if theta_func1 and not tertiary_model_info:
- shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
- shared.state.end()
- return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ if not primary_model_name:
+ return fail("Failed: Merging requires a primary model.")
+
+ primary_model_info = sd_models.checkpoints_list[primary_model_name]
+
+ if theta_func2 and not secondary_model_name:
+ return fail("Failed: Merging requires a secondary model.")
+
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
+
+ if theta_func1 and not tertiary_model_name:
+ return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
- shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
- print(f"Loading {secondary_model_info.filename}...")
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
+
+ result_is_inpainting_model = False
+
+ if theta_func2:
+ shared.state.textinfo = f"Loading B"
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ else:
+ theta_1 = None
if theta_func1:
+ shared.state.textinfo = f"Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+ shared.state.textinfo = 'Merging B and C'
+ shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
+ if key in checkpoint_dict_skip_on_merge:
+ continue
+
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
theta_1[key] = theta_func1(theta_1[key], t2)
else:
theta_1[key] = torch.zeros_like(theta_1[key])
+
+ shared.state.sampling_step += 1
del theta_2
+ shared.state.nextjob()
+
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...")
-
- chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
-
+ shared.state.textinfo = 'Merging A and B'
+ shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
- if 'model' in key and key in theta_1:
+ if theta_1 and 'model' in key and key in theta_1:
- if key in chckpoint_dict_skip_on_merge:
+ if key in checkpoint_dict_skip_on_merge:
continue
a = theta_0[key]
b = theta_1[key]
- shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -352,36 +193,45 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam else:
theta_0[key] = theta_func2(a, b, multiplier)
- if save_as_half:
- theta_0[key] = theta_0[key].half()
-
- # I believe this part should be discarded, but I'll leave it for now until I am sure
- for key in theta_1.keys():
- if 'model' in key and key not in theta_0:
+ theta_0[key] = to_half(theta_0[key], save_as_half)
- if key in chckpoint_dict_skip_on_merge:
- continue
+ shared.state.sampling_step += 1
- theta_0[key] = theta_1[key]
- if save_as_half:
- theta_0[key] = theta_0[key].half()
del theta_1
- ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
+ if bake_in_vae_filename is not None:
+ print(f"Baking in VAE from {bake_in_vae_filename}")
+ shared.state.textinfo = 'Baking in VAE'
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
+
+ for key in vae_dict.keys():
+ theta_0_key = 'first_stage_model.' + key
+ if theta_0_key in theta_0:
+ theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)
- filename = \
- primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
- secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
- interp_method.replace(" ", "_") + \
- '-merged.' + \
- ("inpainting." if result_is_inpainting_model else "") + \
- checkpoint_format
+ del vae_dict
+
+ if save_as_half and not theta_func2:
+ for key in theta_0.keys():
+ theta_0[key] = to_half(theta_0[key], save_as_half)
+
+ if discard_weights:
+ regex = re.compile(discard_weights)
+ for key in list(theta_0):
+ if re.search(regex, key):
+ theta_0.pop(key, None)
+
+ ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
+ filename = filename_generator() if custom_name == '' else custom_name
+ filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
- shared.state.textinfo = f"Saving to {output_modelname}..."
+ shared.state.nextjob()
+ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -394,8 +244,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
- print("Checkpoint saved.")
- shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ print(f"Checkpoint saved to {output_modelname}.")
+ shared.state.textinfo = "Checkpoint saved"
shared.state.end()
- return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a381ff59..46e12dc6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict): from modules import ui
settings_map = {
- 'sd_hypernetwork': 'Hypernet',
- 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash',
@@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Clip skip" not in res:
res["Clip skip"] = "1"
- if "Hypernet strength" not in res:
- res["Hypernet strength"] = "1"
-
- if "Hypernet" in res:
- hypernet_name = res["Hypernet"]
- hypernet_hash = res.get("Hypernet hash", None)
- res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+ hypernet = res.get("Hypernet", None)
+ if hypernet is not None:
+ res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c963fc40..503534e2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -25,7 +25,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
- multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
@@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module): add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__()
+ self.multiplier = 1.0
+
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
@@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x
def forward(self, x):
- return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
def trainables(self):
layer_structure = []
@@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module): return layer_structure
-def apply_strength(value=None):
- HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
if layer_structure is None:
@@ -192,6 +190,20 @@ class Hypernetwork: for param in layer.parameters():
param.requires_grad = mode
+ def to(self, device):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.to(device)
+
+ return self
+
+ def set_multiplier(self, multiplier):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.multiplier = multiplier
+
+ return self
+
def eval(self):
for k, layers in self.layers.items():
for layer in layers:
@@ -269,11 +281,13 @@ class Hypernetwork: self.optimizer_state_dict = None
if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- print("Loaded existing optimizer from checkpoint")
- print(f"Optimizer name is {self.optimizer_name}")
+ if shared.opts.print_hypernet_extra:
+ print("Loaded existing optimizer from checkpoint")
+ print(f"Optimizer name is {self.optimizer_name}")
else:
self.optimizer_name = "AdamW"
- print("No saved optimizer exists in checkpoint")
+ if shared.opts.print_hypernet_extra:
+ print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
@@ -306,23 +320,43 @@ def list_hypernetworks(path): return res
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- # Prevent any file named "None.pt" from being loaded.
- if path is not None and filename != "None":
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+def load_hypernetwork(name):
+ path = shared.hypernetworks.get(name, None)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print("Unloading hypernetwork")
+ if path is None:
+ return None
+
+ hypernetwork = Hypernetwork()
+
+ try:
+ hypernetwork.load(path)
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ return hypernetwork
+
+
+def load_hypernetworks(names, multipliers=None):
+ already_loaded = {}
+
+ for hypernetwork in shared.loaded_hypernetworks:
+ if hypernetwork.name in names:
+ already_loaded[hypernetwork.name] = hypernetwork
+
+ shared.loaded_hypernetworks.clear()
+
+ for i, name in enumerate(names):
+ hypernetwork = already_loaded.get(name, None)
+ if hypernetwork is None:
+ hypernetwork = load_hypernetwork(name)
+
+ if hypernetwork is None:
+ continue
- shared.loaded_hypernetwork = None
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
+ shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
@@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str): return applicable[0]
-def apply_hypernetwork(hypernetwork, context, layer=None):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None:
- return context, context
+ return context_k, context_v
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
- context_k = hypernetwork_layers[0](context)
- context_v = hypernetwork_layers[1](context)
+ context_k = hypernetwork_layers[0](context_k)
+ context_v = hypernetwork_layers[1](context_v)
+ return context_k, context_v
+
+
+def apply_hypernetworks(hypernetworks, context, layer=None):
+ context_k = context
+ context_v = context
+ for hypernetwork in hypernetworks:
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
+
return context_k, context_v
@@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x)
context = default(context, x)
- context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
@@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+ hypernetwork = Hypernetwork()
+ hypernetwork.load(path)
+ shared.loaded_hypernetworks = [hypernetwork]
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
@@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else:
images_dir = None
- hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
@@ -575,6 +618,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -670,6 +715,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi do_not_save_samples=True,
)
+ p.disable_extra_networks = True
+
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
@@ -724,6 +771,9 @@ Last saved image: {html.escape(last_saved_image)}<br/> pbar.close()
hypernetwork.eval()
#report_statistics(loss_dict)
+ sd_hijack_checkpoint.remove()
+
+
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 81e3f519..76599f5a 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
@@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(*args):
-
- initial_hypernetwork = shared.loaded_hypernetwork
+ shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
@@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception:
raise
finally:
- shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
diff --git a/modules/images.py b/modules/images.py index c3a5fc8b..3b1c5f34 100644 --- a/modules/images.py +++ b/modules/images.py @@ -605,8 +605,9 @@ def read_info_from_image(image): except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
- items['exif comment'] = exif_comment
- geninfo = exif_comment
+ if exif_comment:
+ items['exif comment'] = exif_comment
+ geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
diff --git a/modules/img2img.py b/modules/img2img.py index f4a03c57..2168c8e2 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_batch = mode == 5
if mode == 0: # img2img
@@ -101,7 +101,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt,
negative_prompt=negative_prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
diff --git a/modules/interrogate.py b/modules/interrogate.py index 738d8ff7..19938cbb 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -5,12 +5,13 @@ from collections import namedtuple import re
import torch
+import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram, modelloader
+from modules import devices, paths, lowvram, modelloader, errors
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.")
+def download_default_clip_interrogate_categories(content_dir):
+ print("Downloading CLIP categories...")
+
+ tmpdir = content_dir + "_tmp"
+ try:
+ os.makedirs(tmpdir)
+
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
+
+ os.rename(tmpdir, content_dir)
+
+ except Exception as e:
+ errors.display(e, "downloading default CLIP interrogate categories")
+ finally:
+ if os.path.exists(tmpdir):
+ os.remove(tmpdir)
+
+
class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
- categories = None
dtype = None
running_on_cpu = None
def __init__(self, content_dir):
- self.categories = []
+ self.loaded_categories = None
+ self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
- if os.path.exists(content_dir):
- for filename in os.listdir(content_dir):
+ def categories(self):
+ if self.loaded_categories is not None:
+ return self.loaded_categories
+
+ self.loaded_categories = []
+
+ if not os.path.exists(self.content_dir):
+ download_default_clip_interrogate_categories(self.content_dir)
+
+ if os.path.exists(self.content_dir):
+ for filename in os.listdir(self.content_dir):
m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1))
- with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
+ with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.categories.append(Category(name=filename, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
+
+ return self.loaded_categories
def load_blip_model(self):
import models.blip
@@ -139,7 +172,6 @@ class InterrogateModels: shared.state.begin()
shared.state.job = 'interrogate'
try:
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -159,12 +191,7 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True)
- if shared.opts.interrogate_use_builtin_artists:
- artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
-
- res += ", " + artist[0]
-
- for name, topn, items in self.categories:
+ for name, topn, items in self.categories():
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:
diff --git a/modules/postprocessing.py b/modules/postprocessing.py new file mode 100644 index 00000000..8514fea7 --- /dev/null +++ b/modules/postprocessing.py @@ -0,0 +1,103 @@ +import os
+
+from PIL import Image
+
+from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
+from modules.shared import opts
+
+
+def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
+ devices.torch_gc()
+
+ shared.state.begin()
+ shared.state.job = 'extras'
+
+ image_data = []
+ image_names = []
+ outputs = []
+
+ if extras_mode == 1:
+ for img in image_folder:
+ image = Image.open(img)
+ image_data.append(image)
+ image_names.append(os.path.splitext(img.orig_name)[0])
+ elif extras_mode == 2:
+ assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
+ assert input_dir, 'input directory not selected'
+
+ image_list = shared.listfiles(input_dir)
+ for filename in image_list:
+ try:
+ image = Image.open(filename)
+ except Exception:
+ continue
+ image_data.append(image)
+ image_names.append(filename)
+ else:
+ assert image, 'image not selected'
+
+ image_data.append(image)
+ image_names.append(None)
+
+ if extras_mode == 2 and output_dir != '':
+ outpath = output_dir
+ else:
+ outpath = opts.outdir_samples or opts.outdir_extras_samples
+
+ infotext = ''
+
+ for image, name in zip(image_data, image_names):
+ shared.state.textinfo = name
+
+ existing_pnginfo = image.info or {}
+
+ pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
+
+ scripts.scripts_postproc.run(pp, args)
+
+ if opts.use_original_name_batch and name is not None:
+ basename = os.path.splitext(os.path.basename(name))[0]
+ else:
+ basename = ''
+
+ infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+
+ if opts.enable_pnginfo:
+ pp.image.info = existing_pnginfo
+ pp.image.info["postprocessing"] = infotext
+
+ if save_output:
+ images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
+
+ if extras_mode != 2 or show_extras_results:
+ outputs.append(pp.image)
+
+ devices.torch_gc()
+
+ return outputs, ui_common.plaintext_to_html(infotext), ''
+
+
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
+ """old handler for API"""
+
+ args = scripts.scripts_postproc.create_args_for_run({
+ "Upscale": {
+ "upscale_mode": resize_mode,
+ "upscale_by": upscaling_resize,
+ "upscale_to_width": upscaling_resize_w,
+ "upscale_to_height": upscaling_resize_h,
+ "upscale_crop": upscaling_crop,
+ "upscaler_1_name": extras_upscaler_1,
+ "upscaler_2_name": extras_upscaler_2,
+ "upscaler_2_visibility": extras_upscaler_2_visibility,
+ },
+ "GFPGAN": {
+ "gfpgan_visibility": gfpgan_visibility,
+ },
+ "CodeFormer": {
+ "codeformer_visibility": codeformer_visibility,
+ "codeformer_weight": codeformer_weight,
+ },
+ })
+
+ return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
diff --git a/modules/processing.py b/modules/processing.py index 9c3673de..bc541e2f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -140,6 +140,7 @@ class StableDiffusionProcessing: self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
+ self.disable_extra_networks = False
if not seed_enable_extras:
self.subseed = -1
@@ -438,9 +439,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
- "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -468,14 +466,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try:
for k, v in p.override_settings.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork':
- shared.reload_hypernetworks() # make onchange call for changing hypernet
if k == 'sd_model_checkpoint':
- sd_models.reload_model_weights() # make onchange call for changing SD model
+ sd_models.reload_model_weights()
if k == 'sd_vae':
- sd_vae.reload_vae_weights() # make onchange call for changing VAE
+ sd_vae.reload_vae_weights()
res = process_images_inner(p)
@@ -484,9 +480,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks()
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
- if k == 'sd_vae': sd_vae.reload_vae_weights()
+ if k == 'sd_model_checkpoint':
+ sd_models.reload_model_weights()
+
+ if k == 'sd_vae':
+ sd_vae.reload_vae_weights()
return res
@@ -535,13 +533,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
+
if p.scripts is not None:
p.scripts.process(p)
- with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
infotexts = []
output_images = []
@@ -572,6 +568,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ if not p.disable_extra_networks:
+ extra_networks.activate(p, extra_network_data)
+
+ with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
if state.job_count == -1:
state.job_count = p.n_iter
@@ -592,6 +595,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(prompts) == 0:
break
+ prompts, _ = extra_networks.parse_prompts(prompts)
+
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
@@ -681,6 +686,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ if not p.disable_extra_networks:
+ extra_networks.deactivate(p, extra_network_data)
+
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
@@ -857,7 +865,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob()
- self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
diff --git a/modules/progress.py b/modules/progress.py index 3327b883..c69ecf3d 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -67,10 +67,13 @@ def progressapi(req: ProgressRequest): progress = 0
- if shared.state.job_count > 0:
- progress += shared.state.job_no / shared.state.job_count
- if shared.state.sampling_steps > 0:
- progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+ job_count, job_no = shared.state.job_count, shared.state.job_no
+ sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
+
+ if job_count > 0:
+ progress += job_no / job_count
+ if sampling_steps > 0 and job_count > 0:
+ progress += 1 / job_count * sampling_step / sampling_steps
progress = min(progress, 1)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 3ac0b97a..47f70251 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -38,13 +38,13 @@ class UpscalerRealESRGAN(Upscaler): return img
info = self.load_model(path)
- if not os.path.exists(info.data_path):
+ if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
- model_path=info.data_path,
+ model_path=info.local_data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
@@ -58,17 +58,13 @@ class UpscalerRealESRGAN(Upscaler): def load_model(self, path):
try:
- info = None
- for scaler in self.scalers:
- if scaler.data_path == path:
- info = scaler
+ info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
if info is None:
print(f"Unable to find model info: {path}")
return None
- model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
- info.data_path = model_file
+ info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a9e19236..4bb45ec7 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -73,6 +73,7 @@ callback_map = dict( callbacks_image_grid=[],
callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
+ callbacks_before_ui=[],
)
@@ -189,6 +190,14 @@ def script_unloaded_callback(): report_exception(c, 'script_unloaded')
+def before_ui_callback():
+ for c in reversed(callback_map['callbacks_before_ui']):
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'before_ui')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -313,3 +322,9 @@ def on_script_unloaded(callback): the script did should be reverted here"""
add_callback(callback_map['callbacks_script_unloaded'], callback)
+
+
+def on_before_ui(callback):
+ """register a function to be called before the UI is created."""
+
+ add_callback(callback_map['callbacks_before_ui'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py index 4ffc369b..03907a63 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ from collections import namedtuple import gradio as gr
from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks, extensions, script_loading
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object()
@@ -150,8 +150,10 @@ def basedir(): return current_basedir
-scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
+
+scripts_data = []
+postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
@@ -190,23 +192,31 @@ def list_files_with_name(filename): def load_scripts():
global current_basedir
scripts_data.clear()
+ postprocessing_scripts_data.clear()
script_callbacks.clear_callbacks()
scripts_list = list_scripts("scripts", ".py")
syspath = sys.path
+ def register_scripts_from_module(module):
+ for key, script_class in module.__dict__.items():
+ if type(script_class) != type:
+ continue
+
+ if issubclass(script_class, Script):
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+ elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
+ postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+
for scriptfile in sorted(scripts_list):
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
- module = script_loading.load_module(scriptfile.path)
-
- for key, script_class in module.__dict__.items():
- if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+ script_module = script_loading.load_module(scriptfile.path)
+ register_scripts_from_module(script_module)
except Exception:
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@@ -413,6 +423,7 @@ class ScriptRunner: scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
scripts_current: ScriptRunner = None
@@ -423,12 +434,13 @@ def reload_script_body_only(): def reload_scripts():
- global scripts_txt2img, scripts_img2img
+ global scripts_txt2img, scripts_img2img, scripts_postproc
load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def IOComponent_init(self, *args, **kwargs):
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py new file mode 100644 index 00000000..25de02d0 --- /dev/null +++ b/modules/scripts_postprocessing.py @@ -0,0 +1,147 @@ +import os
+import gradio as gr
+
+from modules import errors, shared
+
+
+class PostprocessedImage:
+ def __init__(self, image):
+ self.image = image
+ self.info = {}
+
+
+class ScriptPostprocessing:
+ filename = None
+ controls = None
+ args_from = None
+ args_to = None
+
+ order = 1000
+ """scripts will be ordred by this value in postprocessing UI"""
+
+ name = None
+ """this function should return the title of the script."""
+
+ group = None
+ """A gr.Group component that has all script's UI inside it"""
+
+ def ui(self):
+ """
+ This function should create gradio UI elements. See https://gradio.app/docs/#components
+ The return value should be a dictionary that maps parameter names to components used in processing.
+ Values of those components will be passed to process() function.
+ """
+
+ pass
+
+ def process(self, pp: PostprocessedImage, **args):
+ """
+ This function is called to postprocess the image.
+ args contains a dictionary with all values returned by components from ui()
+ """
+
+ pass
+
+ def image_changed(self):
+ pass
+
+
+def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
+ try:
+ res = func(*args, **kwargs)
+ return res
+ except Exception as e:
+ errors.display(e, f"calling {filename}/{funcname}")
+
+ return default
+
+
+class ScriptPostprocessingRunner:
+ def __init__(self):
+ self.scripts = None
+ self.ui_created = False
+
+ def initialize_scripts(self, scripts_data):
+ self.scripts = []
+
+ for script_class, path, basedir, script_module in scripts_data:
+ script: ScriptPostprocessing = script_class()
+ script.filename = path
+
+ self.scripts.append(script)
+
+ def create_script_ui(self, script, inputs):
+ script.args_from = len(inputs)
+ script.args_to = len(inputs)
+
+ script.controls = wrap_call(script.ui, script.filename, "ui")
+
+ for control in script.controls.values():
+ control.custom_script_source = os.path.basename(script.filename)
+
+ inputs += list(script.controls.values())
+ script.args_to = len(inputs)
+
+ def scripts_in_preferred_order(self):
+ if self.scripts is None:
+ import modules.scripts
+ self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
+
+ scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
+
+ def script_score(name):
+ name = name.lower()
+ for i, possible_match in enumerate(scripts_order):
+ if possible_match in name:
+ return i
+
+ return len(self.scripts)
+
+ script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}
+
+ return sorted(self.scripts, key=lambda x: script_scores[x.name])
+
+ def setup_ui(self):
+ inputs = []
+
+ for script in self.scripts_in_preferred_order():
+ with gr.Box() as group:
+ self.create_script_ui(script, inputs)
+
+ script.group = group
+
+ self.ui_created = True
+ return inputs
+
+ def run(self, pp: PostprocessedImage, args):
+ for script in self.scripts_in_preferred_order():
+ shared.state.job = script.name
+
+ script_args = args[script.args_from:script.args_to]
+
+ process_args = {}
+ for (name, component), value in zip(script.controls.items(), script_args):
+ process_args[name] = value
+
+ script.process(pp, **process_args)
+
+ def create_args_for_run(self, scripts_args):
+ if not self.ui_created:
+ with gr.Blocks(analytics_enabled=False):
+ self.setup_ui()
+
+ scripts = self.scripts_in_preferred_order()
+ args = [None] * max([x.args_to for x in scripts])
+
+ for script in scripts:
+ script_args_dict = scripts_args.get(script.name, None)
+ if script_args_dict is not None:
+
+ for i, name in enumerate(script.controls):
+ args[script.args_from + i] = script_args_dict.get(name, None)
+
+ return args
+
+ def image_changed(self):
+ for script in self.scripts_in_preferred_order():
+ script.image_changed()
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index c72d8efc..e90aa9fe 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -41,7 +41,9 @@ class DisableInitialization: return self.create_model_and_transforms(*args, pretrained=None, **kwargs)
def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
- return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+ res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)
+ res.name_or_path = pretrained_model_name_or_path
+ return res
def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):
args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6b0d95af..f9652d21 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -70,9 +70,10 @@ def undo_optimizations(): def fix_checkpoint():
- ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
- ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
- ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
+ """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
+ checkpoints to be added when not training (there's a warning)"""
+
+ pass
class StableDiffusionModelHijack:
@@ -106,8 +107,6 @@ class StableDiffusionModelHijack: self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
-
- fix_checkpoint()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py index 5712972f..2604d969 100644 --- a/modules/sd_hijack_checkpoint.py +++ b/modules/sd_hijack_checkpoint.py @@ -1,10 +1,46 @@ from torch.utils.checkpoint import checkpoint +import ldm.modules.attention +import ldm.modules.diffusionmodules.openaimodel + + def BasicTransformerBlock_forward(self, x, context=None): return checkpoint(self._forward, x, context) + def AttentionBlock_forward(self, x): return checkpoint(self._forward, x) + def ResBlock_forward(self, x, emb): - return checkpoint(self._forward, x, emb)
\ No newline at end of file + return checkpoint(self._forward, x, emb) + + +stored = [] + + +def add(): + if len(stored) != 0: + return + + stored.extend([ + ldm.modules.attention.BasicTransformerBlock.forward, + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward + ]) + + ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward + + +def remove(): + if len(stored) == 0: + return + + ldm.modules.attention.BasicTransformerBlock.forward = stored[0] + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] + + stored.clear() + diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index cdc63ed7..74452709 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default
from einops import rearrange
-from modules import shared
+from modules import shared, errors
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
del context, context_k, context_v, x
@@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
@@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k) * self.scale
v = self.to_v(context_v)
del context, context_k, context_v, x
@@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
@@ -279,18 +279,34 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ )
+def get_xformers_flash_attention_op(q, k, v):
+ if not shared.cmd_opts.xformers_flash_attention:
+ return None
+
+ try:
+ flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
+ fw, bw = flash_attention_op
+ if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
+ return flash_attention_op
+ except Exception as e:
+ errors.display_once(e, "enabling flash attention")
+
+ return None
+
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
@@ -365,7 +381,7 @@ def xformers_attnblock_forward(self, x): q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
- out = xformers.ops.memory_efficient_attention(q, k, v)
+ out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
diff --git a/modules/sd_models.py b/modules/sd_models.py index 6a681cef..12083848 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,14 +41,16 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"):
name = name[1:]
- self.title = name
+ self.name = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
self.shorthash = self.sha256[0:10] if self.sha256 else None
- self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
+ self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+
+ self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -56,13 +58,15 @@ class CheckpointInfo: checkpoint_alisases[id] = self
def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256]
self.register()
+ self.title = f'{self.name} [{self.shorthash}]'
+
return self.shorthash
@@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None def load_model_weights(model, checkpoint_info: CheckpointInfo):
+ title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash()
+ if checkpoint_info.title != title:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
cache_enabled = shared.opts.sd_checkpoint_cache > 0
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index b2af2ce7..4ce238b8 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -72,6 +72,13 @@ def refresh_vae_list(): os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), ] + if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir): + paths += [ + os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'), + ] + candidates = [] for path in paths: candidates += glob.iglob(path, recursive=True) @@ -113,6 +120,12 @@ def resolve_vae(checkpoint_file): return None, None +def load_vae_dict(filename, map_location): + vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + return vae_dict_1 + + def load_vae(model, vae_file=None, vae_source="from unknown source"): global vae_dict, loaded_vae_file # save_settings = False @@ -130,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) - vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) _load_vae_dict(model, vae_dict_1) if cache_enabled: diff --git a/modules/shared.py b/modules/shared.py index 483c4c62..a644c0be 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -9,7 +9,6 @@ from PIL import Image import gradio as gr
import tqdm
-import modules.artists
import modules.interrogate
import modules.memmon
import modules.styles
@@ -20,12 +19,15 @@ from modules.paths import models_path, script_path, sd_path demo = None
+sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
+
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
+parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
@@ -55,6 +57,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
+parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
@@ -64,6 +67,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
@@ -97,6 +101,8 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
+parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
+
script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
@@ -142,7 +148,7 @@ config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = {}
-loaded_hypernetwork = None
+loaded_hypernetworks = []
def reload_hypernetworks():
@@ -150,8 +156,6 @@ def reload_hypernetworks(): global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
- hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
-
class State:
@@ -252,8 +256,6 @@ class State: state = State()
state.server_start = time.time()
-artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
-
styles_filename = cmd_opts.styles_file
prompt_styles = modules.styles.StyleDatabase(styles_filename)
@@ -368,6 +370,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" }))
options_templates.update(options_section(('system', "System"), {
+ "show_warnings": OptionInfo(False, "Show warnings in console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
@@ -395,10 +398,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
- "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
- "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
+ "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
@@ -406,8 +407,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
- "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
+ "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -418,7 +419,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
- "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
@@ -444,10 +444,13 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
- "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
- 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
- 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
- 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
+ "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
+ "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
+ "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
+ "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
options_templates.update(options_section(('ui', "Live previews"), {
@@ -472,6 +475,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
}))
+options_templates.update(options_section(('postprocessing', "Postprocessing"), {
+ 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
+ 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+}))
+
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
@@ -657,3 +665,17 @@ mem_mon.start() def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
+
+
+def html_path(filename):
+ return os.path.join(script_path, "html", filename)
+
+
+def html(filename):
+ path = html_path(filename)
+
+ if os.path.exists(path):
+ with open(path, encoding="utf8") as file:
+ return file.read()
+
+ return ""
diff --git a/modules/styles.py b/modules/styles.py index ce6e71ca..990d5623 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -40,12 +40,18 @@ def apply_styles_to_prompt(prompt, styles): class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "")
- self.styles = {"None": self.no_style}
+ self.styles = {}
+ self.path = path
- if not os.path.exists(path):
+ self.reload()
+
+ def reload(self):
+ self.styles.clear()
+
+ if not os.path.exists(self.path):
return
- with open(path, "r", encoding="utf-8-sig", newline='') as file:
+ with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
reader = csv.DictReader(file)
for row in reader:
# Support loading old CSV format with "name, text"-columns
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 64abff4d..c0ac11d3 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
@@ -20,7 +20,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height, if process_caption_deepbooru:
deepbooru.model.start()
- preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug)
+ preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
finally:
@@ -109,8 +109,30 @@ def split_pic(image, inverse_xy, width, height, overlap_ratio): splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
-
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
+def center_crop(image: Image, w: int, h: int):
+ iw, ih = image.size
+ if ih / h < iw / w:
+ sw = w * ih / h
+ box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
+ else:
+ sh = h * iw / w
+ box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
+ return image.resize((w, h), Image.Resampling.LANCZOS, box)
+
+
+def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
+ iw, ih = image.size
+ err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
+ wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
+ if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+ key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
+ default=None
+ )
+ return wh and center_crop(image, *wh)
+
+
+def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
height = process_height
src = os.path.abspath(process_src)
@@ -194,6 +216,14 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre save_pic(focal, index, params, existing_caption=existing_caption)
process_default_resize = False
+ if process_multicrop:
+ cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
+ if cropped is not None:
+ save_pic(cropped, index, params, existing_caption=existing_caption)
+ else:
+ print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
+ process_default_resize = False
+
if process_default_resize:
img = images.resize_image(1, img, width, height)
save_pic(img, index, params, existing_caption=existing_caption)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7e4a6d24..4e90f690 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -15,7 +15,7 @@ import numpy as np from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -50,6 +50,7 @@ class Embedding: self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
+ self.filename = None
def save(self, filename):
embedding_data = {
@@ -182,6 +183,7 @@ class EmbeddingDatabase: embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
+ embedding.filename = path
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
@@ -452,6 +454,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st pbar = tqdm.tqdm(total=steps - initial_step)
try:
+ sd_hijack_checkpoint.add()
+
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
@@ -617,9 +621,11 @@ Last saved image: {html.escape(last_saved_image)}<br/> pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
shared.parallel_processing_allowed = old_parallel_processing_allowed
+ sd_hijack_checkpoint.remove()
return embedding, filename
+
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
diff --git a/modules/txt2img.py b/modules/txt2img.py index ca5d4550..e945fd69 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,13 +8,13 @@ import modules.processing as processing from modules.ui import plaintext_to_html
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
- styles=[prompt_style, prompt_style2],
+ styles=prompt_styles,
negative_prompt=negative_prompt,
seed=seed,
subseed=subseed,
diff --git a/modules/ui.py b/modules/ui.py index 20b66165..85ae62c7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,12 +5,12 @@ import mimetypes import os
import platform
import random
-import subprocess as sp
import sys
import tempfile
import time
import traceback
from functools import partial, reduce
+import warnings
import gradio as gr
import gradio.routes
@@ -19,7 +19,7 @@ import numpy as np from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -40,6 +40,9 @@ from modules.sd_samplers import samplers, samplers_for_img2img from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
+import modules.extras
+
+warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
@@ -72,6 +75,7 @@ css_hide_progressbar = """ .wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
+.wrap.cover-bg .z-20::before { content:"" }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
.meta-text-center { display:none!important; }
@@ -82,86 +86,22 @@ css_hide_progressbar = """ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
paste_symbol = '\u2199\ufe0f' # ↙
-folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️
+extra_networks_symbol = '\U0001F3B4' # 🎴
def plaintext_to_html(text):
- text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
- return text
+ return ui_common.plaintext_to_html(text)
+
def send_gradio_gallery_to_image(x):
if len(x) == 0:
return None
return image_from_url_text(x[0])
-def save_files(js_data, images, do_make_zip, index):
- import csv
- filenames = []
- fullfns = []
-
- #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
- class MyObject:
- def __init__(self, d=None):
- if d is not None:
- for key, value in d.items():
- setattr(self, key, value)
-
- data = json.loads(js_data)
-
- p = MyObject(data)
- path = opts.outdir_save
- save_to_dirs = opts.use_save_to_dirs_for_ui
- extension: str = opts.samples_format
- start_index = 0
-
- if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
- images = [images[index]]
- start_index = index
-
- os.makedirs(opts.outdir_save, exist_ok=True)
-
- with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
- at_start = file.tell() == 0
- writer = csv.writer(file)
- if at_start:
- writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
-
- for image_index, filedata in enumerate(images, start_index):
- image = image_from_url_text(filedata)
-
- is_grid = image_index < p.index_of_first_image
- i = 0 if is_grid else (image_index - p.index_of_first_image)
-
- fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
-
- filename = os.path.relpath(fullfn, path)
- filenames.append(filename)
- fullfns.append(fullfn)
- if txt_fullfn:
- filenames.append(os.path.basename(txt_fullfn))
- fullfns.append(txt_fullfn)
-
- writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
-
- # Make Zip
- if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
-
- from zipfile import ZipFile
- with ZipFile(zip_filepath, "w") as zip_file:
- for i in range(len(fullfns)):
- with open(fullfns[i], mode="rb") as f:
- zip_file.writestr(filenames[i], f.read())
- fullfns.insert(0, zip_filepath)
-
- return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-
-
def visit(x, func, path=""):
if hasattr(x, 'children'):
for c in x.children:
@@ -180,7 +120,7 @@ def add_style(name: str, prompt: str, negative_prompt: str): # reserialize all styles every time we save them
shared.prompt_styles.save_styles(shared.styles_filename)
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+ return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
@@ -197,22 +137,44 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
-def apply_styles(prompt, prompt_neg, style1_name, style2_name):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+def apply_styles(prompt, prompt_neg, styles):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
+
+
+def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
+ if mode in {0, 1, 3, 4}:
+ return [interrogation_function(ii_singles[mode]), None]
+ elif mode == 2:
+ return [interrogation_function(ii_singles[mode]["image"]), None]
+ elif mode == 5:
+ assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
+ images = shared.listfiles(ii_input_dir)
+ print(f"Will process {len(images)} images.")
+ if ii_output_dir != "":
+ os.makedirs(ii_output_dir, exist_ok=True)
+ else:
+ ii_output_dir = ii_input_dir
+
+ for image in images:
+ img = Image.open(image)
+ filename = os.path.basename(image)
+ left, _ = os.path.splitext(filename)
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
+ return [gr.update(), None]
def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB"))
-
- return gr_show(True) if prompt is None else prompt
+ return gr.update() if prompt is None else prompt
def interrogate_deepbooru(image):
prompt = deepbooru.model.tag(image)
- return gr_show(True) if prompt is None else prompt
+ return gr.update() if prompt is None else prompt
def create_seed_inputs(target_interface):
@@ -299,6 +261,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps):
try:
+ text, _ = extra_networks.parse_prompt(text)
+
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
@@ -310,43 +274,23 @@ def update_token_counter(text, steps): flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- style_class = ' class="red"' if (token_count > max_length) else ""
- return f"<span {style_class}>{token_count}/{max_length}</span>"
+ return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
- with gr.Row(elem_id="toprow"):
- with gr.Column(scale=6):
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
- placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
- placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Column(scale=1, elem_id="roll_col"):
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
- prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
button_interrogate = None
button_deepbooru = None
@@ -355,7 +299,7 @@ def create_toprow(is_img2img): button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
- with gr.Column(scale=1):
+ with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
with gr.Row(elem_id=f"{id_part}_generate_box"):
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
@@ -373,14 +317,30 @@ def create_toprow(is_img2img): outputs=[],
)
- with gr.Row():
- with gr.Column(scale=1, elem_id="style_pos_col"):
- prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+ with gr.Row(elem_id=f"{id_part}_tools"):
+ paste = ToolButton(value=paste_symbol, elem_id="paste")
+ clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+ prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
+ save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
+
+ token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
+ negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
+
+ clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[prompt, negative_prompt],
+ outputs=[prompt, negative_prompt],
+ )
- with gr.Column(scale=1, elem_id="style_neg_col"):
- prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
+ with gr.Row(elem_id=f"{id_part}_styles_row"):
+ prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
+ create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
def setup_progressbar(*args, **kwargs):
@@ -417,21 +377,7 @@ def apply_setting(key, value): opts.data_labels[key].onchange()
opts.save(shared.config_filename)
- return value
-
-
-def update_generation_info(args):
- generation_info, html_info, img_index = args
- try:
- generation_info = json.loads(generation_info)
- if img_index < 0 or img_index >= len(generation_info["infotexts"]):
- return html_info
- return plaintext_to_html(generation_info["infotexts"][img_index])
- except Exception:
- pass
- # if the json parse or anything else fails, just return the old html_info
- return html_info
-
+ return getattr(opts, key)
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
@@ -453,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele def create_output_panel(tabname, outdir):
- def open_folder(f):
- if not os.path.exists(f):
- print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
- return
- elif not os.path.isdir(f):
- print(f"""
-WARNING
-An open_folder request was made with an argument that is not a folder.
-This could be an error or a malicious attempt to run code on your computer.
-Requested path was: {f}
-""", file=sys.stderr)
- return
-
- if not shared.cmd_opts.hide_ui_dir_config:
- path = os.path.normpath(f)
- if platform.system() == "Windows":
- os.startfile(path)
- elif platform.system() == "Darwin":
- sp.Popen(["open", path])
- elif "microsoft-standard-WSL2" in platform.uname().release:
- sp.Popen(["wsl-open", path])
- else:
- sp.Popen(["xdg-open", path])
-
- with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
- with gr.Group(elem_id=f"{tabname}_gallery_container"):
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
-
- generation_info = None
- with gr.Column():
- with gr.Row(elem_id=f"image_buttons_{tabname}"):
- open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
-
- if tabname != "extras":
- save = gr.Button('Save', elem_id=f'save_{tabname}')
- save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
-
- buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
-
- open_folder_button.click(
- fn=lambda: open_folder(opts.outdir_samples or outdir),
- inputs=[],
- outputs=[],
- )
-
- if tabname != "extras":
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
-
- with gr.Group():
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
- if tabname == 'txt2img' or tabname == 'img2img':
- generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
- generation_info_button.click(
- fn=update_generation_info,
- _js="(x, y) => [x, y, selected_gallery_index()]",
- inputs=[generation_info, html_info],
- outputs=[html_info],
- preprocess=False
- )
-
- save.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- save_zip.click(
- fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
- inputs=[
- generation_info,
- result_gallery,
- html_info,
- html_info,
- ],
- outputs=[
- download_files,
- html_log,
- ]
- )
-
- else:
- html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
- html_info = gr.HTML(elem_id=f'html_info_{tabname}')
- html_log = gr.HTML(elem_id=f'html_log_{tabname}')
-
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+ return ui_common.create_output_panel(tabname, outdir)
def create_sampler_and_steps_selection(choices, tabname):
@@ -576,6 +422,16 @@ def ordered_ui_categories(): yield category
+def get_value_for_setting(key):
+ value = getattr(opts, key)
+
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+ return gr.update(value=value, **args)
+
+
def create_ui():
import modules.img2img
import modules.txt2img
@@ -588,13 +444,17 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
- txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
+ txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
+
+ with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
@@ -617,7 +477,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
- with FormRow(elem_id="txt2img_checkboxes"):
+ with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
@@ -625,12 +485,12 @@ def create_ui(): elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
- with FormRow(elem_id="txt2img_hires_fix_row1"):
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
- with FormRow(elem_id="txt2img_hires_fix_row2"):
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
@@ -674,8 +534,7 @@ def create_ui(): dummy_component,
txt2img_prompt,
txt2img_negative_prompt,
- txt2img_prompt_style,
- txt2img_prompt_style2,
+ txt2img_prompt_styles,
steps,
sampler_index,
restore_faces,
@@ -765,17 +624,24 @@ def create_ui(): ]
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+
+ ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
- img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
+ img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+
+ with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
with FormRow().style(equal_height=False):
- with gr.Column(variant='panel', elem_id="img2img_settings"):
+ with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
copy_image_destinations = {}
@@ -875,7 +741,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
elif category == "checkboxes":
- with FormRow(elem_id="img2img_checkboxes"):
+ with FormRow(elem_id="img2img_checkboxes", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
@@ -943,8 +809,7 @@ def create_ui(): dummy_component,
img2img_prompt,
img2img_negative_prompt,
- img2img_prompt_style,
- img2img_prompt_style2,
+ img2img_prompt_styles,
init_img,
sketch,
init_img_with_mask,
@@ -983,23 +848,36 @@ def create_ui(): show_progress=False,
)
+ interrogate_args = dict(
+ _js="get_img2img_tab_index",
+ inputs=[
+ dummy_component,
+ img2img_batch_input_dir,
+ img2img_batch_output_dir,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ init_img_inpaint,
+ ],
+ outputs=[img2img_prompt, dummy_component],
+ )
+
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
img2img_interrogate.click(
- fn=interrogate,
- inputs=[init_img],
- outputs=[img2img_prompt],
+ fn=lambda *args: process_interrogate(interrogate, *args),
+ **interrogate_args,
)
img2img_deepbooru.click(
- fn=interrogate_deepbooru,
- inputs=[init_img],
- outputs=[img2img_prompt],
+ fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
+ **interrogate_args,
)
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
- style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
+ style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
@@ -1009,18 +887,21 @@ def create_ui(): # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
# the same number of parameters, but we only know the style-name after the JavaScript prompt
inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
+ outputs=[txt2img_prompt_styles, img2img_prompt_styles],
)
- for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
+ for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click(
fn=apply_styles,
_js=js_func,
- inputs=[prompt, negative_prompt, style1, style2],
- outputs=[prompt, negative_prompt, style1, style2],
+ inputs=[prompt, negative_prompt, styles],
+ outputs=[prompt, negative_prompt, styles],
)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
@@ -1047,86 +928,7 @@ def create_ui(): modules.scripts.scripts_current = None
with gr.Blocks(analytics_enabled=False) as extras_interface:
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image', elem_id="extras_single_tab"):
- extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
-
- with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
-
- with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
- show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
-
- submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
- with gr.Tabs(elem_id="extras_resize_mode"):
- with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
- upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
- with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
- with gr.Group():
- with gr.Row():
- upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
- upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
-
- with gr.Group():
- extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
-
- with gr.Group():
- extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
- extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
-
- with gr.Group():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
-
- with gr.Group():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
-
- with gr.Group():
- upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
-
- result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
-
- submit.click(
- fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
- _js="get_extras_tab_index",
- inputs=[
- dummy_component,
- dummy_component,
- extras_image,
- image_batch,
- extras_batch_input_dir,
- extras_batch_output_dir,
- show_extras_results,
- gfpgan_visibility,
- codeformer_visibility,
- codeformer_weight,
- upscaling_resize,
- upscaling_resize_w,
- upscaling_resize_h,
- upscaling_crop,
- extras_upscaler_1,
- extras_upscaler_2,
- extras_upscaler_2_visibility,
- upscale_before_face_fix,
- ],
- outputs=[
- result_images,
- html_info_x,
- html_info,
- ]
- )
- parameters_copypaste.add_paste_fields("extras", extras_image, None)
-
- extras_image.change(
- fn=modules.extras.clear_cache,
- inputs=[], outputs=[]
- )
+ ui_postprocessing.create_ui()
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False):
@@ -1147,12 +949,21 @@ def create_ui(): outputs=[html, generation_info, html2],
)
+ def update_interp_description(value):
+ interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
+ interp_descriptions = {
+ "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
+ "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
+ "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
+ }
+ return interp_descriptions[value]
+
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
+ with gr.Column(variant='compact'):
+ interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
- with FormRow():
+ with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
@@ -1164,24 +975,37 @@ def create_ui(): custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
- interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
- config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
+ with FormRow():
+ with gr.Column():
+ config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
- modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+ with gr.Column():
+ with FormRow():
+ bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
+ create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
- with gr.Column(variant='panel'):
- submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
+ with FormRow():
+ discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
+
+ with gr.Row():
+ modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
+
+ with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
+ with gr.Group(elem_id="modelmerger_results_panel"):
+ modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
- with gr.Row().style(equal_height=False):
+ with gr.Row(variant="compact").style(equal_height=False):
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"):
@@ -1226,6 +1050,7 @@ def create_ui(): process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
+ process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
@@ -1238,7 +1063,19 @@ def create_ui(): process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
+
+ with gr.Column(visible=False) as process_multicrop_col:
+ gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
+ with gr.Row():
+ process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
+ process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
+ with gr.Row():
+ process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
+ process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
+ with gr.Row():
+ process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
+ process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1260,6 +1097,12 @@ def create_ui(): outputs=[process_focal_crop_row],
)
+ process_multicrop.change(
+ fn=lambda show: gr_show(show),
+ inputs=[process_multicrop],
+ outputs=[process_multicrop_col],
+ )
+
def get_textual_inversion_template_names():
return sorted([x for x in textual_inversion.textual_inversion_templates])
@@ -1379,6 +1222,13 @@ def create_ui(): process_focal_crop_entropy_weight,
process_focal_crop_edges_weight,
process_focal_crop_debug,
+ process_multicrop,
+ process_multicrop_mindim,
+ process_multicrop_maxdim,
+ process_multicrop_minarea,
+ process_multicrop_maxarea,
+ process_multicrop_objective,
+ process_multicrop_threshold,
],
outputs=[
ti_output,
@@ -1532,7 +1382,7 @@ def create_ui(): opts.save(shared.config_filename)
- return gr.update(value=value), opts.dumpjson()
+ return get_value_for_setting(key), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row():
@@ -1550,6 +1400,7 @@ def create_ui(): previous_section = None
current_tab = None
+ current_row = None
with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
@@ -1558,10 +1409,14 @@ def create_ui(): elem_id, text = item.section
if current_tab is not None:
+ current_row.__exit__()
current_tab.__exit__()
+ gr.Group()
current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
previous_section = item.section
@@ -1576,6 +1431,7 @@ def create_ui(): components.append(component)
if current_tab is not None:
+ current_row.__exit__()
current_tab.__exit__()
with gr.TabItem("Actions"):
@@ -1583,10 +1439,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
- if os.path.exists("html/licenses.html"):
- with open("html/licenses.html", encoding="utf8") as file:
- with gr.TabItem("Licenses"):
- gr.HTML(file.read(), elem_id="licenses")
+ with gr.TabItem("Licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@@ -1657,7 +1511,7 @@ def create_ui(): interfaces += [(extensions_interface, "Extensions", "extensions")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
- with gr.Row(elem_id="quicksettings"):
+ with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
@@ -1673,11 +1527,9 @@ def create_ui(): if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
- if os.path.exists("html/footer.html"):
- with open("html/footer.html", encoding="utf8") as file:
- footer = file.read()
- footer = footer.format(versions=versions_html())
- gr.HTML(footer, elem_id="footer")
+ footer = shared.html("footer.html")
+ footer = footer.format(versions=versions_html())
+ gr.HTML(footer, elem_id="footer")
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
@@ -1698,7 +1550,7 @@ def create_ui(): component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
def get_settings_values():
- return [getattr(opts, key) for key in component_keys]
+ return [get_value_for_setting(key) for key in component_keys]
demo.load(
fn=get_settings_values,
@@ -1713,12 +1565,15 @@ def create_ui(): print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list
- return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
+ return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
+ modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
modelmerger_merge.click(
- fn=modelmerger,
+ fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
+ _js='modelmerger',
inputs=[
+ dummy_component,
primary_model_name,
secondary_model_name,
tertiary_model_name,
@@ -1728,13 +1583,15 @@ def create_ui(): custom_name,
checkpoint_format,
config_source,
+ bake_in_vae,
+ discard_weights,
],
outputs=[
- submit_result,
primary_model_name,
secondary_model_name,
tertiary_model_name,
component_dict['sd_model_checkpoint'],
+ modelmerger_result,
]
)
@@ -1766,7 +1623,10 @@ def create_ui(): if saved_value is None:
ui_settings[key] = getattr(obj, field)
elif condition and not condition(saved_value):
- print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
+ pass
+
+ # this warning is generally not useful;
+ # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
else:
setattr(obj, field, saved_value)
if init_field is not None:
@@ -1794,7 +1654,13 @@ def create_ui(): apply_field(x, 'value')
if type(x) == gr.Dropdown:
- apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
+ def check_dropdown(val):
+ if getattr(x, 'multiselect', False):
+ return all([value in x.choices for value in val])
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
@@ -1806,28 +1672,27 @@ def create_ui(): with open(ui_config_file, "w", encoding="utf8") as file:
json.dump(ui_settings, file, indent=4)
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
+ interp_description.value = update_interp_description(interp_method.value)
+
return demo
def reload_javascript():
- with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
- javascript = f'<script>{jsfile.read()}</script>'
-
- scripts_list = modules.scripts.list_scripts("javascript", ".js")
-
- for basedir, filename, path in scripts_list:
- with open(path, "r", encoding="utf8") as jsfile:
- javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
+ head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}"></script>\n'
+ inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
- javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
+ inline += f"set_theme('{cmd_opts.theme}');"
+
+ for script in modules.scripts.list_scripts("javascript", ".js"):
+ head += f'<script type="text/javascript" src="file={script.path}"></script>\n'
- javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
+ head += f'<script type="text/javascript">{inline}</script>\n'
def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(
- b'</head>', f'{javascript}</head>'.encode("utf8"))
+ res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8"))
res.init_headers()
return res
diff --git a/modules/ui_common.py b/modules/ui_common.py new file mode 100644 index 00000000..9405ac1f --- /dev/null +++ b/modules/ui_common.py @@ -0,0 +1,202 @@ +import json
+import html
+import os
+import platform
+import sys
+
+import gradio as gr
+import subprocess as sp
+
+from modules import call_queue, shared
+from modules.generation_parameters_copypaste import image_from_url_text
+import modules.images
+
+folder_symbol = '\U0001f4c2' # 📂
+
+
+def update_generation_info(generation_info, html_info, img_index):
+ try:
+ generation_info = json.loads(generation_info)
+ if img_index < 0 or img_index >= len(generation_info["infotexts"]):
+ return html_info, gr.update()
+ return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update()
+ except Exception:
+ pass
+ # if the json parse or anything else fails, just return the old html_info
+ return html_info, gr.update()
+
+
+def plaintext_to_html(text):
+ text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
+ return text
+
+
+def save_files(js_data, images, do_make_zip, index):
+ import csv
+ filenames = []
+ fullfns = []
+
+ #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
+ class MyObject:
+ def __init__(self, d=None):
+ if d is not None:
+ for key, value in d.items():
+ setattr(self, key, value)
+
+ data = json.loads(js_data)
+
+ p = MyObject(data)
+ path = shared.opts.outdir_save
+ save_to_dirs = shared.opts.use_save_to_dirs_for_ui
+ extension: str = shared.opts.samples_format
+ start_index = 0
+
+ if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
+
+ images = [images[index]]
+ start_index = index
+
+ os.makedirs(shared.opts.outdir_save, exist_ok=True)
+
+ with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
+ at_start = file.tell() == 0
+ writer = csv.writer(file)
+ if at_start:
+ writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
+
+ for image_index, filedata in enumerate(images, start_index):
+ image = image_from_url_text(filedata)
+
+ is_grid = image_index < p.index_of_first_image
+ i = 0 if is_grid else (image_index - p.index_of_first_image)
+
+ fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
+
+ filename = os.path.relpath(fullfn, path)
+ filenames.append(filename)
+ fullfns.append(fullfn)
+ if txt_fullfn:
+ filenames.append(os.path.basename(txt_fullfn))
+ fullfns.append(txt_fullfn)
+
+ writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
+
+ # Make Zip
+ if do_make_zip:
+ zip_filepath = os.path.join(path, "images.zip")
+
+ from zipfile import ZipFile
+ with ZipFile(zip_filepath, "w") as zip_file:
+ for i in range(len(fullfns)):
+ with open(fullfns[i], mode="rb") as f:
+ zip_file.writestr(filenames[i], f.read())
+ fullfns.insert(0, zip_filepath)
+
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
+
+
+def create_output_panel(tabname, outdir):
+ from modules import shared
+ import modules.generation_parameters_copypaste as parameters_copypaste
+
+ def open_folder(f):
+ if not os.path.exists(f):
+ print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.')
+ return
+ elif not os.path.isdir(f):
+ print(f"""
+WARNING
+An open_folder request was made with an argument that is not a folder.
+This could be an error or a malicious attempt to run code on your computer.
+Requested path was: {f}
+""", file=sys.stderr)
+ return
+
+ if not shared.cmd_opts.hide_ui_dir_config:
+ path = os.path.normpath(f)
+ if platform.system() == "Windows":
+ os.startfile(path)
+ elif platform.system() == "Darwin":
+ sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
+ else:
+ sp.Popen(["xdg-open", path])
+
+ with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+ with gr.Group(elem_id=f"{tabname}_gallery_container"):
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
+
+ generation_info = None
+ with gr.Column():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+
+ if tabname != "extras":
+ save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
+
+ buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
+
+ open_folder_button.click(
+ fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
+ inputs=[],
+ outputs=[],
+ )
+
+ if tabname != "extras":
+ with gr.Row():
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
+
+ with gr.Group():
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+ if tabname == 'txt2img' or tabname == 'img2img':
+ generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
+ generation_info_button.click(
+ fn=update_generation_info,
+ _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
+ inputs=[generation_info, html_info, html_info],
+ outputs=[html_info, html_info],
+ )
+
+ save.click(
+ fn=call_queue.wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ],
+ show_progress=False,
+ )
+
+ save_zip.click(
+ fn=call_queue.wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
+ html_info,
+ html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
+ ]
+ )
+
+ else:
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_components.py b/modules/ui_components.py index 97acff06..9aec3097 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent): return "button"
+class ToolButtonTop(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool-top", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
@@ -37,3 +47,4 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self):
return "colorpicker"
+
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 00000000..af2b8071 --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,171 @@ +import os.path
+
+from modules import shared
+import gradio as gr
+import json
+
+from modules.generation_parameters_copypaste import image_from_url_text
+
+extra_pages = []
+
+
+def register_page(page):
+ """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
+
+ extra_pages.append(page)
+
+
+class ExtraNetworksPage:
+ def __init__(self, title):
+ self.title = title
+ self.name = title.lower()
+ self.card_page = shared.html("extra-networks-card.html")
+ self.allow_negative_prompt = False
+
+ def refresh(self):
+ pass
+
+ def create_html(self, tabname):
+ items_html = ''
+
+ for item in self.list_items():
+ items_html += self.create_html_for_item(item, tabname)
+
+ if items_html == '':
+ dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
+ items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+
+ res = f"""
+<div id='{tabname}_{self.name}_cards' class='extra-network-cards'>
+{items_html}
+</div>
+"""
+
+ return res
+
+ def list_items(self):
+ raise NotImplementedError()
+
+ def allowed_directories_for_previews(self):
+ return []
+
+ def create_html_for_item(self, item, tabname):
+ preview = item.get("preview", None)
+
+ args = {
+ "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
+ "prompt": item["prompt"],
+ "tabname": json.dumps(tabname),
+ "local_preview": json.dumps(item["local_preview"]),
+ "name": item["name"],
+ "allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
+ }
+
+ return self.card_page.format(**args)
+
+
+def intialize():
+ extra_pages.clear()
+
+
+class ExtraNetworksUi:
+ def __init__(self):
+ self.pages = None
+ self.stored_extra_pages = None
+
+ self.button_save_preview = None
+ self.preview_target_filename = None
+
+ self.tabname = None
+
+
+def pages_in_preferred_order(pages):
+ tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
+
+ def tab_name_score(name):
+ name = name.lower()
+ for i, possible_match in enumerate(tab_order):
+ if possible_match in name:
+ return i
+
+ return len(pages)
+
+ tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
+
+ return sorted(pages, key=lambda x: tab_scores[x.name])
+
+
+def create_ui(container, button, tabname):
+ ui = ExtraNetworksUi()
+ ui.pages = []
+ ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
+ ui.tabname = tabname
+
+ with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+ for page in ui.stored_extra_pages:
+ with gr.Tab(page.title):
+ page_elem = gr.HTML(page.create_html(ui.tabname))
+ ui.pages.append(page_elem)
+
+ filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
+ button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
+ button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
+
+ ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
+ ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
+
+ button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
+ button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
+
+ def refresh():
+ res = []
+
+ for pg in ui.stored_extra_pages:
+ pg.refresh()
+ res.append(pg.create_html(ui.tabname))
+
+ return res
+
+ button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
+
+ return ui
+
+
+def path_is_parent(parent_path, child_path):
+ parent_path = os.path.abspath(parent_path)
+ child_path = os.path.abspath(child_path)
+
+ return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
+
+
+def setup_ui(ui, gallery):
+ def save_preview(index, images, filename):
+ if len(images) == 0:
+ print("There is no image in gallery to save as a preview.")
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ index = int(index)
+ index = 0 if index < 0 else index
+ index = len(images) - 1 if index >= len(images) else index
+
+ img_info = images[index if index >= 0 else 0]
+ image = image_from_url_text(img_info)
+
+ is_allowed = False
+ for extra_page in ui.stored_extra_pages:
+ if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ is_allowed = True
+ break
+
+ assert is_allowed, f'writing to {filename} is not allowed'
+
+ image.save(filename)
+
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ ui.button_save_preview.click(
+ fn=save_preview,
+ _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
+ inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
+ outputs=[*ui.pages]
+ )
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 00000000..65d000cf --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,35 @@ +import json
+import os
+
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Hypernetworks')
+
+ def refresh(self):
+ shared.reload_hypernetworks()
+
+ def list_items(self):
+ for name, path in shared.hypernetworks.items():
+ path, ext = os.path.splitext(path)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ break
+
+ yield {
+ "name": name,
+ "filename": path,
+ "preview": preview,
+ "prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.hypernetwork_dir]
+
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 00000000..dbd23d2d --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,33 @@ +import json
+import os
+
+from modules import ui_extra_networks, sd_hijack
+
+
+class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Textual Inversion')
+ self.allow_negative_prompt = True
+
+ def refresh(self):
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+ def list_items(self):
+ for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+ path, ext = os.path.splitext(embedding.filename)
+ preview_file = path + ".preview.png"
+
+ preview = None
+ if os.path.isfile(preview_file):
+ preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
+
+ yield {
+ "name": embedding.name,
+ "filename": embedding.filename,
+ "preview": preview,
+ "prompt": json.dumps(embedding.name),
+ "local_preview": path + ".preview.png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py new file mode 100644 index 00000000..b418d955 --- /dev/null +++ b/modules/ui_postprocessing.py @@ -0,0 +1,57 @@ +import gradio as gr
+from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+import modules.generation_parameters_copypaste as parameters_copypaste
+
+
+def create_ui():
+ tab_index = gr.State(value=0)
+
+ with gr.Row().style(equal_height=False, variant='compact'):
+ with gr.Column(variant='compact'):
+ with gr.Tabs(elem_id="mode_extras"):
+ with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
+
+ with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
+
+ with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
+
+ submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
+
+ script_inputs = scripts.scripts_postproc.setup_ui()
+
+ with gr.Column():
+ result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
+
+ tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
+ tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
+ tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
+
+ submit.click(
+ fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
+ inputs=[
+ tab_index,
+ extras_image,
+ image_batch,
+ extras_batch_input_dir,
+ extras_batch_output_dir,
+ show_extras_results,
+ *script_inputs
+ ],
+ outputs=[
+ result_images,
+ html_info_x,
+ html_info,
+ ]
+ )
+
+ parameters_copypaste.add_paste_fields("extras", extras_image, None)
+
+ extras_image.change(
+ fn=scripts.scripts_postproc.image_changed,
+ inputs=[], outputs=[]
+ )
diff --git a/modules/upscaler.py b/modules/upscaler.py index 231680cb..a5bf5acb 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -95,6 +95,7 @@ class UpscalerData: def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): self.name = name self.data_path = path + self.local_data_path = path self.scaler = upscaler self.scale = scale self.model = model |