From b5253f0dab529707f1fe2e11211a10ce2f264617 Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Thu, 5 Jan 2023 21:21:48 +0000 Subject: allow img2img api to run scripts --- modules/api/api.py | 27 ++++++++++++++++++++++++--- modules/api/models.py | 2 +- modules/processing.py | 4 ++-- 3 files changed, 27 insertions(+), 6 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 2103709b..aa62a42e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras @@ -28,8 +28,13 @@ def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) except: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}") +def script_name_to_index(name, scripts): + try: + return [script.title().lower() for script in scripts].index(name.lower()) + except: + raise HTTPException(status_code=422, detail=f"Script '{name}' not found") def validate_sampler_name(name): config = sd_samplers.all_samplers_map.get(name, None) @@ -170,6 +175,14 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") + if img2imgreq.script_name is not None: + if scripts.scripts_img2img.scripts == []: + scripts.scripts_img2img.initialize_scripts(True) + ui.create_ui() + + script_idx = script_name_to_index(img2imgreq.script_name, scripts.scripts_img2img.selectable_scripts) + script = scripts.scripts_img2img.selectable_scripts[script_idx] + mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) @@ -186,13 +199,21 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. + args.pop('script_name', None) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] shared.state.begin() - processed = process_images(p) + if 'script' in locals(): + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples + p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) + else: + processed = process_images(p) + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/api/models.py b/modules/api/models.py index d8198a27..862477e7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -106,7 +106,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] ).generate_model() class TextToImageResponse(BaseModel): diff --git a/modules/processing.py b/modules/processing.py index a408d622..d5ac7eb1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -98,7 +98,7 @@ class StableDiffusionProcessing(): """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -149,7 +149,7 @@ class StableDiffusionProcessing(): self.seed_resize_from_w = 0 self.scripts = None - self.script_args = None + self.script_args = script_args self.all_prompts = None self.all_negative_prompts = None self.all_seeds = None -- cgit v1.2.1 From df3b31eb559ab9fabf7e513bdeddd5282c16f124 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 7 Jan 2023 07:04:59 -0500 Subject: In-place operations can break gradient calculation --- modules/sd_hijack_clip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 5520c9b2..852afc66 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z *= original_mean / new_mean + z = z * (original_mean / new_mean) return z -- cgit v1.2.1 From d38ede71d5330958f4bbac5f99c1be3c146b506a Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Sat, 7 Jan 2023 14:21:31 +0000 Subject: Added script support in txt2img endpoint --- modules/api/api.py | 22 +++++++++++++++++++--- modules/api/models.py | 2 +- 2 files changed, 20 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index aa62a42e..0e8ea263 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -149,6 +149,14 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + if txt2imgreq.script_name is not None: + if scripts.scripts_txt2img.scripts == []: + scripts.scripts_txt2img.initialize_scripts(True) + ui.create_ui() + + script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) + script = scripts.scripts_txt2img.selectable_scripts[script_idx] + populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": True, @@ -158,11 +166,20 @@ class Api: if populate.sampler_name: populate.sampler_index = None # prevent a warning later on + args = vars(populate) + args.pop('script_name', None) + with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate)) + p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) shared.state.begin() - processed = process_images(p) + if 'script' in locals(): + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples + p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) + else: + processed = process_images(p) shared.state.end() @@ -213,7 +230,6 @@ class Api: processed = scripts.scripts_img2img.run(p, *p.script_args) else: processed = process_images(p) - shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/api/models.py b/modules/api/models.py index c85eb94d..ce43c858 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,7 +100,7 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( -- cgit v1.2.1 From 984b86dd0abf0da7f6b116864c791a2bfe8859ef Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 7 Jan 2023 13:08:21 -0700 Subject: Add fallback for Protocol import --- modules/sub_quadratic_attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index fea7aaac..93381bae 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,7 +15,13 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math -from typing import Optional, NamedTuple, Protocol, List + +try: + from typing import Protocol +except: + from typing_extensions import Protocol + +from typing import Optional, NamedTuple, List def narrow_trunc( input: Tensor, -- cgit v1.2.1 From a0c87f1fdf2b76b2ae4ef6c4b01ddaede3afab06 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 08:52:26 +0300 Subject: skip images in embeddings dir if they have a second .preview extension --- modules/textual_inversion/textual_inversion.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 45882ed6..e85dd549 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -109,6 +109,10 @@ class EmbeddingDatabase: ext = ext.upper() if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) -- cgit v1.2.1 From 085427de0efc9e9e7a6e9a5aebc6b5a69f0365e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 09:37:33 +0300 Subject: make it possible for extensions/scripts to add their own embedding directories --- modules/sd_hijack.py | 7 +- modules/textual_inversion/textual_inversion.py | 170 +++++++++++++++---------- 2 files changed, 108 insertions(+), 69 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfdb09d6..6b0d95af 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -83,10 +83,12 @@ class StableDiffusionModelHijack: clip = None optimization_method = None - embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() - def hijack(self, m): + def __init__(self): + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) + def hijack(self, m): if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) @@ -117,7 +119,6 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e85dd549..217fe9eb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -66,17 +66,41 @@ class Embedding: return self.cached_checksum +class DirWithTextualInversionEmbeddings: + def __init__(self, path): + self.path = path + self.mtime = None + + def has_changed(self): + if not os.path.isdir(self.path): + return False + + mt = os.path.getmtime(self.path) + if self.mtime is None or mt > self.mtime: + return True + + def update(self): + if not os.path.isdir(self.path): + return + + self.mtime = os.path.getmtime(self.path) + + class EmbeddingDatabase: - def __init__(self, embeddings_dir): + def __init__(self): self.ids_lookup = {} self.word_embeddings = {} self.skipped_embeddings = {} - self.dir_mtime = None - self.embeddings_dir = embeddings_dir self.expected_shape = -1 + self.embedding_dirs = {} - def register_embedding(self, embedding, model): + def add_embedding_dir(self, path): + self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) + + def clear_embedding_dirs(self): + self.embedding_dirs.clear() + def register_embedding(self, embedding, model): self.word_embeddings[embedding.name] = embedding ids = model.cond_stage_model.tokenize([embedding.name])[0] @@ -93,69 +117,62 @@ class EmbeddingDatabase: vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] - def load_textual_inversion_embeddings(self, force_reload = False): - mt = os.path.getmtime(self.embeddings_dir) - if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime: - return + def load_from_file(self, path, filename): + name, ext = os.path.splitext(filename) + ext = ext.upper() - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - self.skipped_embeddings.clear() - self.expected_shape = self.get_expected_shape() - - def process_file(path, filename): - name, ext = os.path.splitext(filename) - ext = ext.upper() - - if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: - _, second_ext = os.path.splitext(name) - if second_ext.upper() == '.PREVIEW': - return - - embed_image = Image.open(path) - if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: - data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - name = data.get('name', name) - else: - data = extract_image_data_embed(embed_image) - name = data.get('name', name) - elif ext in ['.BIN', '.PT']: - data = torch.load(path, map_location="cpu") - else: + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': return - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - vec = emb.detach().to(devices.device, dtype=torch.float32) - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vec.shape[0] - embedding.shape = vec.shape[-1] - - if self.expected_shape == -1 or self.expected_shape == embedding.shape: - self.register_embedding(embedding, shared.sd_model) + embed_image = Image.open(path) + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) else: - self.skipped_embeddings[name] = embedding + data = extract_image_data_embed(embed_image) + name = data.get('name', name) + elif ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + else: + return + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + + if self.expected_shape == -1 or self.expected_shape == embedding.shape: + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings[name] = embedding - for root, dirs, fns in os.walk(self.embeddings_dir): + def load_from_dir(self, embdir): + if not os.path.isdir(embdir.path): + return + + for root, dirs, fns in os.walk(embdir.path): for fn in fns: try: fullfn = os.path.join(root, fn) @@ -163,12 +180,32 @@ class EmbeddingDatabase: if os.stat(fullfn).st_size == 0: continue - process_file(fullfn, fn) + self.load_from_file(fullfn, fn) except Exception: print(f"Error loading embedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue + def load_textual_inversion_embeddings(self, force_reload=False): + if not force_reload: + need_reload = False + for path, embdir in self.embedding_dirs.items(): + if embdir.has_changed(): + need_reload = True + break + + if not need_reload: + return + + self.ids_lookup.clear() + self.word_embeddings.clear() + self.skipped_embeddings.clear() + self.expected_shape = self.get_expected_shape() + + for path, embdir in self.embedding_dirs.items(): + self.load_from_dir(embdir) + embdir.update() + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") @@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert os.path.isfile(template_file), "Prompt template file doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" - assert steps > 0 , "Max steps must be positive" + assert steps > 0, "Max steps must be positive" assert isinstance(save_model_every, int), "Save {name} must be integer" - assert save_model_every >= 0 , "Save {name} must be positive or 0" + assert save_model_every >= 0, "Save {name} must be positive or 0" assert isinstance(create_image_every, int), "Create image must be integer" - assert create_image_every >= 0 , "Create image must be positive or 0" + assert create_image_every >= 0, "Create image must be positive or 0" if save_model_every or create_image_every: assert log_directory, "Log directory is empty" + def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 -- cgit v1.2.1 From 6d0cc1e239e0a43a2e6d696eae20c66fad0819bb Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Sun, 8 Jan 2023 11:03:48 +0000 Subject: Corrected is_img2img param --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 0e8ea263..1785a6b4 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -151,7 +151,7 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): if txt2imgreq.script_name is not None: if scripts.scripts_txt2img.scripts == []: - scripts.scripts_txt2img.initialize_scripts(True) + scripts.scripts_txt2img.initialize_scripts(False) ui.create_ui() script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) -- cgit v1.2.1 From 137ce534b2355a527cd1a50c192909161258b442 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 16:14:38 +0300 Subject: remove some code duplication remove calls to locals() add a test for img2img with script --- modules/api/api.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 1785a6b4..5b6125f8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -148,14 +148,20 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) - def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - if txt2imgreq.script_name is not None: - if scripts.scripts_txt2img.scripts == []: - scripts.scripts_txt2img.initialize_scripts(False) - ui.create_ui() + def get_script(self, script_name, script_runner): + if script_name is None: + return None, None + + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + + script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) + script = script_runner.selectable_scripts[script_idx] + return script, script_idx - script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) - script = scripts.scripts_txt2img.selectable_scripts[script_idx] + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), @@ -173,7 +179,7 @@ class Api: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) shared.state.begin() - if 'script' in locals(): + if script is not None: p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args @@ -182,7 +188,6 @@ class Api: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -192,13 +197,7 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") - if img2imgreq.script_name is not None: - if scripts.scripts_img2img.scripts == []: - scripts.scripts_img2img.initialize_scripts(True) - ui.create_ui() - - script_idx = script_name_to_index(img2imgreq.script_name, scripts.scripts_img2img.selectable_scripts) - script = scripts.scripts_img2img.selectable_scripts[script_idx] + script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img) mask = img2imgreq.mask if mask: @@ -223,7 +222,7 @@ class Api: p.init_images = [decode_base64_to_image(x) for x in init_images] shared.state.begin() - if 'script' in locals(): + if script is not None: p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args -- cgit v1.2.1 From cb255faec6e5f6b47b7632e6b7d450b9e2f6678b Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Sun, 8 Jan 2023 10:17:50 -0700 Subject: Add support for loading VAEs from safetensor files --- modules/sd_vae.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index ac71d62d..9fcfd9db 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,4 +1,5 @@ import torch +import safetensors.torch import os import collections from collections import namedtuple @@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): candidates = [ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True) + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) @@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"): if os.path.isfile(vae_file_try): vae_file = vae_file_try print(f"Using VAE found similar to selected model: {vae_file}") + # if still not found, try look for ".vae.safetensors" beside model + if vae_file == "auto": + vae_file_try = model_path + ".vae.safetensors" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print(f"Using VAE found similar to selected model: {vae_file}") # No more fallbacks for auto if vae_file == "auto": vae_file = None @@ -163,8 +172,14 @@ def load_vae(model, vae_file=None): assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _, extension = os.path.splitext(vae_file) + if extension.lower() == ".safetensors": + vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location) + else: + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + if "state_dict" in vae_ckpt: + vae_ckpt = vae_ckpt["state_dict"] + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, vae_dict_1) if cache_enabled: -- cgit v1.2.1 From d4fd2418efb0986a8226add0b800fb5c73ffb58c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 14:57:47 +0300 Subject: add an option to use old hiresfix width/height behavior add a visual effect to inactive hires fix elements --- modules/generation_parameters_copypaste.py | 17 +++++++++++------ modules/processing.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + modules/ui.py | 23 ++++++++++++++--------- 4 files changed, 50 insertions(+), 17 deletions(-) (limited to 'modules') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 12a9de3d..f7f68b67 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res): firstpass_width = res.get('First pass size-1', None) firstpass_height = res.get('First pass size-2', None) + if shared.opts.use_old_hires_fix_width_height: + hires_width = int(res.get("Hires resize-1", None)) + hires_height = int(res.get("Hires resize-2", None)) + + if hires_width is not None and hires_height is not None: + res['Size-1'] = hires_width + res['Size-2'] = hires_height + return + if firstpass_width is None or firstpass_height is None: return @@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res): height = int(res.get("Size-2", 512)) if firstpass_width == 0 or firstpass_height == 0: - # old algorithm for auto-calculating first pass size - desired_pixel_count = 512 * 512 - actual_pixel_count = width * height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - firstpass_width = math.ceil(scale * width / 64) * 64 - firstpass_height = math.ceil(scale * height / 64) * 64 + from modules import processing + firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) res['Size-1'] = firstpass_width res['Size-2'] = firstpass_height diff --git a/modules/processing.py b/modules/processing.py index 1d23b15f..f04a0e1e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: return res +def old_hires_fix_first_pass_dimensions(width, height): + """old algorithm for auto-calculating first pass size""" + + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + width = math.ceil(scale * width / 64) * 64 + height = math.ceil(scale * height / 64) * 64 + + return width, height + + class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None @@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_upscale_to_y = hr_resize_y if firstphase_width != 0 or firstphase_height != 0: - print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) - self.hr_scale = self.width / firstphase_width + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height self.width = firstphase_width self.height = firstphase_height self.truncate_x = 0 self.truncate_y = 0 + self.applied_old_hires_behavior_to = None def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): + self.hr_resize_x = self.width + self.hr_resize_y = self.height + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height + + self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height) + self.applied_old_hires_behavior_to = (self.width, self.height) + if self.hr_resize_x == 0 and self.hr_resize_y == 0: self.extra_generation_params["Hires upscale"] = self.hr_scale self.hr_upscale_to_x = int(self.width * self.hr_scale) diff --git a/modules/shared.py b/modules/shared.py index a6712dae..a1e10201 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -398,6 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), + "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index 99483130..719c26b3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz with devices.autocast(): p.init([""], [0], [0]) - return f"resize: from {width}x{height} to {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" def apply_styles(prompt, prompt_neg, style1_name, style2_name): @@ -745,15 +745,20 @@ def create_ui(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - hr_resolution_preview_args = dict( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False - ) - for input in hr_resolution_preview_inputs: - input.change(**hr_resolution_preview_args) + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) -- cgit v1.2.1 From 49c4509ce2302350210ff650fd26373518c46a79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 19:58:35 +0300 Subject: use existing function for loading VAE weights from file --- modules/sd_vae.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9fcfd9db..0a49daa1 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,7 +3,7 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks +from modules import shared, devices, script_callbacks, sd_models from modules.paths import models_path import glob from copy import deepcopy @@ -172,13 +172,8 @@ def load_vae(model, vae_file=None): assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") store_base_vae(model) - _, extension = os.path.splitext(vae_file) - if extension.lower() == ".safetensors": - vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location) - else: - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - if "state_dict" in vae_ckpt: - vae_ckpt = vae_ckpt["state_dict"] + + vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, vae_dict_1) @@ -210,10 +205,12 @@ def _load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) + def clear_loaded_vae(): global loaded_vae_file loaded_vae_file = None + def reload_vae_weights(sd_model=None, vae_file="auto"): from modules import lowvram, devices, sd_hijack -- cgit v1.2.1 From cdfcbd995932ffa728db0cc00a5f97665c752103 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 20:08:48 +0300 Subject: Remove fallback for Protocol import and remove Protocol import and remove instances of Protocol in code add some whitespace between functions to be in line with other code in the repo --- modules/sub_quadratic_attention.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 93381bae..55052815 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,14 +15,9 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math - -try: - from typing import Protocol -except: - from typing_extensions import Protocol - from typing import Optional, NamedTuple, List + def narrow_trunc( input: Tensor, dim: int, @@ -31,12 +26,14 @@ def narrow_trunc( ) -> Tensor: return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) + class AttnChunk(NamedTuple): exp_values: Tensor exp_weights_sum: Tensor max_score: Tensor -class SummarizeChunk(Protocol): + +class SummarizeChunk: @staticmethod def __call__( query: Tensor, @@ -44,7 +41,8 @@ class SummarizeChunk(Protocol): value: Tensor, ) -> AttnChunk: ... -class ComputeQueryChunkAttn(Protocol): + +class ComputeQueryChunkAttn: @staticmethod def __call__( query: Tensor, @@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol): value: Tensor, ) -> Tensor: ... + def _summarize_chunk( query: Tensor, key: Tensor, @@ -72,6 +71,7 @@ def _summarize_chunk( max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + def _query_chunk_attention( query: Tensor, key: Tensor, @@ -112,6 +112,7 @@ def _query_chunk_attention( all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights + # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( query: Tensor, @@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking( hidden_states_slice = torch.bmm(attn_probs, value) return hidden_states_slice + class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk + def efficient_dot_product_attention( query: Tensor, key: Tensor, -- cgit v1.2.1