aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/codeformer_model.py6
-rw-r--r--modules/extras.py107
-rw-r--r--modules/images.py54
-rw-r--r--modules/img2img.py53
-rw-r--r--modules/interrogate.py2
-rw-r--r--modules/processing.py31
-rw-r--r--modules/prompt_parser.py130
-rw-r--r--modules/scripts.py7
-rw-r--r--modules/sd_hijack.py2
-rw-r--r--modules/sd_samplers.py51
-rw-r--r--modules/shared.py17
-rw-r--r--modules/styles.py96
-rw-r--r--modules/txt2img.py4
-rw-r--r--modules/ui.py132
14 files changed, 466 insertions, 226 deletions
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 21c704f7..8fbdea24 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -47,13 +47,11 @@ def setup_codeformer():
def __init__(self):
self.net = None
self.face_helper = None
- if shared.device.type == 'mps': # CodeFormer currently does not support mps backend
- shared.device_codeformer = torch.device('cpu')
def create_models(self):
if self.net is not None and self.face_helper is not None:
- self.net.to(shared.device)
+ self.net.to(devices.device_codeformer)
return self.net, self.face_helper
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
@@ -66,7 +64,7 @@ def setup_codeformer():
self.net = net
self.face_helper = face_helper
- self.net.to(shared.device)
+ self.net.to(devices.device_codeformer)
return net, face_helper
diff --git a/modules/extras.py b/modules/extras.py
index cb083544..ffae7d67 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -7,71 +7,91 @@ import modules.gfpgan_model
from modules.ui import plaintext_to_html
import modules.codeformer_model
import piexif
+import piexif.helper
cached_images = {}
-def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
+def run_extras(image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
devices.torch_gc()
- existing_pnginfo = image.info or {}
+ imageArr = []
- image = image.convert("RGB")
- info = ""
+ if image_folder != None:
+ if image != None:
+ print("Batch detected and single image detected, please only use one of the two. Aborting.")
+ return None
+ #convert file to pillow image
+ for img in image_folder:
+ image = Image.fromarray(np.array(Image.open(img)))
+ imageArr.append(image)
+
+ elif image != None:
+ if image_folder != None:
+ print("Batch detected and single image detected, please only use one of the two. Aborting.")
+ return None
+ else:
+ imageArr.append(image)
outpath = opts.outdir_samples or opts.outdir_extras_samples
- if gfpgan_visibility > 0:
- restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
- res = Image.fromarray(restored_img)
+ for image in imageArr:
+ existing_pnginfo = image.info or {}
+
+ image = image.convert("RGB")
+ info = ""
+
+ if gfpgan_visibility > 0:
+ restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
+ res = Image.fromarray(restored_img)
- if gfpgan_visibility < 1.0:
- res = Image.blend(image, res, gfpgan_visibility)
+ if gfpgan_visibility < 1.0:
+ res = Image.blend(image, res, gfpgan_visibility)
- info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
- image = res
+ info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
+ image = res
- if codeformer_visibility > 0:
- restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
- res = Image.fromarray(restored_img)
+ if codeformer_visibility > 0:
+ restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
+ res = Image.fromarray(restored_img)
- if codeformer_visibility < 1.0:
- res = Image.blend(image, res, codeformer_visibility)
+ if codeformer_visibility < 1.0:
+ res = Image.blend(image, res, codeformer_visibility)
- info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility)}\n"
- image = res
+ info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility)}\n"
+ image = res
- if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
+ if upscaling_resize != 1.0:
+ def upscale(image, scaler_index, resize):
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
+ pixels = tuple(np.array(small).flatten().tolist())
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.upscale(image, image.width * resize, image.height * resize)
- cached_images[key] = c
+ c = cached_images.get(key)
+ if c is None:
+ upscaler = shared.sd_upscalers[scaler_index]
+ c = upscaler.upscale(image, image.width * resize, image.height * resize)
+ cached_images[key] = c
- return c
+ return c
- info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
+ res = upscale(image, extras_upscaler_1, upscaling_resize)
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
- res = Image.blend(res, res2, extras_upscaler_2_visibility)
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
+ res = Image.blend(res, res2, extras_upscaler_2_visibility)
- image = res
+ image = res
- while len(cached_images) > 2:
- del cached_images[next(iter(cached_images.keys()))]
+ while len(cached_images) > 2:
+ del cached_images[next(iter(cached_images.keys()))]
- images.save_image(image, outpath, "", None, info=info, extension=opts.samples_format, short_filename=True, no_prompt=True, pnginfo_section_name="extras", existing_info=existing_pnginfo)
+ images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo)
- return image, plaintext_to_html(info), ''
+ return imageArr, plaintext_to_html(info), ''
def run_pnginfo(image):
@@ -80,7 +100,12 @@ def run_pnginfo(image):
if "exif" in image.info:
exif = piexif.load(image.info["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
- exif_comment = exif_comment.decode("utf8", 'ignore')
+ try:
+ exif_comment = piexif.helper.UserComment.load(exif_comment)
+ except ValueError:
+ exif_comment = exif_comment.decode('utf8', errors="ignore")
+
+
items['exif comment'] = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif']:
diff --git a/modules/images.py b/modules/images.py
index 50b0e099..f37f5f08 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -13,7 +13,7 @@ import string
import modules.shared
from modules import sd_samplers, shared
-from modules.shared import opts
+from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -252,7 +252,7 @@ def sanitize_filename_part(text, replace_spaces=True):
if replace_spaces:
text = text.replace(' ', '_')
- return text.translate({ord(x): '' for x in invalid_filename_chars})[:128]
+ return text.translate({ord(x): '_' for x in invalid_filename_chars})[:128]
def apply_filename_pattern(x, p, seed, prompt):
@@ -277,13 +277,33 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[model_hash]", shared.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
+ if cmd_opts.hide_ui_dir_config:
+ x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
+
return x
+def get_next_sequence_number(path, basename):
+ """
+ Determines and returns the next sequence number to use when saving an image in the specified directory.
+
+ The sequence starts at 0.
+ """
+ result = -1
+ if basename != '':
+ basename = basename + "-"
-def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, pnginfo_section_name='parameters', p=None, existing_info=None):
- # would be better to add this as an argument in future, but will do for now
- is_a_grid = basename != ""
+ prefix_length = len(basename)
+ for p in os.listdir(path):
+ if p.startswith(basename):
+ l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+ try:
+ result = max(int(l[0]), result)
+ except ValueError:
+ pass
+ return result + 1
+
+def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None):
if short_filename or prompt is None or seed is None:
file_decoration = ""
elif opts.save_to_dirs:
@@ -307,7 +327,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else:
pnginfo = None
- save_to_dirs = (is_a_grid and opts.grid_save_to_dirs) or (not is_a_grid and opts.save_to_dirs)
+ save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
if save_to_dirs:
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt)
@@ -315,26 +335,30 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
os.makedirs(path, exist_ok=True)
- filecount = len([x for x in os.listdir(path) if os.path.splitext(x)[1] == '.' + extension])
+ basecount = get_next_sequence_number(path, basename)
fullfn = "a.png"
fullfn_without_extension = "a"
for i in range(500):
- fn = f"{filecount+i:05}" if basename == '' else f"{basename}-{filecount+i:04}"
+ fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn):
break
- if extension.lower() in ("jpg", "jpeg"):
- exif_bytes = piexif.dump({
+ def exif_bytes():
+ return piexif.dump({
"Exif": {
- piexif.ExifIFD.UserComment: info.encode("utf8"),
- }
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
+ },
})
+
+ if extension.lower() in ("jpg", "jpeg", "webp"):
+ image.save(fullfn, quality=opts.jpeg_quality, exif_bytes=exif_bytes())
else:
- exif_bytes = None
+ image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
- image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo, exif=exif_bytes)
+ if extension.lower() == "webp":
+ piexif.insert(exif_bytes, fullfn)
target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length
@@ -346,7 +370,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif oversize:
image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
- image.save(fullfn, quality=opts.jpeg_quality, exif=exif_bytes)
+ image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality, exif_bytes=exif_bytes())
if opts.save_txt and info is not None:
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
diff --git a/modules/img2img.py b/modules/img2img.py
index 70c99e33..2dcabc6b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -11,10 +11,9 @@ from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
-def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init_img_with_mask, init_mask, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, denoising_strength_change_factor: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
+def img2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_mask, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
is_inpaint = mode == 1
- is_loopback = mode == 2
- is_upscale = mode == 3
+ is_upscale = mode == 2
if is_inpaint:
if mask_mode == 0:
@@ -38,7 +37,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt,
negative_prompt=negative_prompt,
- prompt_style=prompt_style,
+ styles=[prompt_style, prompt_style2],
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
@@ -61,46 +60,10 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
denoising_strength=denoising_strength,
inpaint_full_res=inpaint_full_res,
inpainting_mask_invert=inpainting_mask_invert,
- extra_generation_params={
- "Denoising strength change factor": (denoising_strength_change_factor if is_loopback else None)
- }
)
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
- if is_loopback:
- output_images, info = None, None
- history = []
- initial_seed = None
- initial_info = None
-
- state.job_count = n_iter
-
- for i in range(n_iter):
- p.n_iter = 1
- p.batch_size = 1
- p.do_not_save_grid = True
-
- state.job = f"Batch {i + 1} out of {n_iter}"
- processed = process_images(p)
-
- if initial_seed is None:
- initial_seed = processed.seed
- initial_info = processed.info
-
- init_img = processed.images[0]
-
- p.init_images = [init_img]
- p.seed = processed.seed + 1
- p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
- history.append(processed.images[0])
-
- grid = images.image_grid(history, batch_size, rows=1)
-
- images.save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, p=p)
-
- processed = Processed(p, history, initial_seed, initial_info)
-
- elif is_upscale:
+ if is_upscale:
initial_info = None
processing.fix_seed(p)
@@ -113,6 +76,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
+ batch_size = p.batch_size
upscale_count = p.n_iter
p.n_iter = 1
p.do_not_save_grid = True
@@ -124,7 +88,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
for tiledata in row:
work.append(tiledata[2])
- batch_count = math.ceil(len(work) / p.batch_size)
+ batch_count = math.ceil(len(work) / batch_size)
state.job_count = batch_count * upscale_count
print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
@@ -136,9 +100,10 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
work_results = []
for i in range(batch_count):
- p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
+ p.batch_size = batch_size
+ p.init_images = work[i*batch_size:(i+1)*batch_size]
- state.job = f"Batch {i + 1} out of {state.job_count}"
+ state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
processed = process_images(p)
if initial_info is None:
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 06862fcc..f62a4745 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -98,7 +98,7 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array]).to(shared.device)
+ text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
diff --git a/modules/processing.py b/modules/processing.py
index 5abdfd7c..798313ee 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -12,7 +12,7 @@ import cv2
from skimage import exposure
import modules.sd_hijack
-from modules import devices
+from modules import devices, prompt_parser
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
@@ -46,14 +46,14 @@ def apply_color_correction(correction, image):
class StableDiffusionProcessing:
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
- self.prompt_style: str = prompt_style
+ self.styles: str = styles
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
@@ -194,9 +194,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
- comments = []
+ comments = {}
- modules.styles.apply_style(p, shared.prompt_styles[p.prompt_style])
+ shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
all_prompts = p.prompt
@@ -261,11 +261,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
- uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
- c = p.sd_model.get_learned_conditioning(prompts)
+ #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
+ #c = p.sd_model.get_learned_conditioning(prompts)
+ uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
+ c = prompt_parser.get_learned_conditioning(prompts, p.steps)
if len(model_hijack.comments) > 0:
- comments += model_hijack.comments
+ for comment in model_hijack.comments:
+ comments[comment] = 1
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
@@ -326,12 +329,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
state.nextjob()
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
- if not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
- return_grid = opts.return_grid
-
+ if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
grid = images.image_grid(output_images, p.batch_size)
- if return_grid:
+ if opts.return_grid:
output_images.insert(0, grid)
if opts.grid_save:
@@ -458,7 +459,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
- self.color_corrections = []
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
+ if add_color_corrections:
+ self.color_corrections = []
imgs = []
for img in self.init_images:
image = img.convert("RGB")
@@ -480,7 +483,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_fill != 1:
image = fill(image, latent_mask)
- if opts.img2img_color_correction:
+ if add_color_corrections:
self.color_corrections.append(setup_color_correction(image))
image = np.array(image).astype(np.float32) / 255.0
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
new file mode 100644
index 00000000..0835f692
--- /dev/null
+++ b/modules/prompt_parser.py
@@ -0,0 +1,130 @@
+import re
+from collections import namedtuple
+import torch
+
+import modules.shared as shared
+
+re_prompt = re.compile(r'''
+(.*?)
+\[
+ ([^]:]+):
+ (?:([^]:]*):)?
+ ([0-9]*\.?[0-9]+)
+]
+|
+(.+)
+''', re.X)
+
+# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
+# will be represented with prompt_schedule like this (assuming steps=100):
+# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
+# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
+# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
+# [75, 'fantasy landscape with a lake and an oak in background masterful']
+# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
+
+
+def get_learned_conditioning_prompt_schedules(prompts, steps):
+ res = []
+ cache = {}
+
+ for prompt in prompts:
+ prompt_schedule: list[list[str | int]] = [[steps, ""]]
+
+ cached = cache.get(prompt, None)
+ if cached is not None:
+ res.append(cached)
+ continue
+
+ for m in re_prompt.finditer(prompt):
+ plaintext = m.group(1) if m.group(5) is None else m.group(5)
+ concept_from = m.group(2)
+ concept_to = m.group(3)
+ if concept_to is None:
+ concept_to = concept_from
+ concept_from = ""
+ swap_position = float(m.group(4)) if m.group(4) is not None else None
+
+ if swap_position is not None:
+ if swap_position < 1:
+ swap_position = swap_position * steps
+ swap_position = int(min(swap_position, steps))
+
+ swap_index = None
+ found_exact_index = False
+ for i in range(len(prompt_schedule)):
+ end_step = prompt_schedule[i][0]
+ prompt_schedule[i][1] += plaintext
+
+ if swap_position is not None and swap_index is None:
+ if swap_position == end_step:
+ swap_index = i
+ found_exact_index = True
+
+ if swap_position < end_step:
+ swap_index = i
+
+ if swap_index is not None:
+ if not found_exact_index:
+ prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
+
+ for i in range(len(prompt_schedule)):
+ end_step = prompt_schedule[i][0]
+ must_replace = swap_position < end_step
+
+ prompt_schedule[i][1] += concept_to if must_replace else concept_from
+
+ res.append(prompt_schedule)
+ cache[prompt] = prompt_schedule
+ #for t in prompt_schedule:
+ # print(t)
+
+ return res
+
+
+ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
+ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
+
+
+def get_learned_conditioning(prompts, steps):
+
+ res = []
+
+ prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
+ cache = {}
+
+ for prompt, prompt_schedule in zip(prompts, prompt_schedules):
+
+ cached = cache.get(prompt, None)
+ if cached is not None:
+ res.append(cached)
+ continue
+
+ texts = [x[1] for x in prompt_schedule]
+ conds = shared.sd_model.get_learned_conditioning(texts)
+
+ cond_schedule = []
+ for i, (end_at_step, text) in enumerate(prompt_schedule):
+ cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
+
+ cache[prompt] = cond_schedule
+ res.append(cond_schedule)
+
+ return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
+
+
+def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
+ res = torch.zeros(c.shape)
+ for i, cond_schedule in enumerate(c.schedules):
+ target_index = 0
+ for curret_index, (end_at, cond) in enumerate(cond_schedule):
+ if current_step <= end_at:
+ target_index = curret_index
+ break
+ res[i] = cond_schedule[target_index].cond
+
+ return res.to(shared.device)
+
+
+
+#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100)
diff --git a/modules/scripts.py b/modules/scripts.py
index 74591bab..9cc5a185 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -42,10 +42,10 @@ def load_scripts(basedir):
if not os.path.isfile(path):
continue
- with open(path, "r", encoding="utf8") as file:
- text = file.read()
-
try:
+ with open(path, "r", encoding="utf8") as file:
+ text = file.read()
+
from types import ModuleType
compiled = compile(text, path, 'exec')
module = ModuleType(filename)
@@ -92,6 +92,7 @@ class ScriptRunner:
for script in self.scripts:
script.args_from = len(inputs)
+ script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index ec7d14cb..65414518 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -57,7 +57,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index d478c5bc..02ffce0e 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -7,6 +7,7 @@ from PIL import Image
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
+from modules import prompt_parser
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -53,20 +54,6 @@ def store_latent(decoded):
shared.state.current_image = sample_to_image(decoded)
-def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
- if sampler_wrapper.mask is not None:
- img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
- x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
-
- res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
-
- if sampler_wrapper.mask is not None:
- store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
- else:
- store_latent(res[1])
-
- return res
-
def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
@@ -94,10 +81,29 @@ class VanillaStableDiffusionSampler:
self.nmask = None
self.init_latent = None
self.sampler_noises = None
+ self.step = 0
def number_of_needed_noises(self, p):
return 0
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+
+ if self.mask is not None:
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
+ x_dec = img_orig * self.mask + self.nmask * x_dec
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ if self.mask is not None:
+ store_latent(self.init_latent * self.mask + self.nmask * res[1])
+ else:
+ store_latent(res[1])
+
+ self.step += 1
+ return res
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
@@ -109,10 +115,11 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
- self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
+ self.sampler.p_sample_ddim = self.p_sample_ddim_hook
self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
+ self.step = 0
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
@@ -121,10 +128,11 @@ class VanillaStableDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning):
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
# existing code fails with cetin step counts, like 9
try:
@@ -142,8 +150,12 @@ class CFGDenoiser(torch.nn.Module):
self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
@@ -158,6 +170,8 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ self.step += 1
+
return denoised
@@ -191,7 +205,7 @@ class TorchHijack:
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
- self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
+ self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
@@ -228,6 +242,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.mask = p.mask
self.model_wrap_cfg.nmask = p.nmask
self.model_wrap_cfg.init_latent = p.init_latent
+ self.model_wrap.step = 0
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
@@ -241,6 +256,8 @@ class KDiffusionSampler:
sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0]
+ self.model_wrap_cfg.step = 0
+
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
diff --git a/modules/shared.py b/modules/shared.py
index ac870ec4..fa6a0e99 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -23,7 +23,7 @@ parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, sd_model_f
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
-parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
+parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
@@ -45,6 +45,7 @@ parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
+parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
cmd_opts = parser.parse_args()
@@ -79,8 +80,8 @@ state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
-styles_filename = os.path.join(script_path, 'styles.csv')
-prompt_styles = modules.styles.load_styles(styles_filename)
+styles_filename = cmd_opts.styles_file
+prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
@@ -109,10 +110,11 @@ class Options:
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
- "samples_save": OptionInfo(True, "Save indiviual samples"),
+ "samples_save": OptionInfo(True, "Always save all generated images"),
+ "save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
"samples_format": OptionInfo('png', 'File format for individual samples'),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
- "grid_save": OptionInfo(True, "Save image grids"),
+ "grid_save": OptionInfo(True, "Always save all generated image grids"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
@@ -123,6 +125,7 @@ class Options:
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"add_model_hash_to_info": OptionInfo(False, "Add model hash to generation information"),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
+ "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"font": OptionInfo("", "Font for image grids that have text"),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
@@ -141,8 +144,8 @@ class Options:
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
- "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
- "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
+ "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
}
diff --git a/modules/styles.py b/modules/styles.py
index bc7f070f..eeedcd08 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -20,49 +20,67 @@ class PromptStyle(typing.NamedTuple):
negative_prompt: str
-def load_styles(path: str) -> dict[str, PromptStyle]:
- styles = {"None": PromptStyle("None", "", "")}
+def merge_prompts(style_prompt: str, prompt: str) -> str:
+ if "{prompt}" in style_prompt:
+ res = style_prompt.replace("{prompt}", prompt)
+ else:
+ parts = filter(None, (prompt.strip(), style_prompt.strip()))
+ res = ", ".join(parts)
- if os.path.exists(path):
- with open(path, "r", encoding="utf8", newline='') as file:
- reader = csv.DictReader(file)
- for row in reader:
- # Support loading old CSV format with "name, text"-columns
- prompt = row["prompt"] if "prompt" in row else row["text"]
- negative_prompt = row.get("negative_prompt", "")
- styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+ return res
- return styles
+def apply_styles_to_prompt(prompt, styles):
+ for style in styles:
+ prompt = merge_prompts(style, prompt)
-def merge_prompts(style_prompt: str, prompt: str) -> str:
- parts = filter(None, (prompt.strip(), style_prompt.strip()))
- return ", ".join(parts)
+ return prompt
-def apply_style(processing: StableDiffusionProcessing, style: PromptStyle) -> None:
- if isinstance(processing.prompt, list):
- processing.prompt = [merge_prompts(style.prompt, p) for p in processing.prompt]
- else:
- processing.prompt = merge_prompts(style.prompt, processing.prompt)
+class StyleDatabase:
+ def __init__(self, path: str):
+ self.no_style = PromptStyle("None", "", "")
+ self.styles = {"None": self.no_style}
- if isinstance(processing.negative_prompt, list):
- processing.negative_prompt = [merge_prompts(style.negative_prompt, p) for p in processing.negative_prompt]
- else:
- processing.negative_prompt = merge_prompts(style.negative_prompt, processing.negative_prompt)
-
-
-def save_styles(path: str, styles: abc.Iterable[PromptStyle]) -> None:
- # Write to temporary file first, so we don't nuke the file if something goes wrong
- fd, temp_path = tempfile.mkstemp(".csv")
- with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
- # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
- # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
- writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
- writer.writeheader()
- writer.writerows(style._asdict() for style in styles)
-
- # Always keep a backup file around
- if os.path.exists(path):
- shutil.move(path, path + ".bak")
- shutil.move(temp_path, path)
+ if not os.path.exists(path):
+ return
+
+ with open(path, "r", encoding="utf8", newline='') as file:
+ reader = csv.DictReader(file)
+ for row in reader:
+ # Support loading old CSV format with "name, text"-columns
+ prompt = row["prompt"] if "prompt" in row else row["text"]
+ negative_prompt = row.get("negative_prompt", "")
+ self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
+
+ def apply_styles_to_prompt(self, prompt, styles):
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
+
+ def apply_negative_styles_to_prompt(self, prompt, styles):
+ return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
+
+ def apply_styles(self, p: StableDiffusionProcessing) -> None:
+ if isinstance(p.prompt, list):
+ p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
+ else:
+ p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)
+
+ if isinstance(p.negative_prompt, list):
+ p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
+ else:
+ p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
+
+ def save_styles(self, path: str) -> None:
+ # Write to temporary file first, so we don't nuke the file if something goes wrong
+ fd, temp_path = tempfile.mkstemp(".csv")
+ with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
+ # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
+ # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
+ writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
+ writer.writeheader()
+ writer.writerows(style._asdict() for k, style in self.styles.items())
+
+ # Always keep a backup file around
+ if os.path.exists(path):
+ shutil.move(path, path + ".bak")
+ shutil.move(temp_path, path)
diff --git a/modules/txt2img.py b/modules/txt2img.py
index d60febfc..30d89849 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -6,13 +6,13 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, *args):
+def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
- prompt_style=prompt_style,
+ styles=[prompt_style, prompt_style2],
negative_prompt=negative_prompt,
seed=seed,
subseed=subseed,
diff --git a/modules/ui.py b/modules/ui.py
index d1aa7793..b6d5dcd8 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -80,7 +80,7 @@ def send_gradio_gallery_to_image(x):
return image_from_url_text(x[0])
-def save_files(js_data, images):
+def save_files(js_data, images, index):
import csv
os.makedirs(opts.outdir_save, exist_ok=True)
@@ -88,6 +88,10 @@ def save_files(js_data, images):
filenames = []
data = json.loads(js_data)
+
+ if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
+ images = [images[index]]
+ data["seed"] += (index - 1 if opts.return_grid else index)
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0
@@ -233,13 +237,20 @@ def add_style(name: str, prompt: str, negative_prompt: str):
return [gr_show(), gr_show()]
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles[style.name] = style
+ shared.prompt_styles.styles[style.name] = style
# Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
# reserialize all styles every time we save them
- modules.styles.save_styles(shared.styles_filename, shared.prompt_styles.values())
+ shared.prompt_styles.save_styles(shared.styles_filename)
- update = {"visible": True, "choices": list(shared.prompt_styles), "__type__": "update"}
- return [update, update]
+ update = {"visible": True, "choices": list(shared.prompt_styles.styles), "__type__": "update"}
+ return [update, update, update, update]
+
+
+def apply_styles(prompt, prompt_neg, style1_name, style2_name):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name])
+ prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name])
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")]
def interrogate(image):
@@ -247,15 +258,46 @@ def interrogate(image):
return gr_show(True) if prompt is None else prompt
+
+def create_toprow(is_img2img):
+ with gr.Row(elem_id="toprow"):
+ with gr.Column(scale=4):
+ with gr.Row():
+ with gr.Column(scale=8):
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2)
+ roll = gr.Button('Roll', elem_id="roll", visible=len(shared.artist_db.artists) > 0)
+
+ with gr.Column(scale=1, elem_id="style_pos_col"):
+ prompt_style = gr.Dropdown(label="Style 1", elem_id="style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+
+ with gr.Row():
+ with gr.Column(scale=8):
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
+
+ with gr.Column(scale=1, elem_id="style_neg_col"):
+ prompt_style2 = gr.Dropdown(label="Style 2", elem_id="style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+
+ with gr.Column(scale=1):
+ with gr.Row():
+ submit = gr.Button('Generate', elem_id="generate", variant='primary')
+
+ with gr.Row():
+ if is_img2img:
+ interrogate = gr.Button('Interrogate', elem_id="interrogate")
+ else:
+ interrogate = None
+ prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
+ save_style = gr.Button('Create style', elem_id="style_create")
+
+ check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
+
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, check_progress
+
+
def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- with gr.Row(elem_id="toprow"):
- txt2img_prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1)
- txt2img_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
- txt2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
- roll = gr.Button('Roll', elem_id="txt2img_roll", visible=len(shared.artist_db.artists) > 0)
- submit = gr.Button('Generate', elem_id="txt2img_generate", variant='primary')
- check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, check_progress = create_toprow(is_img2img=False)
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -286,7 +328,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery').style(grid=4)
-
with gr.Group():
with gr.Row():
save = gr.Button('Save')
@@ -294,7 +335,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
send_to_inpaint = gr.Button('Send to inpaint')
send_to_extras = gr.Button('Send to extras')
interrupt = gr.Button('Interrupt')
- txt2img_save_style = gr.Button('Save prompt as style')
progressbar = gr.HTML(elem_id="progressbar")
@@ -302,7 +342,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
html_info = gr.HTML()
generation_info = gr.Textbox(visible=False)
-
txt2img_args = dict(
fn=txt2img,
_js="submit",
@@ -310,6 +349,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
txt2img_prompt,
txt2img_negative_prompt,
txt2img_prompt_style,
+ txt2img_prompt_style2,
steps,
sampler_index,
restore_faces,
@@ -339,7 +379,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
outputs=[progressbar, txt2img_preview, txt2img_preview],
)
-
interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
@@ -348,9 +387,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
save.click(
fn=wrap_gradio_call(save_files),
+ _js = "(x, y, z) => [x, y, selected_gallery_index()]",
inputs=[
generation_info,
txt2img_gallery,
+ html_info
],
outputs=[
html_info,
@@ -370,18 +411,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- with gr.Row(elem_id="toprow"):
- img2img_prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1)
- img2img_negative_prompt = gr.Textbox(label="Negative prompt", elem_id="img2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1)
- img2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1)
- img2img_interrogate = gr.Button('Interrogate', elem_id="img2img_interrogate", variant='primary')
- submit = gr.Button('Generate', elem_id="img2img_generate", variant='primary')
- check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, check_progress = create_toprow(is_img2img=True)
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Group():
- switch_mode = gr.Radio(label='Mode', elem_id="img2img_mode", choices=['Redraw whole image', 'Inpaint a part of image', 'Loopback', 'SD upscale'], value='Redraw whole image', type="index", show_label=False)
+ switch_mode = gr.Radio(label='Mode', elem_id="img2img_mode", choices=['Redraw whole image', 'Inpaint a part of image', 'SD upscale'], value='Redraw whole image', type="index", show_label=False)
init_img = gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil")
init_img_with_mask = gr.Image(label="Image for inpainting with mask", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", visible=False, image_mode="RGBA")
init_mask = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False)
@@ -415,7 +450,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Group():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
- denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, visible=False)
with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
@@ -449,8 +483,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
def apply_mode(mode, uploadmask):
is_classic = mode == 0
is_inpaint = mode == 1
- is_loopback = mode == 2
- is_upscale = mode == 3
+ is_upscale = mode == 2
return {
init_img: gr_show(not is_inpaint or (is_inpaint and uploadmask == 1)),
@@ -460,12 +493,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
mask_mode: gr_show(is_inpaint),
mask_blur: gr_show(is_inpaint),
inpainting_fill: gr_show(is_inpaint),
- batch_size: gr_show(not is_loopback),
sd_upscale_upscaler_name: gr_show(is_upscale),
sd_upscale_overlap: gr_show(is_upscale),
inpaint_full_res: gr_show(is_inpaint),
inpainting_mask_invert: gr_show(is_inpaint),
- denoising_strength_change_factor: gr_show(is_loopback),
img2img_interrogate: gr_show(not is_inpaint),
}
@@ -480,12 +511,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
mask_mode,
mask_blur,
inpainting_fill,
- batch_size,
sd_upscale_upscaler_name,
sd_upscale_overlap,
inpaint_full_res,
inpainting_mask_invert,
- denoising_strength_change_factor,
img2img_interrogate,
]
)
@@ -511,6 +540,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
img2img_prompt,
img2img_negative_prompt,
img2img_prompt_style,
+ img2img_prompt_style2,
init_img,
init_img_with_mask,
init_mask,
@@ -526,7 +556,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
batch_size,
cfg_scale,
denoising_strength,
- denoising_strength_change_factor,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
height,
@@ -568,9 +597,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
save.click(
fn=wrap_gradio_call(save_files),
+ _js = "(x, y, z) => [x, y, selected_gallery_index()]",
inputs=[
generation_info,
img2img_gallery,
+ html_info
],
outputs=[
html_info,
@@ -579,22 +610,46 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
]
)
+ roll.click(
+ fn=roll_artist,
+ inputs=[
+ img2img_prompt,
+ ],
+ outputs=[
+ img2img_prompt,
+ ]
+ )
+
+ prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
+ style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
+
dummy_component = gr.Label(visible=False)
- for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]):
+ for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
button.click(
fn=add_style,
_js="ask_for_style_name",
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
# the same number of parameters, but we only know the style-name after the JavaScript prompt
inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_style, img2img_prompt_style],
+ outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
+ )
+
+ for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns):
+ button.click(
+ fn=apply_styles,
+ inputs=[prompt, negative_prompt, style1, style2],
+ outputs=[prompt, negative_prompt, style1, style2],
)
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
- with gr.Group():
- image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
+ with gr.Tabs():
+ with gr.TabItem('Single Image'):
+ image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
+
+ with gr.TabItem('Batch Process'):
+ image_batch = gr.File(label="Batch Process", file_count="multiple", source="upload", interactive=True, type="file")
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
@@ -615,7 +670,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Column(variant='panel'):
- result_image = gr.Image(label="Result")
+ result_images = gr.Gallery(label="Result")
html_info_x = gr.HTML()
html_info = gr.HTML()
@@ -623,6 +678,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
fn=run_extras,
inputs=[
image,
+ image_batch,
gfpgan_visibility,
codeformer_visibility,
codeformer_weight,
@@ -632,7 +688,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
extras_upscaler_2_visibility,
],
outputs=[
- result_image,
+ result_images,
html_info_x,
html_info,
]