aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py205
-rw-r--r--modules/api/models.py24
-rw-r--r--modules/codeformer/vqgan_arch.py4
-rw-r--r--modules/deepbooru.py4
-rw-r--r--modules/devices.py35
-rw-r--r--modules/errors.py25
-rw-r--r--modules/extras.py42
-rw-r--r--modules/generation_parameters_copypaste.py137
-rw-r--r--modules/hypernetworks/hypernetwork.py61
-rw-r--r--modules/hypernetworks/ui.py31
-rw-r--r--modules/images.py74
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/import_hook.py5
-rw-r--r--modules/interrogate.py8
-rw-r--r--modules/lowvram.py12
-rw-r--r--modules/memmon.py3
-rw-r--r--modules/modelloader.py20
-rw-r--r--modules/ngrok.py4
-rw-r--r--modules/processing.py264
-rw-r--r--modules/safe.py63
-rw-r--r--modules/script_callbacks.py42
-rw-r--r--modules/scripts.py15
-rw-r--r--modules/sd_hijack.py56
-rw-r--r--modules/sd_hijack_clip.py371
-rw-r--r--modules/sd_hijack_clip_old.py81
-rw-r--r--modules/sd_hijack_inpainting.py241
-rw-r--r--modules/sd_hijack_optimizations.py135
-rw-r--r--modules/sd_hijack_unet.py2
-rw-r--r--modules/sd_hijack_xlmr.py34
-rw-r--r--modules/sd_models.py76
-rw-r--r--modules/sd_samplers.py86
-rw-r--r--modules/sd_vae.py33
-rw-r--r--modules/sd_vae_approx.py58
-rw-r--r--modules/shared.py70
-rw-r--r--modules/sub_quadratic_attention.py205
-rw-r--r--modules/textual_inversion/dataset.py10
-rw-r--r--modules/textual_inversion/learn_schedule.py11
-rw-r--r--modules/textual_inversion/logging.py24
-rw-r--r--modules/textual_inversion/preprocess.py1
-rw-r--r--modules/textual_inversion/textual_inversion.py149
-rw-r--r--modules/txt2img.py11
-rw-r--r--modules/ui.py719
-rw-r--r--modules/ui_components.py33
-rw-r--r--modules/ui_extensions.py59
-rw-r--r--modules/ui_tempdir.py24
-rw-r--r--modules/upscaler.py6
-rw-r--r--modules/xlmr.py137
47 files changed, 2561 insertions, 1151 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 14d0baaa..2103709b 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -1,22 +1,27 @@
import base64
import io
import time
+import datetime
import uvicorn
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
-from fastapi import APIRouter, Depends, FastAPI, HTTPException
+from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru
+from modules import sd_samplers, deepbooru, sd_hijack, images
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
-from modules.extras import run_extras, run_pnginfo
+from modules.extras import run_extras
+from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
+from modules.textual_inversion.preprocess import preprocess
+from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list
+from modules.sd_models import checkpoints_list, find_checkpoint_config
from modules.realesrgan_model import get_realesrgan_models
+from modules import devices
from typing import List
def upscaler_to_index(name: str):
@@ -63,18 +68,40 @@ def encode_pil_to_base64(image):
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
+def api_middleware(app: FastAPI):
+ @app.middleware("http")
+ async def log_and_time(req: Request, call_next):
+ ts = time.time()
+ res: Response = await call_next(req)
+ duration = str(round(time.time() - ts, 4))
+ res.headers["X-Process-Time"] = duration
+ endpoint = req.scope.get('path', 'err')
+ if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
+ print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
+ t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
+ code = res.status_code,
+ ver = req.scope.get('http_version', '0.0'),
+ cli = req.scope.get('client', ('0:0.0.0', 0))[0],
+ prot = req.scope.get('scheme', 'err'),
+ method = req.scope.get('method', 'err'),
+ endpoint = endpoint,
+ duration = duration,
+ ))
+ return res
+
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
if shared.cmd_opts.api_auth:
- self.credenticals = dict()
+ self.credentials = dict()
for auth in shared.cmd_opts.api_auth.split(","):
user, password = auth.split(":")
- self.credenticals[user] = password
+ self.credentials[user] = password
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
+ api_middleware(self.app)
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
@@ -93,26 +120,31 @@ class Api:
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
- self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
+ self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
+ self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
+ self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
+ self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
+ self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
return self.app.add_api_route(path, endpoint, **kwargs)
- def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())):
- if credenticals.username in self.credenticals:
- if compare_digest(credenticals.password, self.credenticals[credenticals.username]):
+ def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
+ if credentials.username in self.credentials:
+ if compare_digest(credentials.password, self.credentials[credentials.username]):
return True
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
populate = txt2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True
@@ -120,15 +152,14 @@ class Api:
)
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
- p = StableDiffusionProcessingTxt2Img(**vars(populate))
- # Override object param
-
- shared.state.begin()
with self.queue_lock:
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+
+ shared.state.begin()
processed = process_images(p)
+ shared.state.end()
- shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -144,7 +175,6 @@ class Api:
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
- "sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True,
"do_not_save_grid": True,
@@ -156,16 +186,14 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
- p = StableDiffusionProcessingImg2Img(**args)
-
- p.init_images = [decode_base64_to_image(x) for x in init_images]
-
- shared.state.begin()
with self.queue_lock:
- processed = process_images(p)
+ p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
- shared.state.end()
+ shared.state.begin()
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -181,7 +209,7 @@ class Api:
reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
- result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
+ result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
@@ -197,7 +225,7 @@ class Api:
reqDict.pop('imageList')
with self.queue_lock:
- result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
+ result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
@@ -205,9 +233,17 @@ class Api:
if(not req.image.strip()):
return PNGInfoResponse(info="")
- result = run_pnginfo(decode_base64_to_image(req.image.strip()))
+ image = decode_base64_to_image(req.image.strip())
+ if image is None:
+ return PNGInfoResponse(info="")
+
+ geninfo, items = images.read_info_from_image(image)
+ if geninfo is None:
+ geninfo = ""
- return PNGInfoResponse(info=result[1])
+ items = {**{'parameters': geninfo}, **items}
+
+ return PNGInfoResponse(info=geninfo, items=items)
def progressapi(self, req: ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
@@ -240,7 +276,7 @@ class Api:
def interrogateapi(self, interrogatereq: InterrogateRequest):
image_b64 = interrogatereq.image
if image_b64 is None:
- raise HTTPException(status_code=404, detail="Image not found")
+ raise HTTPException(status_code=404, detail="Image not found")
img = decode_base64_to_image(image_b64)
img = img.convert('RGB')
@@ -253,7 +289,7 @@ class Api:
processed = deepbooru.model.tag(img)
else:
raise HTTPException(status_code=404, detail="Model not found")
-
+
return InterrogateResponse(caption=processed)
def interruptapi(self):
@@ -298,7 +334,7 @@ class Api:
return upscalers
def get_sd_models(self):
- return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
+ return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
@@ -309,7 +345,7 @@ class Api:
def get_realesrgan_models(self):
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
- def get_promp_styles(self):
+ def get_prompt_styles(self):
styleList = []
for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k]
@@ -322,10 +358,113 @@ class Api:
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
-
+
+ def get_embeddings(self):
+ db = sd_hijack.model_hijack.embedding_db
+
+ def convert_embedding(embedding):
+ return {
+ "step": embedding.step,
+ "sd_checkpoint": embedding.sd_checkpoint,
+ "sd_checkpoint_name": embedding.sd_checkpoint_name,
+ "shape": embedding.shape,
+ "vectors": embedding.vectors,
+ }
+
+ def convert_embeddings(embeddings):
+ return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
+
+ return {
+ "loaded": convert_embeddings(db.word_embeddings),
+ "skipped": convert_embeddings(db.skipped_embeddings),
+ }
+
def refresh_checkpoints(self):
shared.refresh_checkpoints()
+ def create_embedding(self, args: dict):
+ try:
+ shared.state.begin()
+ filename = create_embedding(**args) # create empty embedding
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
+ shared.state.end()
+ return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
+ except AssertionError as e:
+ shared.state.end()
+ return TrainResponse(info = "create embedding error: {error}".format(error = e))
+
+ def create_hypernetwork(self, args: dict):
+ try:
+ shared.state.begin()
+ filename = create_hypernetwork(**args) # create empty embedding
+ shared.state.end()
+ return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
+ except AssertionError as e:
+ shared.state.end()
+ return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
+
+ def preprocess(self, args: dict):
+ try:
+ shared.state.begin()
+ preprocess(**args) # quick operation unless blip/booru interrogation is enabled
+ shared.state.end()
+ return PreprocessResponse(info = 'preprocess complete')
+ except KeyError as e:
+ shared.state.end()
+ return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
+ except AssertionError as e:
+ shared.state.end()
+ return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
+ except FileNotFoundError as e:
+ shared.state.end()
+ return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
+
+ def train_embedding(self, args: dict):
+ try:
+ shared.state.begin()
+ apply_optimizations = shared.opts.training_xattention_optimizations
+ error = None
+ filename = ''
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
+ try:
+ embedding, filename = train_embedding(**args) # can take a long time to complete
+ except Exception as e:
+ error = e
+ finally:
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
+ shared.state.end()
+ return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+ except AssertionError as msg:
+ shared.state.end()
+ return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
+
+ def train_hypernetwork(self, args: dict):
+ try:
+ shared.state.begin()
+ initial_hypernetwork = shared.loaded_hypernetwork
+ apply_optimizations = shared.opts.training_xattention_optimizations
+ error = None
+ filename = ''
+ if not apply_optimizations:
+ sd_hijack.undo_optimizations()
+ try:
+ hypernetwork, filename = train_hypernetwork(*args)
+ except Exception as e:
+ error = e
+ finally:
+ shared.loaded_hypernetwork = initial_hypernetwork
+ shared.sd_model.cond_stage_model.to(devices.device)
+ shared.sd_model.first_stage_model.to(devices.device)
+ if not apply_optimizations:
+ sd_hijack.apply_optimizations()
+ shared.state.end()
+ return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
+ except AssertionError as msg:
+ shared.state.end()
+ return TrainResponse(info = "train embedding error: {error}".format(error = error))
+
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
diff --git a/modules/api/models.py b/modules/api/models.py
index 22b88c59..5fa63774 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -128,7 +128,7 @@ class ExtrasBaseRequest(BaseModel):
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
- upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
+ upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
@@ -157,7 +157,8 @@ class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel):
- info: str = Field(title="Image info", description="A string with all the info the image had")
+ info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
+ items: dict = Field(title="Items", description="An object containing all the info the image had")
class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
@@ -175,6 +176,15 @@ class InterrogateRequest(BaseModel):
class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
+class TrainResponse(BaseModel):
+ info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
+
+class CreateResponse(BaseModel):
+ info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
+
+class PreprocessResponse(BaseModel):
+ info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
+
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)
@@ -240,3 +250,13 @@ class ArtistItem(BaseModel):
score: float = Field(title="Score")
category: str = Field(title="Category")
+class EmbeddingItem(BaseModel):
+ step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
+ sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
+ sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
+ shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
+ vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
+
+class EmbeddingsResponse(BaseModel):
+ loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+ skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
index c06c590c..e7293683 100644
--- a/modules/codeformer/vqgan_arch.py
+++ b/modules/codeformer/vqgan_arch.py
@@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module):
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
- raise ValueError(f'Wrong params!')
+ raise ValueError('Wrong params!')
def forward(self, x):
@@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module):
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
- raise ValueError(f'Wrong params!')
+ raise ValueError('Wrong params!')
def forward(self, x):
return self.main(x) \ No newline at end of file
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index dfc83357..122fce7f 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -79,7 +79,9 @@ class DeepDanbooru:
res = []
- for tag in tags:
+ filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
+
+ for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag]
tag_outformat = tag
if use_spaces:
diff --git a/modules/devices.py b/modules/devices.py
index f8cffae1..caeb0276 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -125,7 +125,34 @@ def layer_norm_fix(*args, **kwargs):
return orig_layer_norm(*args, **kwargs)
-# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
-if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
- torch.Tensor.to = tensor_to_fix
- torch.nn.functional.layer_norm = layer_norm_fix
+# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+orig_tensor_numpy = torch.Tensor.numpy
+def numpy_fix(self, *args, **kwargs):
+ if self.requires_grad:
+ self = self.detach()
+ return orig_tensor_numpy(self, *args, **kwargs)
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
+orig_cumsum = torch.cumsum
+orig_Tensor_cumsum = torch.Tensor.cumsum
+def cumsum_fix(input, cumsum_func, *args, **kwargs):
+ if input.device.type == 'mps':
+ output_dtype = kwargs.get('dtype', input.dtype)
+ if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
+ return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
+ return cumsum_func(input, *args, **kwargs)
+
+
+if has_mps():
+ if version.parse(torch.__version__) < version.parse("1.13"):
+ # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+ torch.Tensor.to = tensor_to_fix
+ torch.nn.functional.layer_norm = layer_norm_fix
+ torch.Tensor.numpy = numpy_fix
+ elif version.parse(torch.__version__) > version.parse("1.13.1"):
+ if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
+ torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
+ torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
+ orig_narrow = torch.narrow
+ torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
diff --git a/modules/errors.py b/modules/errors.py
index 372dc51a..a668c014 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -2,9 +2,30 @@ import sys
import traceback
+def print_error_explanation(message):
+ lines = message.strip().split("\n")
+ max_len = max([len(x) for x in lines])
+
+ print('=' * max_len, file=sys.stderr)
+ for line in lines:
+ print(line, file=sys.stderr)
+ print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ message = str(e)
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+ print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+ """)
+
+
def run(code, task):
try:
code()
except Exception as e:
- print(f"{task}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ display(task, e)
diff --git a/modules/extras.py b/modules/extras.py
index 0ad8deec..7407bfe3 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -19,8 +19,6 @@ from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
import modules.codeformer_model
-import piexif
-import piexif.helper
import gradio as gr
import safetensors.torch
@@ -55,9 +53,12 @@ class LruCache(OrderedDict):
cached_images: LruCache = LruCache(max_size=5)
-def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool):
+def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc()
+ shared.state.begin()
+ shared.state.job = 'extras'
+
imageArr = []
# Also keep track of original file names
imageNameArr = []
@@ -94,6 +95,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Extra operation definitions
def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-gfpgan'
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
@@ -104,6 +106,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
+ shared.state.job = 'extras-codeformer'
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)
@@ -114,6 +117,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return (res, info)
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
+ shared.state.job = 'extras-upscale'
upscaler = shared.sd_upscalers[scaler_index]
res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
@@ -180,6 +184,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
+
+ shared.state.textinfo = f'Processing image {image_name}'
+
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -188,18 +195,25 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for op in extras_ops:
image, info = op(image, info)
- if opts.use_original_name_batch and image_name != None:
+ if opts.use_original_name_batch and image_name is not None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
basename = ''
- images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
- no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
-
- if opts.enable_pnginfo:
+ if opts.enable_pnginfo: # append info before save
image.info = existing_pnginfo
image.info["extras"] = info
+ if save_output:
+ # Add upscaler name as a suffix.
+ suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
+ # Add second upscaler if applicable.
+ if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
+ suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
+
+ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
+ no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
+
if extras_mode != 2 or show_extras_results :
outputs.append(image)
@@ -235,6 +249,9 @@ def run_pnginfo(image):
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+ shared.state.begin()
+ shared.state.job = 'model-merge'
+
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -256,8 +273,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_func1, theta_func2 = theta_funcs[interp_method]
if theta_func1 and not tertiary_model_info:
+ shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
+ shared.state.end()
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
+ shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
@@ -274,6 +294,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2
+ shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@@ -284,6 +305,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
a = theta_0[key]
b = theta_1[key]
+ shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -323,6 +345,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename)
+ shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -334,4 +357,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
sd_models.list_models()
print("Checkpoint saved.")
+ shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ shared.state.end()
+
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 565e342d..12a9de3d 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,12 +1,13 @@
import base64
import io
+import math
import os
import re
from pathlib import Path
import gradio as gr
from modules.shared import script_path
-from modules import shared
+from modules import shared, ui_tempdir
import tempfile
from PIL import Image
@@ -14,6 +15,7 @@ re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
+re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update())
paste_fields = {}
bind_list = []
@@ -35,9 +37,12 @@ def quote(text):
def image_from_url_text(filedata):
- if type(filedata) == dict and filedata["is_file"]:
+ if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ filedata = filedata[0]
+
+ if type(filedata) == dict and filedata.get("is_file", False):
filename = filedata["name"]
- is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
+ is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
return Image.open(filename)
@@ -92,7 +97,7 @@ def integrate_settings_paste_fields(component_dict):
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
- buttons[tab] = gr.Button(f"Send to {tab}")
+ buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
return buttons
@@ -101,35 +106,57 @@ def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
+def send_image_and_dimensions(x):
+ if isinstance(x, Image.Image):
+ img = x
+ else:
+ img = image_from_url_text(x)
+
+ if shared.opts.send_size and isinstance(img, Image.Image):
+ w = img.width
+ h = img.height
+ else:
+ w = gr.update()
+ h = gr.update()
+
+ return img, w, h
+
+
def run_bind():
- for buttons, send_image, send_generate_info in bind_list:
+ for buttons, source_image_component, send_generate_info in bind_list:
for tab in buttons:
button = buttons[tab]
- if send_image and paste_fields[tab]["init_img"]:
- if type(send_image) == gr.Gallery:
- button.click(
- fn=lambda x: image_from_url_text(x),
- _js="extract_image_from_gallery",
- inputs=[send_image],
- outputs=[paste_fields[tab]["init_img"]],
- )
+ destination_image_component = paste_fields[tab]["init_img"]
+ fields = paste_fields[tab]["fields"]
+
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+ if source_image_component and destination_image_component:
+ if isinstance(source_image_component, gr.Gallery):
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
+ jsfunc = "extract_image_from_gallery"
else:
- button.click(
- fn=lambda x: x,
- inputs=[send_image],
- outputs=[paste_fields[tab]["init_img"]],
- )
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
+ jsfunc = None
+
+ button.click(
+ fn=func,
+ _js=jsfunc,
+ inputs=[source_image_component],
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ )
- if send_generate_info and paste_fields[tab]["fields"] is not None:
+ if send_generate_info and fields is not None:
if send_generate_info in paste_fields:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else [])
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
- outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in fields if name in paste_field_names],
)
else:
- connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
+ connect_paste(button, fields, send_generate_info)
button.click(
fn=None,
@@ -139,6 +166,58 @@ def run_bind():
)
+def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
+ """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
+
+ Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
+ parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
+
+ If the infotext has no hash, then a hypernet with the same name will be selected instead.
+ """
+ hypernet_name = hypernet_name.lower()
+ if hypernet_hash is not None:
+ # Try to match the hash in the name
+ for hypernet_key in shared.hypernetworks.keys():
+ result = re_hypernet_hash.search(hypernet_key)
+ if result is not None and result[1] == hypernet_hash:
+ return hypernet_key
+ else:
+ # Fall back to a hypernet with the same name
+ for hypernet_key in shared.hypernetworks.keys():
+ if hypernet_key.lower().startswith(hypernet_name):
+ return hypernet_key
+
+ return None
+
+
+def restore_old_hires_fix_params(res):
+ """for infotexts that specify old First pass size parameter, convert it into
+ width, height, and hr scale"""
+
+ firstpass_width = res.get('First pass size-1', None)
+ firstpass_height = res.get('First pass size-2', None)
+
+ if firstpass_width is None or firstpass_height is None:
+ return
+
+ firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
+ width = int(res.get("Size-1", 512))
+ height = int(res.get("Size-2", 512))
+
+ if firstpass_width == 0 or firstpass_height == 0:
+ # old algorithm for auto-calculating first pass size
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ firstpass_width = math.ceil(scale * width / 64) * 64
+ firstpass_height = math.ceil(scale * height / 64) * 64
+
+ res['Size-1'] = firstpass_width
+ res['Size-2'] = firstpass_height
+ res['Hires resize-1'] = width
+ res['Hires resize-2'] = height
+
+
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -188,6 +267,20 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Clip skip" not in res:
res["Clip skip"] = "1"
+ if "Hypernet strength" not in res:
+ res["Hypernet strength"] = "1"
+
+ if "Hypernet" in res:
+ hypernet_name = res["Hypernet"]
+ hypernet_hash = res.get("Hypernet hash", None)
+ res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+
+ if "Hires resize-1" not in res:
+ res["Hires resize-1"] = 0
+ res["Hires resize-2"] = 0
+
+ restore_old_hires_fix_params(res)
+
return res
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c406ffb3..b0cfbe71 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -13,7 +13,7 @@ import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared, sd_samplers
-from modules.textual_inversion import textual_inversion
+from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
@@ -277,7 +277,7 @@ def load_hypernetwork(filename):
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
+ print("Unloading hypernetwork")
shared.loaded_hypernetwork = None
@@ -378,8 +378,32 @@ def report_statistics(loss_info:dict):
print(e)
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ if type(layer_structure) == str:
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ activation_func=activation_func,
+ weight_init=weight_init,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
+ )
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -391,6 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
+ shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps
@@ -417,10 +442,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
initial_step = hypernetwork.step or 0
if initial_step >= steps:
- shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -428,7 +457,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
-
+
+ if shared.opts.save_training_settings_to_txt:
+ saved_params = dict(
+ model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds),
+ **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
+ )
+ logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
+
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
@@ -439,7 +475,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
-
+
weights = hypernetwork.weights()
hypernetwork.train_mode()
@@ -498,6 +534,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
if shared.state.interrupted:
break
+ if clip_grad:
+ clip_grad_sched.step(hypernetwork.step)
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
@@ -512,14 +551,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
_loss_step += loss.item()
scaler.scale(loss).backward()
+
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
- # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
- # scaler.unscale_(optimizer)
- # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
- # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
- # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
+
+ if clip_grad:
+ clip_grad(weights, clip_grad_sched.learn_rate)
+
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index c2d4b51c..e7f9e593 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -3,39 +3,16 @@ import os
import re
import gradio as gr
-import modules.textual_inversion.preprocess
-import modules.textual_inversion.textual_inversion
+import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared
-from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
- # Remove illegal characters from name.
- name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
- fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- if not overwrite_old:
- assert not os.path.exists(fn), f"file {fn} already exists"
-
- if type(layer_structure) == str:
- layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
-
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
- name=name,
- enable_sizes=[int(x) for x in enable_sizes],
- layer_structure=layer_structure,
- activation_func=activation_func,
- weight_init=weight_init,
- add_layer_norm=add_layer_norm,
- use_dropout=use_dropout,
- )
- hypernet.save(fn)
-
- shared.reload_hypernetworks()
-
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
def train_hypernetwork(*args):
diff --git a/modules/images.py b/modules/images.py
index 8146f580..c3a5fc8b 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -39,11 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None):
cols = math.ceil(len(imgs) / rows)
+ params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
+ script_callbacks.image_grid_callback(params)
+
w, h = imgs[0].size
- grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
+ grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black')
- for i, img in enumerate(imgs):
- grid.paste(img, box=(i % cols * w, i // cols * h))
+ for i, img in enumerate(params.imgs):
+ grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
return grid
@@ -136,8 +139,19 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
lines.append(word)
return lines
- def draw_texts(drawing, draw_x, draw_y, lines):
+ def get_font(fontsize):
+ try:
+ return ImageFont.truetype(opts.font or Roboto, fontsize)
+ except Exception:
+ return ImageFont.truetype(Roboto, fontsize)
+
+ def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
for i, line in enumerate(lines):
+ fnt = initial_fnt
+ fontsize = initial_fontsize
+ while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
+ fontsize -= 1
+ fnt = get_font(fontsize)
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
@@ -148,10 +162,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
fontsize = (width + height) // 25
line_spacing = fontsize // 2
- try:
- fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
- except Exception:
- fnt = ImageFont.truetype(Roboto, fontsize)
+ fnt = get_font(fontsize)
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
@@ -178,6 +189,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
for line in texts:
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
+ line.allowed_width = allowed_width
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
@@ -194,13 +206,13 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_heights[col] / 2
- draw_texts(d, x, y, hor_texts[col])
+ draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
for row in range(rows):
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
- draw_texts(d, x, y, ver_texts[row])
+ draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
return result
@@ -218,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts):
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
-def resize_image(resize_mode, im, width, height):
+def resize_image(resize_mode, im, width, height, upscaler_name=None):
+ """
+ Resizes an image with the specified resize_mode, width, and height.
+
+ Args:
+ resize_mode: The mode to use when resizing the image.
+ 0: Resize the image to the specified width and height.
+ 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
+ 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
+ im: The image to resize.
+ width: The width to resize the image to.
+ height: The height to resize the image to.
+ upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.
+ """
+
+ upscaler_name = upscaler_name or opts.upscaler_for_img2img
+
def resize(im, w, h):
- if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
+ if upscaler_name is None or upscaler_name == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS)
scale = max(w / im.width, h / im.height)
if scale > 1.0:
- upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
- assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
+ upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
+ assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
@@ -429,7 +457,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
basename (`str`):
The base filename which will be applied to `filename pattern`.
- seed, prompt, short_filename,
+ seed, prompt, short_filename,
extension (`str`):
Image file extension, default is `png`.
pngsectionname (`str`):
@@ -516,6 +544,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+ if image_to_save.mode == 'RGBA':
+ image_to_save = image_to_save.convert("RGB")
+
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
if opts.enable_pnginfo and info is not None:
@@ -590,7 +621,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print(f"Error parsing NovelAI iamge generation parameters:", file=sys.stderr)
+ print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return geninfo, items
@@ -613,3 +644,14 @@ def image_data(data):
pass
return '', None
+
+
+def flatten(img, bgcolor):
+ """replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
+
+ if img.mode == "RGBA":
+ background = Image.new('RGBA', img.size, bgcolor)
+ background.paste(img, mask=img)
+ img = background
+
+ return img.convert('RGB')
diff --git a/modules/img2img.py b/modules/img2img.py
index 81da4b13..ca58b5d8 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/import_hook.py b/modules/import_hook.py
new file mode 100644
index 00000000..28c67dfa
--- /dev/null
+++ b/modules/import_hook.py
@@ -0,0 +1,5 @@
+import sys
+
+# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
+if "--xformers" not in "".join(sys.argv):
+ sys.modules["xformers"] = None
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 0068b81c..738d8ff7 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -135,8 +135,9 @@ class InterrogateModels:
return caption[0]
def interrogate(self, pil_image):
- res = None
-
+ res = ""
+ shared.state.begin()
+ shared.state.job = 'interrogate'
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -172,10 +173,11 @@ class InterrogateModels:
res += ", " + match
except Exception:
- print(f"Error interrogating", file=sys.stderr)
+ print("Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
res += "<error>"
self.unload()
+ shared.state.end()
return res
diff --git a/modules/lowvram.py b/modules/lowvram.py
index aa464a95..042a0254 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -55,18 +55,20 @@ def setup_for_low_vram(sd_model, use_medvram):
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
- # remove three big modules, cond, first_stage, and unet from the model and then
+ # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
- stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
+ stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
sd_model.to(devices.device)
- sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
- # register hooks for those the first two models
+ # register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
+ if sd_model.depth_model:
+ sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'):
diff --git a/modules/memmon.py b/modules/memmon.py
index 9fb9b687..a7060f58 100644
--- a/modules/memmon.py
+++ b/modules/memmon.py
@@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread):
def read(self):
if not self.disabled:
free, total = torch.cuda.mem_get_info()
+ self.data["free"] = free
self.data["total"] = total
torch_stats = torch.cuda.memory_stats(self.device)
+ self.data["active"] = torch_stats["active.all.current"]
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
+ self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
self.data["system_peak"] = total - self.data["min_free"]
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e647f6fa..6a1a7ac8 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
+builtin_upscaler_classes = []
+forbidden_upscaler_classes = set()
+
+
+def list_builtin_upscalers():
+ load_upscalers()
+
+ builtin_upscaler_classes.clear()
+ builtin_upscaler_classes.extend(Upscaler.__subclasses__())
+
+
+def forbid_loaded_nonbuiltin_upscalers():
+ for cls in Upscaler.__subclasses__():
+ if cls not in builtin_upscaler_classes:
+ forbidden_upscaler_classes.add(cls)
+
+
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -139,6 +156,9 @@ def load_upscalers():
datas = []
commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
+ if cls in forbidden_upscaler_classes:
+ continue
+
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))
diff --git a/modules/ngrok.py b/modules/ngrok.py
index 64c9a3c2..3df2c06b 100644
--- a/modules/ngrok.py
+++ b/modules/ngrok.py
@@ -2,7 +2,7 @@ from pyngrok import ngrok, conf, exception
def connect(token, port, region):
account = None
- if token == None:
+ if token is None:
token = 'None'
else:
if ':' in token:
@@ -14,7 +14,7 @@ def connect(token, port, region):
auth_token=token, region=region
)
try:
- if account == None:
+ if account is None:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
else:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
diff --git a/modules/processing.py b/modules/processing.py
index 24c537d1..82157bc9 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -27,6 +27,7 @@ from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
+from blendmodes.blend import blendLayers, BlendType
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -39,17 +40,19 @@ def setup_color_correction(image):
return correction_target
-def apply_color_correction(correction, image):
+def apply_color_correction(correction, original_image):
logging.info("Applying color correction.")
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
- np.asarray(image),
+ np.asarray(original_image),
cv2.COLOR_RGB2LAB
),
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
+ image = blendLayers(image, original_image, BlendType.LUMINOSITY)
+
return image
@@ -73,11 +76,29 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -118,6 +139,7 @@ class StableDiffusionProcessing():
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
+ self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
if not seed_enable_extras:
@@ -132,28 +154,12 @@ class StableDiffusionProcessing():
self.all_negative_prompts = None
self.all_seeds = None
self.all_subseeds = None
+ self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1)
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- self.is_using_inpainting_conditioning = True
-
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
@@ -199,7 +205,7 @@ class StableDiffusionProcessing():
source_image * (1.0 - conditioning_mask),
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
)
-
+
# Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
@@ -235,7 +241,7 @@ class StableDiffusionProcessing():
class Processed:
- def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
self.images = images_list
self.prompt = p.prompt
self.negative_prompt = p.negative_prompt
@@ -243,6 +249,7 @@ class Processed:
self.subseed = subseed
self.subseed_strength = p.subseed_strength
self.info = info
+ self.comments = comments
self.width = p.width
self.height = p.height
self.sampler_name = p.sampler_name
@@ -314,7 +321,7 @@ class Processed:
return json.dumps(obj)
- def infotext(self, p: StableDiffusionProcessing, index):
+ def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
@@ -334,13 +341,14 @@ def slerp(val, low, high):
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
+ eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
xs = []
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
+ if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
@@ -380,8 +388,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
- if opts.eta_noise_seed_delta > 0:
- torch.manual_seed(seed + opts.eta_noise_seed_delta)
+ if eta_noise_seed_delta > 0:
+ torch.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -414,7 +422,7 @@ def fix_seed(p):
p.subseed = get_fixed_seed(p.subseed)
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@@ -429,6 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
+ "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
@@ -446,7 +455,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
- negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
+ negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -457,18 +466,26 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try:
for k, v in p.override_settings.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
- if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
+ if k == 'sd_hypernetwork':
+ shared.reload_hypernetworks() # make onchange call for changing hypernet
+
+ if k == 'sd_model_checkpoint':
+ sd_models.reload_model_weights() # make onchange call for changing SD model
+ p.sd_model = shared.sd_model
+
+ if k == 'sd_vae':
+ sd_vae.reload_vae_weights() # make onchange call for changing VAE
res = process_images_inner(p)
- finally: # restore opts to original state
- for k, v in stored_opts.items():
- setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks()
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
- if k == 'sd_vae': sd_vae.reload_vae_weights()
+ finally:
+ # restore opts to original state
+ if p.override_settings_restore_afterwards:
+ for k, v in stored_opts.items():
+ setattr(opts, k, v)
+ if k == 'sd_hypernetwork': shared.reload_hypernetworks()
+ if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
+ if k == 'sd_vae': sd_vae.reload_vae_weights()
return res
@@ -527,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
+ cached_uc = [None, None]
+ cached_c = [None, None]
+
+ def get_conds_with_caching(function, required_prompts, steps, cache):
+ """
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
+ using a cache to store the result if the same arguments have been used before.
+
+ cache is an array containing two elements. The first element is a tuple
+ representing the previously used arguments, or None if no arguments
+ have been used before. The second element is where the previously
+ computed result is stored.
+ """
+
+ if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ return cache[1]
+
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+
+ cache[0] = (required_prompts, steps)
+ return cache[1]
+
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -535,9 +575,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter
for n in range(p.n_iter):
+ p.iteration = n
+
if state.skipped:
state.skipped = False
-
+
if state.interrupted:
break
@@ -552,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
- with devices.autocast():
- uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps)
- c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
+ uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
+ c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -612,7 +653,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
- del x_samples_ddim
+ del x_samples_ddim
devices.torch_gc()
@@ -638,7 +679,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
- res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
+ res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None:
p.scripts.postprocess(p, res)
@@ -649,78 +690,115 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
- self.firstphase_width = firstphase_width
- self.firstphase_height = firstphase_height
+ self.hr_scale = hr_scale
+ self.hr_upscaler = hr_upscaler
+ self.hr_second_pass_steps = hr_second_pass_steps
+ self.hr_resize_x = hr_resize_x
+ self.hr_resize_y = hr_resize_y
+ self.hr_upscale_to_x = hr_resize_x
+ self.hr_upscale_to_y = hr_resize_y
+
+ if firstphase_width != 0 or firstphase_height != 0:
+ print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
+ self.hr_scale = self.width / firstphase_width
+ self.width = firstphase_width
+ self.height = firstphase_height
+
self.truncate_x = 0
self.truncate_y = 0
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if state.job_count == -1:
- state.job_count = self.n_iter * 2
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
else:
- state.job_count = state.job_count * 2
-
- self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
+
+ if self.hr_resize_y == 0:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ elif self.hr_resize_x == 0:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+ else:
+ target_w = self.hr_resize_x
+ target_h = self.hr_resize_y
+ src_ratio = self.width / self.height
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
+
+ if src_ratio < dst_ratio:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ else:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
+
+ # special case: the user has chosen to do nothing
+ if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
+ self.enable_hr = False
+ self.denoising_strength = None
+ self.extra_generation_params.pop("Hires upscale", None)
+ self.extra_generation_params.pop("Hires resize", None)
+ return
- if self.firstphase_width == 0 or self.firstphase_height == 0:
- desired_pixel_count = 512 * 512
- actual_pixel_count = self.width * self.height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- self.firstphase_width = math.ceil(scale * self.width / 64) * 64
- self.firstphase_height = math.ceil(scale * self.height / 64) * 64
- firstphase_width_truncated = int(scale * self.width)
- firstphase_height_truncated = int(scale * self.height)
+ if not state.processing_has_refined_job_count:
+ if state.job_count == -1:
+ state.job_count = self.n_iter
- else:
-
- width_ratio = self.width / self.firstphase_width
- height_ratio = self.height / self.firstphase_height
+ shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
+ state.job_count = state.job_count * 2
+ state.processing_has_refined_job_count = True
- if width_ratio > height_ratio:
- firstphase_width_truncated = self.firstphase_width
- firstphase_height_truncated = self.firstphase_width * self.height / self.width
- else:
- firstphase_width_truncated = self.firstphase_height * self.width / self.height
- firstphase_height_truncated = self.firstphase_height
+ if self.hr_second_pass_steps:
+ self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
- self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
- self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
+ if self.hr_upscaler is not None:
+ self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and latent_scale_mode is None:
+ assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
+
+ x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+
if not self.enable_hr:
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
- x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
-
- samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ target_width = self.hr_upscale_to_x
+ target_height = self.hr_upscale_to_y
- """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
def save_intermediate(image, index):
+ """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
+
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
- image = sd_samplers.sample_to_image(image, index)
+ image = sd_samplers.sample_to_image(image, index, approximation=0)
- images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
- if opts.use_scale_latent_for_hires_fix:
+ if latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
- # Avoid making the inpainting conditioning unless necessary as
+ # Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
@@ -738,7 +816,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
save_intermediate(image, i)
- image = images.resize_image(0, image, self.width, self.height)
+ image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
@@ -755,13 +833,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
+
+ noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
+ samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
return samples
@@ -829,9 +909,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = []
imgs = []
for img in self.init_images:
- image = img.convert("RGB")
+ image = images.flatten(img, opts.img2img_background_color)
- if crop_region is None:
+ if crop_region is None and self.resize_mode != 3:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
@@ -840,6 +920,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.overlay_images.append(image_masked.convert('RGBA'))
+ # crop_region is not None if we are doing inpaint full res
if crop_region is not None:
image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height)
@@ -876,6 +957,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ if self.resize_mode == 3:
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
if image_mask is not None:
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
diff --git a/modules/safe.py b/modules/safe.py
index 10460ad0..82d44be3 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler):
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
- if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name)
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
- if module == 'numpy.core.multiarray' and name == 'scalar':
- return numpy.core.multiarray.scalar
- if module == 'numpy' and name == 'dtype':
- return numpy.dtype
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
+ return getattr(numpy.core.multiarray, name)
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
+ return getattr(numpy, name)
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
@@ -80,7 +80,7 @@ def check_pt(filename, extra_handler):
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
-
+
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
@@ -103,12 +103,12 @@ def check_pt(filename, extra_handler):
def load(filename, *args, **kwargs):
- return load_with_extra(filename, *args, **kwargs)
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
"""
- this functon is intended to be used by extensions that want to load models with
+ this function is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious.
Use the extra_handler argument to specify a function that takes module and field name as text,
@@ -137,19 +137,56 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
return None
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
+ print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
+class Extra:
+ """
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
+ (because it's not your code making the torch.load call). The intended use is like this:
+
+```
+import torch
+from modules import safe
+
+def handler(module, name):
+ if module == 'torch' and name in ['float64', 'float16']:
+ return getattr(torch, name)
+
+ return None
+
+with safe.Extra(handler):
+ x = torch.load('model.pt')
+```
+ """
+
+ def __init__(self, handler):
+ self.handler = handler
+
+ def __enter__(self):
+ global global_extra_handler
+
+ assert global_extra_handler is None, 'already inside an Extra() block'
+ global_extra_handler = self.handler
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ global global_extra_handler
+
+ global_extra_handler = None
+
+
unsafe_torch_load = torch.load
torch.load = load
+global_extra_handler = None
+
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 8e22f875..608c5300 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -51,6 +51,13 @@ class UiTrainTabParams:
self.txt2img_preview_params = txt2img_preview_params
+class ImageGridLoopParams:
+ def __init__(self, imgs, cols, rows):
+ self.imgs = imgs
+ self.cols = cols
+ self.rows = rows
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
callbacks_app_started=[],
@@ -63,6 +70,8 @@ callback_map = dict(
callbacks_cfg_denoiser=[],
callbacks_before_component=[],
callbacks_after_component=[],
+ callbacks_image_grid=[],
+ callbacks_script_unloaded=[],
)
@@ -155,6 +164,22 @@ def after_component_callback(component, **kwargs):
report_exception(c, 'after_component_callback')
+def image_grid_callback(params: ImageGridLoopParams):
+ for c in callback_map['callbacks_image_grid']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'image_grid')
+
+
+def script_unloaded_callback():
+ for c in reversed(callback_map['callbacks_script_unloaded']):
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'script_unloaded')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -186,7 +211,7 @@ def on_app_started(callback):
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
- passed as an argument"""
+ passed as an argument; this function is also called when the script is reloaded. """
add_callback(callback_map['callbacks_model_loaded'], callback)
@@ -255,3 +280,18 @@ def on_before_component(callback):
def on_after_component(callback):
"""register a function to be called after a component is created. See on_before_component for more."""
add_callback(callback_map['callbacks_after_component'], callback)
+
+
+def on_image_grid(callback):
+ """register a function to be called before making an image grid.
+ The callback is called with one argument:
+ - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
+ """
+ add_callback(callback_map['callbacks_image_grid'], callback)
+
+
+def on_script_unloaded(callback):
+ """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
+ the script did should be reverted here"""
+
+ add_callback(callback_map['callbacks_script_unloaded'], callback)
diff --git a/modules/scripts.py b/modules/scripts.py
index 23ca195d..35164093 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,4 +1,5 @@
import os
+import re
import sys
import traceback
from collections import namedtuple
@@ -36,7 +37,7 @@ class Script:
def ui(self, is_img2img):
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
- Values of those returned componenbts will be passed to run() and process() functions.
+ Values of those returned components will be passed to run() and process() functions.
"""
pass
@@ -47,7 +48,7 @@ class Script:
This function should return:
- False if the script should not be shown in UI at all
- - True if the script should be shown in UI if it's scelected in the scripts drowpdown
+ - True if the script should be shown in UI if it's selected in the scripts dropdown
- script.AlwaysVisible if the script should be shown in UI at all times
"""
@@ -128,6 +129,15 @@ class Script:
"""unused"""
return ""
+ def elem_id(self, item_id):
+ """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
+
+ need_tabname = self.show(True) == self.show(False)
+ tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+ title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
+
+ return f'script_{tabname}{title}_{item_id}'
+
current_basedir = paths.script_path
@@ -280,7 +290,6 @@ class ScriptRunner:
script.group = group
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
- dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 690a9ec2..cfdb09d6 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -5,9 +5,7 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
-
-from modules.sd_hijack_optimizations import invokeAI_mps_available
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
@@ -35,26 +33,34 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
+ optimization_method = None
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
+ optimization_method = 'xformers'
+ elif cmd_opts.opt_sub_quad_attention:
+ print("Applying sub-quadratic cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
+ optimization_method = 'sub-quadratic'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
- if not invokeAI_mps_available and shared.device.type == 'mps':
- print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- else:
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'V1'
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
+ optimization_method = 'Doggettx'
+
+ return optimization_method
def undo_optimizations():
@@ -68,27 +74,37 @@ def fix_checkpoint():
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
+
class StableDiffusionModelHijack:
fixes = None
comments = []
layers = None
circular_enabled = False
clip = None
+ optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
- if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
+
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ model_embeddings = m.cond_stage_model.roberta.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
+ m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
+
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
- self.clip = m.cond_stage_model
+ self.optimization_method = apply_optimizations()
- apply_optimizations()
+ self.clip = m.cond_stage_model
+
fix_checkpoint()
def flatten(el):
@@ -101,7 +117,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
- if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
+
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ m.cond_stage_model = m.cond_stage_model.wrapped
+
+ elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
@@ -127,10 +147,10 @@ class StableDiffusionModelHijack:
def clear_comments(self):
self.comments = []
- def tokenize(self, text):
- _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
- return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
+ def get_prompt_lengths(self, text):
+ _, token_count = self.clip.process_texts([text])
+ return token_count, self.clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index b451d1cf..5520c9b2 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -1,31 +1,89 @@
import math
+from collections import namedtuple
import torch
-from modules import prompt_parser, devices
+from modules import prompt_parser, devices, sd_hijack
from modules.shared import opts
-def get_target_prompt_token_count(token_count):
- return math.ceil(max(token_count, 1) / 75) * 75
+class PromptChunk:
+ """
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
+ so just 75 tokens from prompt.
+ """
+
+ def __init__(self):
+ self.tokens = []
+ self.multipliers = []
+ self.fixes = []
+
+
+PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
+"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
+chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
+are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
+ have unlimited prompt length and assign weights to tokens in prompt.
+ """
+
def __init__(self, wrapped, hijack):
super().__init__()
+
self.wrapped = wrapped
- self.hijack = hijack
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
+ depending on model."""
+
+ self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
+ self.chunk_length = 75
+
+ def empty_chunk(self):
+ """creates an empty PromptChunk and returns it"""
+
+ chunk = PromptChunk()
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
+ return chunk
+
+ def get_target_prompt_token_count(self, token_count):
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
+
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
def tokenize(self, texts):
+ """Converts a batch of texts into a batch of token ids"""
+
raise NotImplementedError
def encode_with_transformers(self, tokens):
+ """
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
+ All python lists with tokens are assumed to have same length, usually 77.
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
+ model - can be 768 and 1024.
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
+ """
+
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
+
raise NotImplementedError
- def tokenize_line(self, line, used_custom_terms, hijack_comments):
+ def tokenize_line(self, line):
+ """
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
+ represent the prompt.
+ Returns the list and the total number of tokens in the prompt.
+ """
+
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
@@ -33,205 +91,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
tokenized = self.tokenize([text for text, _ in parsed])
- fixes = []
- remade_tokens = []
- multipliers = []
+ chunks = []
+ chunk = PromptChunk()
+ token_count = 0
last_comma = -1
- for tokens, (text, weight) in zip(tokenized, parsed):
- i = 0
- while i < len(tokens):
- token = tokens[i]
+ def next_chunk():
+ """puts current chunk into the list of results and produces the next one - empty"""
+ nonlocal token_count
+ nonlocal last_comma
+ nonlocal chunk
+
+ token_count += len(chunk.tokens)
+ to_add = self.chunk_length - len(chunk.tokens)
+ if to_add > 0:
+ chunk.tokens += [self.id_end] * to_add
+ chunk.multipliers += [1.0] * to_add
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
+
+ last_comma = -1
+ chunks.append(chunk)
+ chunk = PromptChunk()
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ position = 0
+ while position < len(tokens):
+ token = tokens[position]
if token == self.comma_token:
- last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
- last_comma += 1
- reloc_tokens = remade_tokens[last_comma:]
- reloc_mults = multipliers[last_comma:]
+ last_comma = len(chunk.tokens)
+
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
+ break_location = last_comma + 1
+
+ reloc_tokens = chunk.tokens[break_location:]
+ reloc_mults = chunk.multipliers[break_location:]
- remade_tokens = remade_tokens[:last_comma]
- length = len(remade_tokens)
+ chunk.tokens = chunk.tokens[:break_location]
+ chunk.multipliers = chunk.multipliers[:break_location]
- rem = int(math.ceil(length / 75)) * 75 - length
- remade_tokens += [self.id_end] * rem + reloc_tokens
- multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+ next_chunk()
+ chunk.tokens = reloc_tokens
+ chunk.multipliers = reloc_mults
+ if len(chunk.tokens) == self.chunk_length:
+ next_chunk()
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
if embedding is None:
- remade_tokens.append(token)
- multipliers.append(weight)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- iteration = len(remade_tokens) // 75
- if (len(remade_tokens) + emb_len) // 75 != iteration:
- rem = (75 * (iteration + 1) - len(remade_tokens))
- remade_tokens += [self.id_end] * rem
- multipliers += [1.0] * rem
- iteration += 1
- fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
- remade_tokens += [0] * emb_len
- multipliers += [weight] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- token_count = len(remade_tokens)
- prompt_target_length = get_target_prompt_token_count(token_count)
- tokens_to_add = prompt_target_length - len(remade_tokens)
-
- remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
- multipliers = multipliers + [1.0] * tokens_to_add
-
- return remade_tokens, fixes, multipliers, token_count
-
- def process_text(self, texts):
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
+ chunk.tokens.append(token)
+ chunk.multipliers.append(weight)
+ position += 1
+ continue
+
+ emb_len = int(embedding.vec.shape[0])
+ if len(chunk.tokens) + emb_len > self.chunk_length:
+ next_chunk()
+
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
+
+ chunk.tokens += [0] * emb_len
+ chunk.multipliers += [weight] * emb_len
+ position += embedding_length_in_tokens
+
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
+ next_chunk()
+
+ return chunks, token_count
+
+ def process_texts(self, texts):
+ """
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
+ length, in tokens, of all texts.
+ """
+
token_count = 0
cache = {}
- batch_multipliers = []
+ batch_chunks = []
for line in texts:
if line in cache:
- remade_tokens, fixes, multipliers = cache[line]
+ chunks = cache[line]
else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
+ chunks, current_token_count = self.tokenize_line(line)
token_count = max(current_token_count, token_count)
- cache[line] = (remade_tokens, fixes, multipliers)
+ cache[line] = chunks
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
+ batch_chunks.append(chunks)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+ return batch_chunks, token_count
- def process_text_old(self, texts):
- id_start = self.id_start
- id_end = self.id_end
- maxlen = self.wrapped.max_length # you get to stay at 77
- used_custom_terms = []
- remade_batch_tokens = []
- hijack_comments = []
- hijack_fixes = []
- token_count = 0
+ def forward(self, texts):
+ """
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ An example shape returned by this function can be: (2, 77, 768).
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
+ """
- cache = {}
- batch_tokens = self.tokenize(texts)
- batch_multipliers = []
- for tokens in batch_tokens:
- tuple_tokens = tuple(tokens)
+ if opts.use_old_emphasis_implementation:
+ import modules.sd_hijack_clip_old
+ return modules.sd_hijack_clip_old.forward_old(self, texts)
- if tuple_tokens in cache:
- remade_tokens, fixes, multipliers = cache[tuple_tokens]
- else:
- fixes = []
- remade_tokens = []
- multipliers = []
- mult = 1.0
-
- i = 0
- while i < len(tokens):
- token = tokens[i]
-
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
-
- mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
- if mult_change is not None:
- mult *= mult_change
- i += 1
- elif embedding is None:
- remade_tokens.append(token)
- multipliers.append(mult)
- i += 1
- else:
- emb_len = int(embedding.vec.shape[0])
- fixes.append((len(remade_tokens), embedding))
- remade_tokens += [0] * emb_len
- multipliers += [mult] * emb_len
- used_custom_terms.append((embedding.name, embedding.checksum()))
- i += embedding_length_in_tokens
-
- if len(remade_tokens) > maxlen - 2:
- vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
- ovf = remade_tokens[maxlen - 2:]
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
- token_count = len(remade_tokens)
- remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
- remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
-
- multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
- multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
-
- remade_batch_tokens.append(remade_tokens)
- hijack_fixes.append(fixes)
- batch_multipliers.append(multipliers)
- return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
-
- def forward(self, text):
- use_old = opts.use_old_emphasis_implementation
- if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
- else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
-
- self.hijack.comments += hijack_comments
-
- if len(used_custom_terms) > 0:
- self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
-
- if use_old:
- self.hijack.fixes = hijack_fixes
- return self.process_tokens(remade_batch_tokens, batch_multipliers)
-
- z = None
- i = 0
- while max(map(len, remade_batch_tokens)) != 0:
- rem_tokens = [x[75:] for x in remade_batch_tokens]
- rem_multipliers = [x[75:] for x in batch_multipliers]
-
- self.hijack.fixes = []
- for unfiltered in hijack_fixes:
- fixes = []
- for fix in unfiltered:
- if fix[0] == i:
- fixes.append(fix[1])
- self.hijack.fixes.append(fixes)
-
- tokens = []
- multipliers = []
- for j in range(len(remade_batch_tokens)):
- if len(remade_batch_tokens[j]) > 0:
- tokens.append(remade_batch_tokens[j][:75])
- multipliers.append(batch_multipliers[j][:75])
- else:
- tokens.append([self.id_end] * 75)
- multipliers.append([1.0] * 75)
-
- z1 = self.process_tokens(tokens, multipliers)
- z = z1 if z is None else torch.cat((z, z1), axis=-2)
-
- remade_batch_tokens = rem_tokens
- batch_multipliers = rem_multipliers
- i += 1
+ batch_chunks, token_count = self.process_texts(texts)
- return z
+ used_embeddings = {}
+ chunk_count = max([len(x) for x in batch_chunks])
- def process_tokens(self, remade_batch_tokens, batch_multipliers):
- if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
- batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+ zs = []
+ for i in range(chunk_count):
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
+
+ tokens = [x.tokens for x in batch_chunk]
+ multipliers = [x.multipliers for x in batch_chunk]
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
+
+ for fixes in self.hijack.fixes:
+ for position, embedding in fixes:
+ used_embeddings[embedding.name] = embedding
+
+ z = self.process_tokens(tokens, multipliers)
+ zs.append(z)
+ if len(used_embeddings) > 0:
+ embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+
+ return torch.hstack(zs)
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ """
+ sends one single prompt chunk to be encoded by transformers neural network.
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
+ corresponds to one token.
+ """
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
@@ -240,8 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
- batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
- batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
@@ -254,10 +258,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
- self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
+ vocab = self.tokenizer.get_vocab()
+
+ self.comma_token = vocab.get(',</w>', None)
self.token_mults = {}
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
@@ -296,6 +303,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
new file mode 100644
index 00000000..6d9fbbe6
--- /dev/null
+++ b/modules/sd_hijack_clip_old.py
@@ -0,0 +1,81 @@
+from modules import sd_hijack_clip
+from modules import shared
+
+
+def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ id_start = self.id_start
+ id_end = self.id_end
+ maxlen = self.wrapped.max_length # you get to stay at 77
+ used_custom_terms = []
+ remade_batch_tokens = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
+
+ cache = {}
+ batch_tokens = self.tokenize(texts)
+ batch_multipliers = []
+ for tokens in batch_tokens:
+ tuple_tokens = tuple(tokens)
+
+ if tuple_tokens in cache:
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
+ else:
+ fixes = []
+ remade_tokens = []
+ multipliers = []
+ mult = 1.0
+
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+
+ mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
+ if mult_change is not None:
+ mult *= mult_change
+ i += 1
+ elif embedding is None:
+ remade_tokens.append(token)
+ multipliers.append(mult)
+ i += 1
+ else:
+ emb_len = int(embedding.vec.shape[0])
+ fixes.append((len(remade_tokens), embedding))
+ remade_tokens += [0] * emb_len
+ multipliers += [mult] * emb_len
+ used_custom_terms.append((embedding.name, embedding.checksum()))
+ i += embedding_length_in_tokens
+
+ if len(remade_tokens) > maxlen - 2:
+ vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
+ ovf = remade_tokens[maxlen - 2:]
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+ token_count = len(remade_tokens)
+ remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
+ remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
+
+ remade_batch_tokens.append(remade_tokens)
+ hijack_fixes.append(fixes)
+ batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+
+def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)
+
+ self.hijack.comments += hijack_comments
+
+ if len(used_custom_terms) > 0:
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+
+ self.hijack.fixes = hijack_fixes
+ return self.process_tokens(remade_batch_tokens, batch_multipliers)
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 938f9a58..31d2c898 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,3 +1,4 @@
+import os
import torch
from einops import repeat
@@ -11,191 +12,6 @@ from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
-# =================================================================================================
-# Monkey patch DDIMSampler methods from RunwayML repo directly.
-# Adapted from:
-# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
-# =================================================================================================
-@torch.no_grad()
-def sample_ddim(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list):
- ctmp = ctmp[0]
- cbs = ctmp.shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
-
- samples, intermediates = self.ddim_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
-@torch.no_grad()
-def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None):
- b, *_, device = *x.shape, x.device
-
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- if isinstance(c, dict):
- assert isinstance(unconditional_conditioning, dict)
- c_in = dict()
- for k in c:
- if isinstance(c[k], list):
- c_in[k] = [
- torch.cat([unconditional_conditioning[k][i], c[k][i]])
- for i in range(len(c[k]))
- ]
- else:
- c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
- else:
- c_in = torch.cat([unconditional_conditioning, c])
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
-
-# =================================================================================================
-# Monkey patch PLMSSampler methods.
-# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
-# Adapted from:
-# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
-# =================================================================================================
-@torch.no_grad()
-def sample_plms(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list):
- ctmp = ctmp[0]
- cbs = ctmp.shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- print(f'Data shape for PLMS sampling is {size}')
-
- samples, intermediates = self.plms_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -209,7 +25,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
-
+
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
@@ -278,61 +94,18 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
-
-# =================================================================================================
-# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
-# Adapted from:
-# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
-# =================================================================================================
-@torch.no_grad()
-def get_unconditional_conditioning(self, batch_size, null_label=None):
- if null_label is not None:
- xc = null_label
- if isinstance(xc, ListConfig):
- xc = list(xc)
- if isinstance(xc, dict) or isinstance(xc, list):
- c = self.get_learned_conditioning(xc)
- else:
- if hasattr(xc, "to"):
- xc = xc.to(self.device)
- c = self.get_learned_conditioning(xc)
- else:
- # todo: get null label from cond_stage_model
- raise NotImplementedError()
- c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
- return c
+def should_hijack_inpainting(checkpoint_info):
+ from modules import sd_models
-class LatentInpaintDiffusion(LatentDiffusion):
- def __init__(
- self,
- concat_keys=("mask", "masked_image"),
- masked_image_key="masked_image",
- *args,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
- self.masked_image_key = masked_image_key
- assert self.masked_image_key in concat_keys
- self.concat_keys = concat_keys
+ ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
+ cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
-
-def should_hijack_inpainting(checkpoint_info):
- return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
+ return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
def do_inpainting_hijack():
- # most of this stuff seems to no longer be needed because it is already included into SD2.0
- # LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
- # this file should be cleaned up later if weverything tuens out to work fine
-
- # ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
- ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
-
- # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
- # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
- # ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 98123fbf..cdc63ed7 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,7 +1,7 @@
import math
import sys
import traceback
-import importlib
+import psutil
import torch
from torch import einsum
@@ -12,6 +12,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
+from .sub_quadratic_attention import efficient_dot_product_attention
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
print(traceback.format_exc(), file=sys.stderr)
+def get_available_vram():
+ if shared.device.type == 'cuda':
+ stats = torch.cuda.memory_stats(shared.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+ return mem_free_total
+ else:
+ return psutil.virtual_memory().available
+
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
@@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
+ mem_free_total = get_available_vram()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
@@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
-def check_for_psutil():
- try:
- spec = importlib.util.find_spec('psutil')
- return spec is not None
- except ModuleNotFoundError:
- return False
-
-invokeAI_mps_available = check_for_psutil()
-
-# -- Taken from https://github.com/invoke-ai/InvokeAI --
-if invokeAI_mps_available:
- import psutil
- mem_total_gb = psutil.virtual_memory().total // (1 << 30)
+# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
+mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
@@ -152,14 +151,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
return r
def einsum_op_mps_v1(q, k, v):
- if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ if slice_size % 4096 == 0:
+ slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v):
- if mem_total_gb > 8 and q.shape[1] <= 4096:
+ if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
@@ -188,7 +189,7 @@ def einsum_op(q, k, v):
return einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
- if mem_total_gb >= 32:
+ if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v)
@@ -213,6 +214,71 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# -- End of code from https://github.com/invoke-ai/InvokeAI --
+
+# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
+# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
+def sub_quad_attention_forward(self, x, context=None, mask=None):
+ assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
+
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ del context, context_k, context_v, x
+
+ q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+ v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
+
+ x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+
+ x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
+
+ out_proj, dropout = self.to_out
+ x = out_proj(x)
+ x = dropout(x)
+
+ return x
+
+def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
+ bytes_per_token = torch.finfo(q.dtype).bits//8
+ batch_x_heads, q_tokens, _ = q.shape
+ _, k_tokens, _ = k.shape
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
+
+ if chunk_threshold is None:
+ chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
+ elif chunk_threshold == 0:
+ chunk_threshold_bytes = None
+ else:
+ chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
+
+ if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
+ kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
+ elif kv_chunk_size_min == 0:
+ kv_chunk_size_min = None
+
+ if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
+ # the big matmul fits into our memory limit; do everything in 1 chunk,
+ # i.e. send it down the unchunked fast-path
+ query_chunk_size = q_tokens
+ kv_chunk_size = k_tokens
+
+ return efficient_dot_product_attention(
+ q,
+ k,
+ v,
+ query_chunk_size=q_chunk_size,
+ kv_chunk_size=kv_chunk_size,
+ kv_chunk_size_min = kv_chunk_size_min,
+ use_checkpoint=use_checkpoint,
+ )
+
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
@@ -250,12 +316,7 @@ def cross_attention_attnblock_forward(self, x):
h_ = torch.zeros_like(k, device=q.device)
- stats = torch.cuda.memory_stats(q.device)
- mem_active = stats['active_bytes.all.current']
- mem_reserved = stats['reserved_bytes.all.current']
- mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
- mem_free_torch = mem_reserved - mem_active
- mem_free_total = mem_free_cuda + mem_free_torch
+ mem_free_total = get_available_vram()
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
@@ -310,3 +371,19 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
+
+def sub_quad_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+ b, c, h, w = q.shape
+ q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h)
+ out = self.proj_out(out)
+ return x + out
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index 1b9d7757..18daf8c1 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -4,7 +4,7 @@ import torch
class TorchHijackForUnet:
"""
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
- this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
+ this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
"""
def __getattr__(self, item):
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
new file mode 100644
index 00000000..4ac51c38
--- /dev/null
+++ b/modules/sd_hijack_xlmr.py
@@ -0,0 +1,34 @@
+import open_clip.tokenizer
+import torch
+
+from modules import sd_hijack_clip, devices
+from modules.shared import opts
+
+
+class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+
+ self.id_start = wrapped.config.bos_token_id
+ self.id_end = wrapped.config.eos_token_id
+ self.id_pad = wrapped.config.pad_token_id
+
+ self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have </w> bits for comma
+
+ def encode_with_transformers(self, tokens):
+ # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
+ # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
+ # layer to work with - you have to use the last
+
+ attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
+ features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
+ z = features['projection_state']
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.roberta.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+
+ return embedded
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 5b37f3fe..76a89e88 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
@@ -48,6 +48,14 @@ def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
+def find_checkpoint_config(info):
+ config = os.path.splitext(info.filename)[0] + ".yaml"
+ if os.path.exists(config):
+ return config
+
+ return shared.cmd_opts.config
+
+
def list_models():
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
@@ -73,7 +81,7 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@@ -81,12 +89,7 @@ def list_models():
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
- basename, _ = os.path.splitext(filename)
- config = basename + ".yaml"
- if not os.path.exists(config):
- config = shared.cmd_opts.config
-
- checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
def get_closet_checkpoint_match(searchString):
@@ -111,18 +114,19 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
+
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
if len(checkpoints_list) == 0:
- print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
if shared.cmd_opts.ckpt is not None:
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
+ print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
checkpoint_info = next(iter(checkpoints_list.values()))
@@ -167,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
+ device = map_location or shared.weight_load_location
+ if device is None:
+ device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -227,6 +234,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ model.logvar = model.logvar.to(devices.device) # fix for training
+
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
@@ -275,12 +284,14 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
+ checkpoint_config = find_checkpoint_config(checkpoint_info)
- if checkpoint_info.config != shared.cmd_opts.config:
- print(f"Loading config from: {checkpoint_info.config}")
+ if checkpoint_config != shared.cmd_opts.config:
+ print(f"Loading config from: {checkpoint_config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@@ -288,17 +299,17 @@ def load_model(checkpoint_info=None):
gc.collect()
devices.torch_gc()
- sd_config = OmegaConf.load(checkpoint_info.config)
+ sd_config = OmegaConf.load(checkpoint_config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
- sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
+ sd_config.model.params.finetune_keys = None
- # Create a "fake" config with a different name so that we know to unload it when switching models.
- checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+ if not hasattr(sd_config.model.params, "use_ema"):
+ sd_config.model.params.use_ema = False
do_inpainting_hijack()
@@ -306,6 +317,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model)
+
load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -318,23 +330,29 @@ def load_model(checkpoint_info=None):
sd_model.eval()
shared.sd_model = sd_model
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
+
script_callbacks.model_loaded_callback(sd_model)
- print(f"Model loaded.")
+ print("Model loaded.")
+
return sd_model
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
-
+
if not sd_model:
sd_model = shared.sd_model
+ current_checkpoint_info = sd_model.sd_checkpoint_info
+ checkpoint_config = find_checkpoint_config(current_checkpoint_info)
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
@@ -347,13 +365,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info)
+ try:
+ load_model_weights(sd_model, checkpoint_info)
+ except Exception as e:
+ print("Failed to load checkpoint, restoring previous")
+ load_model_weights(sd_model, current_checkpoint_info)
+ raise
+ finally:
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
- sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
+ print("Weights loaded.")
- print(f"Weights loaded.")
return sd_model
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 4c123d3b..01221b89 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -9,7 +9,7 @@ import k_diffusion.sampling
import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing, images
+from modules import prompt_parser, devices, processing, images, sd_vae_approx
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -23,16 +23,16 @@ samplers_k_diffusion = [
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
@@ -97,8 +97,9 @@ sampler_extra_params = {
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
- steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
- t_enc = p.steps - 1
+ requested_steps = (steps or p.steps)
+ steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
+ t_enc = requested_steps - 1
else:
steps = p.steps
t_enc = int(min(p.denoising_strength, 0.999) * steps)
@@ -106,20 +107,32 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-def single_sample_to_image(sample):
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+ if approximation == 2:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 1:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
-def sample_to_image(samples, index=0):
- return single_sample_to_image(samples[index])
+def sample_to_image(samples, index=0, approximation=None):
+ return single_sample_to_image(samples[index], approximation)
-def samples_to_image_grid(samples):
- return images.image_grid([single_sample_to_image(sample) for sample in samples])
+def samples_to_image_grid(samples, approximation=None):
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def store_latent(decoded):
@@ -288,6 +301,16 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@@ -329,12 +352,7 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -444,16 +462,33 @@ class KDiffusionSampler:
return extra_params_kwargs
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = setup_img2img_steps(p, steps)
+ def get_sigmas(self, p, steps):
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
+ discard_next_to_last_sigma = True
+ p.extra_generation_params["Discard penultimate sigma"] = True
+
+ steps += 1 if discard_next_to_last_sigma else 0
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
+ if discard_next_to_last_sigma:
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
+ return sigmas
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = setup_img2img_steps(p, steps)
+
+ sigmas = self.get_sigmas(p, steps)
+
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
@@ -485,12 +520,7 @@ class KDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
- if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
- elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
- else:
- sigmas = self.model_wrap.get_sigmas(steps)
+ sigmas = self.get_sigmas(p, steps)
x = x * sigmas[0]
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 25638a83..ac71d62d 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,5 +1,6 @@
import torch
import os
+import collections
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
@@ -30,6 +31,7 @@ base_vae = None
loaded_vae_file = None
checkpoint_info = None
+checkpoints_loaded = collections.OrderedDict()
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
@@ -149,13 +151,30 @@ def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
+ cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
+
if vae_file:
- assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
- print(f"Loading VAE weights from: {vae_file}")
- store_base_vae(model)
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
- _load_vae_dict(model, vae_dict_1)
+ if cache_enabled and vae_file in checkpoints_loaded:
+ # use vae checkpoint cache
+ print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
+ store_base_vae(model)
+ _load_vae_dict(model, checkpoints_loaded[vae_file])
+ else:
+ assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
+ print(f"Loading VAE weights from: {vae_file}")
+ store_base_vae(model)
+ vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+ _load_vae_dict(model, vae_dict_1)
+
+ if cache_enabled:
+ # cache newly loaded vae
+ checkpoints_loaded[vae_file] = vae_dict_1.copy()
+
+ # clean up cache if limit is reached
+ if cache_enabled:
+ while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
+ checkpoints_loaded.popitem(last=False) # LRU
# If vae used is not in dict, update it
# It will be removed on refresh though
@@ -208,5 +227,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
- print(f"VAE Weights loaded.")
+ print("VAE Weights loaded.")
return sd_model
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
new file mode 100644
index 00000000..0a58542d
--- /dev/null
+++ b/modules/sd_vae_approx.py
@@ -0,0 +1,58 @@
+import os
+
+import torch
+from torch import nn
+from modules import devices, paths
+
+sd_vae_approx_model = None
+
+
+class VAEApprox(nn.Module):
+ def __init__(self):
+ super(VAEApprox, self).__init__()
+ self.conv1 = nn.Conv2d(4, 8, (7, 7))
+ self.conv2 = nn.Conv2d(8, 16, (5, 5))
+ self.conv3 = nn.Conv2d(16, 32, (3, 3))
+ self.conv4 = nn.Conv2d(32, 64, (3, 3))
+ self.conv5 = nn.Conv2d(64, 32, (3, 3))
+ self.conv6 = nn.Conv2d(32, 16, (3, 3))
+ self.conv7 = nn.Conv2d(16, 8, (3, 3))
+ self.conv8 = nn.Conv2d(8, 3, (3, 3))
+
+ def forward(self, x):
+ extra = 11
+ x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
+ x = nn.functional.pad(x, (extra, extra, extra, extra))
+
+ for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
+ x = layer(x)
+ x = nn.functional.leaky_relu(x, 0.1)
+
+ return x
+
+
+def model():
+ global sd_vae_approx_model
+
+ if sd_vae_approx_model is None:
+ sd_vae_approx_model = VAEApprox()
+ sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
+ sd_vae_approx_model.eval()
+ sd_vae_approx_model.to(devices.device, devices.dtype)
+
+ return sd_vae_approx_model
+
+
+def cheap_approximation(sample):
+ # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
+
+ coefs = torch.tensor([
+ [0.298, 0.207, 0.208],
+ [0.187, 0.286, 0.173],
+ [-0.158, 0.189, 0.264],
+ [-0.184, -0.271, -0.473],
+ ]).to(sample.device)
+
+ x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
+
+ return x_sample
diff --git a/modules/shared.py b/modules/shared.py
index 272267c1..a6712dae 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -5,6 +5,7 @@ import os
import sys
import time
+from PIL import Image
import gradio as gr
import tqdm
@@ -13,7 +14,7 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading
+from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path
@@ -22,7 +23,7 @@ demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
-parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
+parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
@@ -55,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
+parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
+parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
+parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
+parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
@@ -81,6 +86,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
+parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
@@ -108,6 +114,17 @@ restricted_opts = {
"outdir_save",
}
+ui_reorder_categories = [
+ "sampler",
+ "dimensions",
+ "cfg",
+ "seed",
+ "checkboxes",
+ "hires_fix",
+ "batch",
+ "scripts",
+]
+
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
@@ -140,6 +157,7 @@ class State:
job = ""
job_no = 0
job_count = 0
+ processing_has_refined_job_count = False
job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
@@ -167,9 +185,10 @@ class State:
def dict(self):
obj = {
"skipped": self.skipped,
- "interrupted": self.skipped,
+ "interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
+ "job_timestamp": self.job_timestamp,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
@@ -180,6 +199,7 @@ class State:
def begin(self):
self.sampling_step = 0
self.job_count = -1
+ self.processing_has_refined_job_count = False
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None
@@ -200,12 +220,13 @@ class State:
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
def set_current_image(self):
+ if not parallel_processing_allowed:
+ return
+
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0:
self.do_set_current_image()
def do_set_current_image(self):
- if not parallel_processing_allowed:
- return
if self.current_latent is None:
return
@@ -217,6 +238,7 @@ class State:
self.current_image_sampling_step = self.sampling_step
+
state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
@@ -293,6 +315,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
+ "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
@@ -325,7 +348,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
- "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@@ -343,7 +365,8 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
- "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
+ "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
+ "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
@@ -354,6 +377,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
@@ -362,15 +386,20 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
+ "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
- "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
+options_templates.update(options_section(('compatibility', "Compatibility"), {
+ "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
+ "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+}))
+
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
@@ -383,11 +412,13 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
+ "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
}))
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
+ "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
@@ -400,7 +431,10 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
+ "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
+ 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
}))
@@ -413,6 +447,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
+ 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
}))
options_templates.update(options_section((None, "Hidden options"), {
@@ -471,7 +506,12 @@ class Options:
return False
if self.data_labels[key].onchange is not None:
- self.data_labels[key].onchange()
+ try:
+ self.data_labels[key].onchange()
+ except Exception as e:
+ errors.display(e, f"changing setting {key} to {value}")
+ setattr(self, key, oldval)
+ return False
return True
@@ -534,6 +574,16 @@ opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
+latent_upscale_default_mode = "Latent"
+latent_upscale_modes = {
+ "Latent": {"mode": "bilinear", "antialias": False},
+ "Latent (antialiased)": {"mode": "bilinear", "antialias": True},
+ "Latent (bicubic)": {"mode": "bicubic", "antialias": False},
+ "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True},
+ "Latent (nearest)": {"mode": "nearest", "antialias": False},
+ "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False},
+}
+
sd_upscalers = []
sd_model = None
@@ -567,7 +617,7 @@ class TotalTQDM:
return
if self._tqdm is None:
self.reset()
- self._tqdm.total=new_total
+ self._tqdm.total = new_total
def clear(self):
if self._tqdm is not None:
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
new file mode 100644
index 00000000..fea7aaac
--- /dev/null
+++ b/modules/sub_quadratic_attention.py
@@ -0,0 +1,205 @@
+# original source:
+# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
+# license:
+# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
+# credit:
+# Amin Rezaei (original author)
+# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
+# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
+# implementation of:
+# Self-attention Does Not Need O(n2) Memory":
+# https://arxiv.org/abs/2112.05682v2
+
+from functools import partial
+import torch
+from torch import Tensor
+from torch.utils.checkpoint import checkpoint
+import math
+from typing import Optional, NamedTuple, Protocol, List
+
+def narrow_trunc(
+ input: Tensor,
+ dim: int,
+ start: int,
+ length: int
+) -> Tensor:
+ return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
+
+class AttnChunk(NamedTuple):
+ exp_values: Tensor
+ exp_weights_sum: Tensor
+ max_score: Tensor
+
+class SummarizeChunk(Protocol):
+ @staticmethod
+ def __call__(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ ) -> AttnChunk: ...
+
+class ComputeQueryChunkAttn(Protocol):
+ @staticmethod
+ def __call__(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ ) -> Tensor: ...
+
+def _summarize_chunk(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ scale: float,
+) -> AttnChunk:
+ attn_weights = torch.baddbmm(
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
+ query,
+ key.transpose(1,2),
+ alpha=scale,
+ beta=0,
+ )
+ max_score, _ = torch.max(attn_weights, -1, keepdim=True)
+ max_score = max_score.detach()
+ exp_weights = torch.exp(attn_weights - max_score)
+ exp_values = torch.bmm(exp_weights, value)
+ max_score = max_score.squeeze(-1)
+ return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
+
+def _query_chunk_attention(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ summarize_chunk: SummarizeChunk,
+ kv_chunk_size: int,
+) -> Tensor:
+ batch_x_heads, k_tokens, k_channels_per_head = key.shape
+ _, _, v_channels_per_head = value.shape
+
+ def chunk_scanner(chunk_idx: int) -> AttnChunk:
+ key_chunk = narrow_trunc(
+ key,
+ 1,
+ chunk_idx,
+ kv_chunk_size
+ )
+ value_chunk = narrow_trunc(
+ value,
+ 1,
+ chunk_idx,
+ kv_chunk_size
+ )
+ return summarize_chunk(query, key_chunk, value_chunk)
+
+ chunks: List[AttnChunk] = [
+ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
+ ]
+ acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
+ chunk_values, chunk_weights, chunk_max = acc_chunk
+
+ global_max, _ = torch.max(chunk_max, 0, keepdim=True)
+ max_diffs = torch.exp(chunk_max - global_max)
+ chunk_values *= torch.unsqueeze(max_diffs, -1)
+ chunk_weights *= max_diffs
+
+ all_values = chunk_values.sum(dim=0)
+ all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
+ return all_values / all_weights
+
+# TODO: refactor CrossAttention#get_attention_scores to share code with this
+def _get_attention_scores_no_kv_chunking(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ scale: float,
+) -> Tensor:
+ attn_scores = torch.baddbmm(
+ torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
+ query,
+ key.transpose(1,2),
+ alpha=scale,
+ beta=0,
+ )
+ attn_probs = attn_scores.softmax(dim=-1)
+ del attn_scores
+ hidden_states_slice = torch.bmm(attn_probs, value)
+ return hidden_states_slice
+
+class ScannedChunk(NamedTuple):
+ chunk_idx: int
+ attn_chunk: AttnChunk
+
+def efficient_dot_product_attention(
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ query_chunk_size=1024,
+ kv_chunk_size: Optional[int] = None,
+ kv_chunk_size_min: Optional[int] = None,
+ use_checkpoint=True,
+):
+ """Computes efficient dot-product attention given query, key, and value.
+ This is efficient version of attention presented in
+ https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
+ Args:
+ query: queries for calculating attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ key: keys for calculating attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ value: values to be used in attention with shape of
+ `[batch * num_heads, tokens, channels_per_head]`.
+ query_chunk_size: int: query chunks size
+ kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
+ kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
+ use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
+ Returns:
+ Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
+ """
+ batch_x_heads, q_tokens, q_channels_per_head = query.shape
+ _, k_tokens, _ = key.shape
+ scale = q_channels_per_head ** -0.5
+
+ kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
+ if kv_chunk_size_min is not None:
+ kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
+
+ def get_query_chunk(chunk_idx: int) -> Tensor:
+ return narrow_trunc(
+ query,
+ 1,
+ chunk_idx,
+ min(query_chunk_size, q_tokens)
+ )
+
+ summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
+ summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
+ compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
+ _get_attention_scores_no_kv_chunking,
+ scale=scale
+ ) if k_tokens <= kv_chunk_size else (
+ # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
+ partial(
+ _query_chunk_attention,
+ kv_chunk_size=kv_chunk_size,
+ summarize_chunk=summarize_chunk,
+ )
+ )
+
+ if q_tokens <= query_chunk_size:
+ # fast-path for when there's just 1 query chunk
+ return compute_query_chunk_attn(
+ query=query,
+ key=key,
+ value=value,
+ )
+
+ # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
+ # and pass slices to be mutated, instead of torch.cat()ing the returned slices
+ res = torch.cat([
+ compute_query_chunk_attn(
+ query=get_query_chunk(i * query_chunk_size),
+ key=key,
+ value=value,
+ ) for i in range(math.ceil(q_tokens / query_chunk_size))
+ ], dim=1)
+ return res
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 2dc64c3c..88d68c76 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -28,9 +28,9 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
-
+
self.placeholder_token = placeholder_token
self.width = width
@@ -50,14 +50,14 @@ class PersonalizedBase(Dataset):
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
+
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
- raise Exception("inturrupted")
+ raise Exception("interrupted")
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
@@ -144,7 +144,7 @@ class PersonalizedDataLoader(DataLoader):
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
-
+
class BatchLoader:
def __init__(self, data):
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index dd0c0ad1..f63fc72f 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -58,14 +58,19 @@ class LearnRateScheduler:
self.finished = False
- def apply(self, optimizer, step_number):
+ def step(self, step_number):
if step_number < self.end_step:
- return
+ return False
try:
(self.learn_rate, self.end_step) = next(self.schedules)
- except Exception:
+ except StopIteration:
self.finished = True
+ return False
+ return True
+
+ def apply(self, optimizer, step_number):
+ if not self.step(step_number):
return
if self.verbose:
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
new file mode 100644
index 00000000..8b1981d5
--- /dev/null
+++ b/modules/textual_inversion/logging.py
@@ -0,0 +1,24 @@
+import datetime
+import json
+import os
+
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
+saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
+saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
+saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+
+
+def save_settings_to_file(log_directory, all_params):
+ now = datetime.datetime.now()
+ params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
+
+ keys = saved_params_all
+ if all_params.get('preview_from_txt2img'):
+ keys = keys | saved_params_previews
+
+ params.update({k: v for k, v in all_params.items() if k in keys})
+
+ filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
+ with open(os.path.join(log_directory, filename), "w") as file:
+ json.dump(params, file, indent=4)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 56b9b2eb..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src)
+ shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e28c357a..45882ed6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,6 +1,7 @@
import os
import sys
import traceback
+import inspect
import torch
import tqdm
@@ -17,15 +18,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
caption_image_overlay)
+from modules.textual_inversion.logging import save_settings_to_file
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
+ self.shape = None
+ self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -39,6 +45,13 @@ class Embedding:
torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -57,14 +70,15 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
+ self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
+ self.expected_shape = -1
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
- # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
@@ -75,21 +89,26 @@ class EmbeddingDatabase:
return embedding
- def load_textual_inversion_embeddings(self):
+ def get_expected_shape(self):
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+ return vec.shape[1]
+
+ def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
+ if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- data = []
-
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -97,8 +116,10 @@ class EmbeddingDatabase:
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
- else:
+ elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ else:
+ return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -122,23 +143,31 @@ class EmbeddingDatabase:
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- self.register_embedding(embedding, shared.sd_model)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
- for fn in os.listdir(self.embeddings_dir):
- try:
- fullfn = os.path.join(self.embeddings_dir, fn)
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
- if os.stat(fullfn).st_size == 0:
- continue
+ for root, dirs, fns in os.walk(self.embeddings_dir):
+ for fn in fns:
+ try:
+ fullfn = os.path.join(root, fn)
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
+ if os.stat(fullfn).st_size == 0:
+ continue
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
- print("Embeddings:", ', '.join(self.word_embeddings.keys()))
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+ if len(self.skipped_embeddings) > 0:
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -194,7 +223,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader()
epoch = (step - 1) // epoch_len
- epoch_step = (step - 1) % epoch_len
+ epoch_step = (step - 1) % epoch_len
csv_writer.writerow({
"step": step,
@@ -203,6 +232,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
@@ -225,11 +255,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -263,18 +294,26 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
initial_step = embedding.step or 0
if initial_step >= steps:
- shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
- # dataset loading may take a while, so input validations and early returns should be done before this
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
-
+
pin_memory = shared.opts.pin_memory
-
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ if shared.opts.save_training_settings_to_txt:
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
@@ -285,6 +324,19 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -295,12 +347,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
loss_step = 0
_loss_step = 0 #internal
-
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
-
+
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ img_c = None
+
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
@@ -318,22 +372,34 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if shared.state.interrupted:
break
+ if clip_grad:
+ clip_grad_sched.step(embedding.step)
+
with devices.autocast():
- # c = stack_conds(batch.cond).to(devices.device)
- # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
- # print(mask)
- # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
- loss = shared.sd_model(x, c)[0] / gradient_step
+
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
+
+ loss = shared.sd_model(x, cond)[0] / gradient_step
del x
-
+
_loss_step += loss.item()
scaler.scale(loss).backward()
-
+
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
+
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
+
scaler.step(optimizer)
scaler.update()
embedding.step += 1
@@ -352,9 +418,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- #if shared.opts.save_optimizer_state:
- #embedding.optimizer_state_dict = optimizer.state_dict()
- save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@@ -444,7 +508,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
@@ -456,7 +520,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename
-def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -467,6 +531,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint
diff --git a/modules/txt2img.py b/modules/txt2img.py
index c8f81176..38b5f591 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -8,7 +8,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -33,8 +33,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
tiling=tiling,
enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None,
- firstphase_width=firstphase_width if enable_hr else None,
- firstphase_height=firstphase_height if enable_hr else None,
+ hr_scale=hr_scale,
+ hr_upscaler=hr_upscaler,
+ hr_second_pass_steps=hr_second_pass_steps,
+ hr_resize_x=hr_resize_x,
+ hr_resize_y=hr_resize_y,
)
p.scripts = modules.scripts.scripts_txt2img
@@ -59,4 +62,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if opts.do_not_show_images:
processed.images = []
- return processed.images, generation_info_js, plaintext_to_html(processed.info)
+ return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments)
diff --git a/modules/ui.py b/modules/ui.py
index 28481e33..6c765262 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -20,6 +20,7 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts
@@ -49,10 +50,14 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
-if cmd_opts.ngrok != None:
+if cmd_opts.ngrok is not None:
import modules.ngrok as ngrok
print('ngrok authtoken detected, trying to connect...')
- ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region)
+ ngrok.connect(
+ cmd_opts.ngrok,
+ cmd_opts.port if cmd_opts.port is not None else 7860,
+ cmd_opts.ngrok_region
+ )
def gr_show(visible=True):
@@ -76,7 +81,6 @@ css_hide_progressbar = """
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
-art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
@@ -155,19 +159,17 @@ def save_files(js_data, images, do_make_zip, index):
zip_file.writestr(filenames[i], f.read())
fullfns.insert(0, zip_filepath)
- return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
-
+ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
-
-def calc_time_left(progress, threshold, label, force_display):
+def calc_time_left(progress, threshold, label, force_display, show_eta):
if progress == 0:
return ""
else:
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
- if (eta_relative > threshold and progress > 0.02) or force_display:
+ if (eta_relative > threshold and show_eta) or force_display:
if eta_relative > 3600:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
elif eta_relative > 60:
@@ -189,7 +191,10 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
- time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display )
+ # Show progress percentage and time left at the same moment, and base it also on steps done
+ show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
+
+ time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
if time_left != "":
shared.state.time_left_force_display = True
@@ -197,7 +202,7 @@ def check_progress_call(id_part):
progressbar = ""
if opts.show_progressbar:
- progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}</div></div>"""
+ progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
image = gr_show(False)
preview_visibility = gr_show(False)
@@ -230,13 +235,6 @@ def check_progress_call_initial(id_part):
return check_progress_call(id_part)
-def roll_artist(prompt):
- allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
- artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
-
- return prompt + ", " + artist.name if prompt != '' else artist.name
-
-
def visit(x, func, path=""):
if hasattr(x, 'children'):
for c in x.children:
@@ -258,6 +256,20 @@ def add_style(name: str, prompt: str, negative_prompt: str):
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)]
+def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
+ from modules import processing, devices
+
+ if not enable:
+ return ""
+
+ p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
+
+ with devices.autocast():
+ p.init([""], [0], [0])
+
+ return f"resize to: <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
+
+
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
@@ -266,7 +278,7 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
def interrogate(image):
- prompt = shared.interrogator.interrogate(image)
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr_show(True) if prompt is None else prompt
@@ -276,35 +288,31 @@ def interrogate_deepbooru(image):
return gr_show(True) if prompt is None else prompt
-def create_seed_inputs():
- with gr.Row():
- with gr.Box():
- with gr.Row(elem_id='seed_row'):
- seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1)
- seed.style(container=False)
- random_seed = gr.Button(random_symbol, elem_id='random_seed')
- reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed')
+def create_seed_inputs(target_interface):
+ with FormRow(elem_id=target_interface + '_seed_row'):
+ seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
+ seed.style(container=False)
+ random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
+ reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
- with gr.Box(elem_id='subseed_show_box'):
- seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False)
+ with gr.Group(elem_id=target_interface + '_subseed_show_box'):
+ seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
- with gr.Row(visible=False) as seed_extra_row_1:
+ with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
seed_extras.append(seed_extra_row_1)
- with gr.Box():
- with gr.Row(elem_id='subseed_row'):
- subseed = gr.Number(label='Variation seed', value=-1)
- subseed.style(container=False)
- random_subseed = gr.Button(random_symbol, elem_id='random_subseed')
- reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed')
- subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01)
-
- with gr.Row(visible=False) as seed_extra_row_2:
+ subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
+ subseed.style(container=False)
+ random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
+ reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
+ subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
+
+ with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
- seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0)
- seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0)
+ seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
+ seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@@ -374,7 +382,7 @@ def update_token_counter(text, steps):
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
- tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
+ token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
@@ -399,7 +407,6 @@ def create_toprow(is_img2img):
)
with gr.Column(scale=1, elem_id="roll_col"):
- roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste")
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
@@ -442,13 +449,11 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=1, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
- prompt_style.save_to_config = True
with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
- prompt_style2.save_to_config = True
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
@@ -496,7 +501,7 @@ def apply_setting(key, value):
return
valtype = type(opts.data_labels[key].default)
- oldval = opts.data[key]
+ oldval = opts.data.get(key, None)
opts.data[key] = valtype(value) if valtype != type(None) else value
if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
@@ -528,7 +533,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele
return gr.update(**(args or {}))
- refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn=refresh,
inputs=[],
@@ -557,6 +562,8 @@ Requested path was: {f}
os.startfile(path)
elif platform.system() == "Darwin":
sp.Popen(["open", path])
+ elif "microsoft-standard-WSL2" in platform.uname().release:
+ sp.Popen(["wsl-open", path])
else:
sp.Popen(["xdg-open", path])
@@ -566,13 +573,14 @@ Requested path was: {f}
generation_info = None
with gr.Column():
- with gr.Row():
+ with gr.Row(elem_id=f"image_buttons_{tabname}"):
+ open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
+
if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}')
+ save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
- button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
- open_folder_button = gr.Button(folder_symbol, elem_id=button_id)
open_folder_button.click(
fn=lambda: open_folder(opts.outdir_samples or outdir),
@@ -582,14 +590,13 @@ Requested path was: {f}
if tabname != "extras":
with gr.Row():
- do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False)
-
- with gr.Row():
- download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False)
+ download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group():
- html_info = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
+ generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
@@ -602,25 +609,61 @@ Requested path was: {f}
save.click(
fn=wrap_gradio_call(save_files),
- _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]",
+ _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[
generation_info,
result_gallery,
- do_make_zip,
+ html_info,
html_info,
],
outputs=[
download_files,
+ html_log,
+ ]
+ )
+
+ save_zip.click(
+ fn=wrap_gradio_call(save_files),
+ _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
+ inputs=[
+ generation_info,
+ result_gallery,
html_info,
html_info,
- html_info,
+ ],
+ outputs=[
+ download_files,
+ html_log,
]
)
+
else:
- html_info_x = gr.HTML()
- html_info = gr.HTML()
+ html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
+ html_info = gr.HTML(elem_id=f'html_info_{tabname}')
+ html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+
parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
- return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
+ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+
+
+def create_sampler_and_steps_selection(choices, tabname):
+ if opts.samplers_in_dropdown:
+ with FormRow(elem_id=f"sampler_selection_{tabname}"):
+ sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ else:
+ with FormGroup(elem_id=f"sampler_selection_{tabname}"):
+ steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
+ sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index")
+
+ return steps, sampler_index
+
+
+def ordered_ui_categories():
+ user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+
+ for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)):
+ yield category
def create_ui():
@@ -635,14 +678,11 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
-
-
-
with gr.Row(elem_id='txt2img_progress_row'):
with gr.Column(scale=1):
pass
@@ -653,43 +693,76 @@ def create_ui():
setup_progressbar(progressbar, txt2img_preview, 'txt2img')
with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
- steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
-
- with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
-
- with gr.Row():
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
- tiling = gr.Checkbox(label='Tiling', value=False)
- enable_hr = gr.Checkbox(label='Highres. fix', value=False)
-
- with gr.Row(visible=False) as hr_options:
- firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0)
- firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0)
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
-
- with gr.Row(equal_height=True):
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
-
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
-
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
+ with gr.Column(variant='panel', elem_id="txt2img_settings"):
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="txt2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "cfg":
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="txt2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
+ enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
+
+ elif category == "hires_fix":
+ with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options:
+ with FormRow(elem_id="txt2img_hires_fix_row1"):
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+
+ with FormRow(elem_id="txt2img_hires_fix_row2"):
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="txt2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+
+ elif category == "scripts":
+ with FormGroup(elem_id="txt2img_script_container"):
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+
+ hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
+ hr_resolution_preview_args = dict(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False
+ )
- with gr.Group():
- custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+ for input in hr_resolution_preview_inputs:
+ input.change(**hr_resolution_preview_args)
- txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
+ txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
+ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
txt2img_prompt,
@@ -709,14 +782,18 @@ def create_ui():
width,
enable_hr,
denoising_strength,
- firstphase_width,
- firstphase_height,
+ hr_scale,
+ hr_upscaler,
+ hr_second_pass_steps,
+ hr_resize_x,
+ hr_resize_y,
] + custom_inputs,
outputs=[
txt2img_gallery,
generation_info,
- html_info
+ html_info,
+ html_log,
],
show_progress=False,
)
@@ -739,17 +816,7 @@ def create_ui():
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
- )
-
- roll.click(
- fn=roll_artist,
- _js="update_txt2img_tokens",
- inputs=[
- txt2img_prompt,
- ],
- outputs=[
- txt2img_prompt,
- ]
+ show_progress = False,
)
txt2img_paste_fields = [
@@ -770,8 +837,11 @@ def create_ui():
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
- (firstphase_width, "First pass size-1"),
- (firstphase_height, "First pass size-2"),
+ (hr_scale, "Hires upscale"),
+ (hr_upscaler, "Hires upscaler"),
+ (hr_second_pass_steps, "Hires steps"),
+ (hr_resize_x, "Hires resize-1"),
+ (hr_resize_y, "Hires resize-2"),
*modules.scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
@@ -793,8 +863,7 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
-
+ img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
@@ -807,14 +876,14 @@ def create_ui():
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
setup_progressbar(progressbar, img2img_preview, 'img2img')
- with gr.Row().style(equal_height=False):
- with gr.Column(variant='panel'):
+ with FormRow().style(equal_height=False):
+ with gr.Column(variant='panel', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
- with gr.TabItem('img2img', id='img2img'):
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
+ with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"):
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
- with gr.TabItem('Inpaint', id='inpaint'):
+ with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
init_img_with_mask_orig = gr.State(None)
@@ -832,54 +901,72 @@ def create_ui():
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
- with gr.Row():
- mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
- mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch)
+ with FormRow():
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur")
+ mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha")
+
+ with FormRow():
+ mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
+ inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
- with gr.Row():
- mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
- inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
+ with FormRow():
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill")
- inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
+ with FormRow():
+ with gr.Column():
+ inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res")
- with gr.Row():
- inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
- inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32)
+ with gr.Column(scale=4):
+ inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
- with gr.TabItem('Batch img2img', id='batch'):
+ with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
- img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
- img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
-
- with gr.Row():
- resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
-
- steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
- sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
-
- with gr.Group():
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
-
- with gr.Row():
- restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
- tiling = gr.Checkbox(label='Tiling', value=False)
-
- with gr.Row():
- batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
- batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
-
- with gr.Group():
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
-
- seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
-
- with gr.Group():
- custom_inputs = modules.scripts.scripts_img2img.setup_ui()
-
- img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
+ img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
+ img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
+
+ with FormRow():
+ resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
+
+ for category in ordered_ui_categories():
+ if category == "sampler":
+ steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
+
+ elif category == "dimensions":
+ with FormRow():
+ with gr.Column(elem_id="img2img_column_size", scale=4):
+ width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+
+ if opts.dimensions_and_batch_together:
+ with gr.Column(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+
+ elif category == "cfg":
+ with FormGroup():
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
+
+ elif category == "seed":
+ seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
+
+ elif category == "checkboxes":
+ with FormRow(elem_id="img2img_checkboxes"):
+ restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
+ tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
+
+ elif category == "batch":
+ if not opts.dimensions_and_batch_together:
+ with FormRow(elem_id="img2img_column_batch"):
+ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
+ batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+
+ elif category == "scripts":
+ with FormGroup(elem_id="img2img_script_container"):
+ custom_inputs = modules.scripts.scripts_img2img.setup_ui()
+
+ img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
@@ -911,7 +998,7 @@ def create_ui():
)
img2img_args = dict(
- fn=wrap_gradio_gpu_call(modules.img2img.img2img),
+ fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img",
inputs=[
dummy_component,
@@ -950,7 +1037,8 @@ def create_ui():
outputs=[
img2img_gallery,
generation_info,
- html_info
+ html_info,
+ html_log,
],
show_progress=False,
)
@@ -970,18 +1058,6 @@ def create_ui():
outputs=[img2img_prompt],
)
-
- roll.click(
- fn=roll_artist,
- _js="update_img2img_tokens",
- inputs=[
- img2img_prompt,
- ],
- outputs=[
- img2img_prompt,
- ]
- )
-
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
@@ -1034,50 +1110,50 @@ def create_ui():
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="mode_extras"):
- with gr.TabItem('Single Image'):
- extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
+ with gr.TabItem('Single Image', elem_id="extras_single_tab"):
+ extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
- with gr.TabItem('Batch Process'):
- image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
+ with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"):
+ image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
- with gr.TabItem('Batch from Directory'):
- extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.")
- extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
- show_extras_results = gr.Checkbox(label='Show result images', value=True)
+ with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"):
+ extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
+ extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
+ show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Tabs(elem_id="extras_resize_mode"):
- with gr.TabItem('Scale by'):
- upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
- with gr.TabItem('Scale to'):
+ with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"):
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
+ with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"):
with gr.Group():
with gr.Row():
- upscaling_resize_w = gr.Number(label="Width", value=512, precision=0)
- upscaling_resize_h = gr.Number(label="Height", value=512, precision=0)
- upscaling_crop = gr.Checkbox(label='Crop to fit', value=True)
+ upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
+ upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
+ upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with gr.Group():
extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
with gr.Group():
extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
- extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
+ extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility")
with gr.Group():
- gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan)
+ gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility")
with gr.Group():
- codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
- codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
+ codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility")
+ codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight")
with gr.Group():
- upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
+ upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix")
- result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
+ result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples)
submit.click(
- fn=wrap_gradio_gpu_call(modules.extras.run_extras),
+ fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']),
_js="get_extras_tab_index",
inputs=[
dummy_component,
@@ -1119,7 +1195,7 @@ def create_ui():
with gr.Column(variant='panel'):
html = gr.HTML()
- generation_info = gr.Textbox(visible=False)
+ generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info")
html2 = gr.HTML()
with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
@@ -1138,23 +1214,27 @@ def create_ui():
with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
+ create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
+
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
+ create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
+
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
- custom_name = gr.Textbox(label="Custom Name (Optional)")
- interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
- interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
+ create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
+
+ custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
+ interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
+ interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
with gr.Row():
- checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format")
- save_as_half = gr.Checkbox(value=False, label="Save as float16")
+ checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
+ save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
-
with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False):
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
@@ -1163,58 +1243,58 @@ def create_ui():
with gr.Tabs(elem_id="train_tabs"):
with gr.Tab(label="Create embedding"):
- new_embedding_name = gr.Textbox(label="Name")
- initialization_text = gr.Textbox(label="Initialization text", value="*")
- nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
- overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
+ new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
+ initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
+ nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
+ overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
- create_embedding = gr.Button(value="Create embedding", variant='primary')
+ create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
with gr.Tab(label="Create hypernetwork"):
- new_hypernetwork_name = gr.Textbox(label="Name")
- new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"])
- new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
- new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
- new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
- new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
- new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
- overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
+ new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
+ new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
+ new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
+ new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
+ new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
+ new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
+ new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
+ overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork")
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
- create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary')
+ create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
with gr.Tab(label="Preprocess images"):
- process_src = gr.Textbox(label='Source directory')
- process_dst = gr.Textbox(label='Destination directory')
- process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
- process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
- preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
+ process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
+ process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
+ process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
+ process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
+ preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row():
- process_flip = gr.Checkbox(label='Create flipped copies')
- process_split = gr.Checkbox(label='Split oversized images')
- process_focal_crop = gr.Checkbox(label='Auto focal point crop')
- process_caption = gr.Checkbox(label='Use BLIP for caption')
- process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True)
+ process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
+ process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
+ process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
+ process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
+ process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
with gr.Row(visible=False) as process_split_extra_row:
- process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
- process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
+ process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
+ process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
with gr.Row(visible=False) as process_focal_crop_row:
- process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05)
- process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05)
- process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
- process_focal_crop_debug = gr.Checkbox(label='Create debug image')
+ process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
+ process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
+ process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
+ process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
with gr.Row():
with gr.Column(scale=3):
@@ -1222,8 +1302,8 @@ def create_ui():
with gr.Column():
with gr.Row():
- interrupt_preprocessing = gr.Button("Interrupt")
- run_preprocess = gr.Button(value="Preprocess", variant='primary')
+ interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
+ run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
process_split.change(
fn=lambda show: gr_show(show),
@@ -1239,38 +1319,48 @@ def create_ui():
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
- with gr.Row():
+ with FormRow():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
- with gr.Row():
+
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
- with gr.Row():
- embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
- hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
-
- batch_size = gr.Number(label='Batch size', value=1, precision=0)
- gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
- dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
- log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
- template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
- training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512)
- training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512)
- steps = gr.Number(label='Max steps', value=100000, precision=0)
- create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
- save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
- save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
- preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
- with gr.Row():
- shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
- tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
- with gr.Row():
- latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
+
+ with FormRow():
+ embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
+ hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
+
+ with FormRow():
+ clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
+ clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
+
+ with FormRow():
+ batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
+ gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
+
+ dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
+ log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
+ template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+ training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
+ training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
+
+ with FormRow():
+ create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
+ save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+
+ save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
+ preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
+
+ shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags")
+ tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out")
+
+ latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method")
with gr.Row():
- interrupt_training = gr.Button(value="Interrupt")
- train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
- train_embedding = gr.Button(value="Train Embedding", variant='primary')
+ train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding")
+ interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training")
+ train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork")
params = script_callbacks.UiTrainTabParams(txt2img_preview_params)
@@ -1360,6 +1450,8 @@ def create_ui():
training_width,
training_height,
steps,
+ clip_grad_mode,
+ clip_grad_value,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
@@ -1389,6 +1481,8 @@ def create_ui():
training_width,
training_height,
steps,
+ clip_grad_mode,
+ clip_grad_value,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
@@ -1443,7 +1537,7 @@ def create_ui():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
- with gr.Row(variant="compact"):
+ with FormRow():
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else:
@@ -1488,41 +1582,36 @@ def create_ui():
return gr.update(value=value), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface:
- settings_submit = gr.Button(value="Apply settings", variant='primary')
- result = gr.HTML()
+ with gr.Row():
+ with gr.Column(scale=6):
+ settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ with gr.Column():
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
- settings_cols = 3
- items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
+ result = gr.HTML(elem_id="settings_result")
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
- quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
+ quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
quicksettings_list = []
- cols_displayed = 0
- items_displayed = 0
previous_section = None
- column = None
- with gr.Row(elem_id="settings").style(equal_height=False):
+ current_tab = None
+ with gr.Tabs(elem_id="settings"):
for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
if previous_section != item.section and not section_must_be_skipped:
- if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
- if column is not None:
- column.__exit__()
+ elem_id, text = item.section
- column = gr.Column(variant='panel')
- column.__enter__()
+ if current_tab is not None:
+ current_tab.__exit__()
- items_displayed = 0
- cols_displayed += 1
+ current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
+ current_tab.__enter__()
previous_section = item.section
- elem_id, text = item.section
- gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='<h1 class="gr-button-lg">{}</h1>'.format(text))
-
if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item))
components.append(dummy_component)
@@ -1532,15 +1621,21 @@ def create_ui():
component = create_setting_component(k)
component_dict[k] = component
components.append(component)
- items_displayed += 1
- with gr.Row():
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
- download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ if current_tab is not None:
+ current_tab.__exit__()
- with gr.Row():
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
- restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
+ with gr.TabItem("Actions"):
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+
+ if os.path.exists("html/licenses.html"):
+ with open("html/licenses.html", encoding="utf8") as file:
+ with gr.TabItem("Licenses"):
+ gr.HTML(file.read(), elem_id="licenses")
+
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
request_notifications.click(
fn=lambda: None,
@@ -1577,9 +1672,6 @@ def create_ui():
outputs=[],
)
- if column is not None:
- column.__exit__()
-
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
(img2img_interface, "img2img", "img2img"),
@@ -1613,7 +1705,7 @@ def create_ui():
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"):
- for i, k, item in quicksettings_list:
+ for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
@@ -1628,6 +1720,12 @@ def create_ui():
if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+ if os.path.exists("html/footer.html"):
+ with open("html/footer.html", encoding="utf8") as file:
+ footer = file.read()
+ footer = footer.format(versions=versions_html())
+ gr.HTML(footer, elem_id="footer")
+
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
@@ -1662,7 +1760,7 @@ def create_ui():
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list
- return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
+ return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
return results
modelmerger_merge.click(
@@ -1720,7 +1818,7 @@ def create_ui():
if init_field is not None:
init_field(saved_value)
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible:
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
apply_field(x, 'visible')
if type(x) == gr.Slider:
@@ -1741,16 +1839,14 @@ def create_ui():
if type(x) == gr.Number:
apply_field(x, 'value')
- # Since there are many dropdowns that shouldn't be saved,
- # we only mark dropdowns that should be saved.
- if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False):
+ if type(x) == gr.Dropdown:
apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None))
- apply_field(x, 'visible')
visit(txt2img_interface, loadsave, "txt2img")
visit(img2img_interface, loadsave, "img2img")
visit(extras_interface, loadsave, "extras")
visit(modelmerger_interface, loadsave, "modelmerger")
+ visit(train_interface, loadsave, "train")
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
with open(ui_config_file, "w", encoding="utf8") as file:
@@ -1786,3 +1882,30 @@ def reload_javascript():
if not hasattr(shared, 'GradioTemplateResponseOriginal'):
shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
+
+
+def versions_html():
+ import torch
+ import launch
+
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+ commit = launch.commit_hash()
+ short_commit = commit[0:8]
+
+ if shared.xformers_available:
+ import xformers
+ xformers_version = xformers.__version__
+ else:
+ xformers_version = "N/A"
+
+ return f"""
+python: <span title="{sys.version}">{python_version}</span>
+ • 
+torch: {torch.__version__}
+ • 
+xformers: {xformers_version}
+ • 
+gradio: {gr.__version__}
+ • 
+commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
+"""
diff --git a/modules/ui_components.py b/modules/ui_components.py
new file mode 100644
index 00000000..cac001dc
--- /dev/null
+++ b/modules/ui_components.py
@@ -0,0 +1,33 @@
+import gradio as gr
+
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+class FormRow(gr.Row, gr.components.FormComponent):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "row"
+
+
+class FormGroup(gr.Group, gr.components.FormComponent):
+ """Same as gr.Row but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "group"
+
+
+class FormHTML(gr.HTML, gr.components.FormComponent):
+ """Same as gr.HTML but fits inside gradio forms"""
+
+ def get_block_name(self):
+ return "html"
+
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 1434f25f..742e745e 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -9,6 +9,8 @@ import git
import gradio as gr
import html
+import shutil
+import errno
from modules import extensions, shared, paths
@@ -138,7 +140,18 @@ def install_extension_from_url(dirname, url):
repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch()
- os.rename(tmpdir, target_dir)
+ try:
+ os.rename(tmpdir, target_dir)
+ except OSError as err:
+ # TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
+ # Shouldn't cause any new issues at least but we probably want to handle it there too.
+ if err.errno == errno.EXDEV:
+ # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
+ # Since we can't use a rename, do the slower but more versitile shutil.move()
+ shutil.move(tmpdir, target_dir)
+ else:
+ # Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
+ raise(err)
import launch
launch.run_extension_installer(target_dir)
@@ -149,15 +162,15 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True)
-def install_extension_from_index(url, hide_tags):
+def install_extension_from_index(url, hide_tags, sort_column):
ext_table, message = install_extension_from_url(None, url)
- code, _ = refresh_available_extensions_from_data(hide_tags)
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
return code, ext_table, message
-def refresh_available_extensions(url, hide_tags):
+def refresh_available_extensions(url, hide_tags, sort_column):
global available_extensions
import urllib.request
@@ -166,18 +179,28 @@ def refresh_available_extensions(url, hide_tags):
available_extensions = json.loads(text)
- code, tags = refresh_available_extensions_from_data(hide_tags)
+ code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
return url, code, gr.CheckboxGroup.update(choices=tags), ''
-def refresh_available_extensions_for_tags(hide_tags):
- code, _ = refresh_available_extensions_from_data(hide_tags)
+def refresh_available_extensions_for_tags(hide_tags, sort_column):
+ code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
return code, ''
-def refresh_available_extensions_from_data(hide_tags):
+sort_ordering = [
+ # (reverse, order_by_function)
+ (True, lambda x: x.get('added', 'z')),
+ (False, lambda x: x.get('added', 'z')),
+ (False, lambda x: x.get('name', 'z')),
+ (True, lambda x: x.get('name', 'z')),
+ (False, lambda x: 'z'),
+]
+
+
+def refresh_available_extensions_from_data(hide_tags, sort_column):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@@ -197,8 +220,11 @@ def refresh_available_extensions_from_data(hide_tags):
<tbody>
"""
- for ext in extlist:
+ sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]
+
+ for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
name = ext.get("name", "noname")
+ added = ext.get('added', 'unknown')
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
@@ -220,7 +246,7 @@ def refresh_available_extensions_from_data(hide_tags):
code += f"""
<tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
- <td>{html.escape(description)}</td>
+ <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
<td>{install_code}</td>
</tr>
@@ -278,25 +304,32 @@ def create_ui():
with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
- inputs=[available_extensions_index, hide_tags],
+ inputs=[available_extensions_index, hide_tags, sort_column],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
- inputs=[extension_to_install, hide_tags],
+ inputs=[extension_to_install, hide_tags, sort_column],
outputs=[available_extensions_table, extensions_table, install_result],
)
hide_tags.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
- inputs=[hide_tags],
+ inputs=[hide_tags, sort_column],
+ outputs=[available_extensions_table, install_result]
+ )
+
+ sort_column.change(
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
+ inputs=[hide_tags, sort_column],
outputs=[available_extensions_table, install_result]
)
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 07210d14..21945235 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -1,6 +1,7 @@
import os
import tempfile
from collections import namedtuple
+from pathlib import Path
import gradio as gr
@@ -12,10 +13,29 @@ from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
+def register_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
+ gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
+
+ if hasattr(gradio, 'temp_dirs'): # gradio 3.9
+ gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
+
+
+def check_tmp_file(gradio, filename):
+ if hasattr(gradio, 'temp_file_sets'):
+ return any([filename in fileset for fileset in gradio.temp_file_sets])
+
+ if hasattr(gradio, 'temp_dirs'):
+ return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
+
+ return False
+
+
def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
- shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
+ register_tmp_file(shared.demo, already_saved_as)
+
file_obj = Savedfile(already_saved_as)
return file_obj
@@ -44,7 +64,7 @@ def on_tmpdir_changed():
os.makedirs(shared.opts.temp_dir, exist_ok=True)
- shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}
+ register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr():
diff --git a/modules/upscaler.py b/modules/upscaler.py
index c4e6e6bd..231680cb 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -53,10 +53,10 @@ class Upscaler:
def do_upscale(self, img: PIL.Image, selected_model: str):
return img
- def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
+ def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
- dest_w = img.width * scale
- dest_h = img.height * scale
+ dest_w = int(img.width * scale)
+ dest_h = int(img.height * scale)
for i in range(3):
shape = (img.width, img.height)
diff --git a/modules/xlmr.py b/modules/xlmr.py
new file mode 100644
index 00000000..beab3fdf
--- /dev/null
+++ b/modules/xlmr.py
@@ -0,0 +1,137 @@
+from transformers import BertPreTrainedModel,BertModel,BertConfig
+import torch.nn as nn
+import torch
+from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
+from transformers import XLMRobertaModel,XLMRobertaTokenizer
+from typing import Optional
+
+class BertSeriesConfig(BertConfig):
+ def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
+
+ super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+
+
+class BertSeriesModelWithTransformation(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ config_class = BertSeriesConfig
+
+ def __init__(self, config=None, **kargs):
+ # modify initialization for autoloading
+ if config is None:
+ config = XLMRobertaConfig()
+ config.attention_probs_dropout_prob= 0.1
+ config.bos_token_id=0
+ config.eos_token_id=2
+ config.hidden_act='gelu'
+ config.hidden_dropout_prob=0.1
+ config.hidden_size=1024
+ config.initializer_range=0.02
+ config.intermediate_size=4096
+ config.layer_norm_eps=1e-05
+ config.max_position_embeddings=514
+
+ config.num_attention_heads=16
+ config.num_hidden_layers=24
+ config.output_past=True
+ config.pad_token_id=1
+ config.position_embedding_type= "absolute"
+
+ config.type_vocab_size= 1
+ config.use_cache=True
+ config.vocab_size= 250002
+ config.project_dim = 768
+ config.learn_encoder = False
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size,config.project_dim)
+ self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
+ self.pooler = lambda x: x[:,0]
+ self.post_init()
+
+ def encode(self,c):
+ device = next(self.parameters()).device
+ text = self.tokenizer(c,
+ truncation=True,
+ max_length=77,
+ return_length=False,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
+ text["attention_mask"] = torch.tensor(
+ text['attention_mask']).to(device)
+ features = self(**text)
+ return features['projection_state']
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) :
+ r"""
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+
+ outputs = self.roberta(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ return_dict=return_dict,
+ )
+
+ # last module outputs
+ sequence_output = outputs[0]
+
+
+ # project every module
+ sequence_output_ln = self.pre_LN(sequence_output)
+
+ # pooler
+ pooler_output = self.pooler(sequence_output_ln)
+ pooler_output = self.transformation(pooler_output)
+ projection_state = self.transformation(outputs.last_hidden_state)
+
+ return {
+ 'pooler_output':pooler_output,
+ 'last_hidden_state':outputs.last_hidden_state,
+ 'hidden_states':outputs.hidden_states,
+ 'attentions':outputs.attentions,
+ 'projection_state':projection_state,
+ 'sequence_out': sequence_output
+ }
+
+
+class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
+ base_model_prefix = 'roberta'
+ config_class= RobertaSeriesConfig \ No newline at end of file