aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/images.py191
-rw-r--r--modules/img2img.py50
-rw-r--r--modules/processing.py113
-rw-r--r--modules/sd_samplers_cfg_denoiser.py24
-rw-r--r--modules/sd_samplers_common.py1
-rw-r--r--modules/soft_inpainting.py308
-rw-r--r--modules/ui.py7
7 files changed, 666 insertions, 28 deletions
diff --git a/modules/images.py b/modules/images.py
index daf4eebe..94953498 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -791,3 +791,194 @@ def flatten(img, bgcolor):
img = background
return img.convert('RGB')
+
+
+def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
+ """
+ Generalization convolution filter capable of applying
+ weighted mean, median, maximum, and minimum filters
+ parametrically using an arbitrary kernel.
+
+ Args:
+ img (nparray):
+ The image, a 2-D array of floats, to which the filter is being applied.
+ kernel (nparray):
+ The kernel, a 2-D array of floats.
+ kernel_center (nparray):
+ The kernel center coordinate, a 1-D array with two elements.
+ percentile_min (float):
+ The lower bound of the histogram window used by the filter,
+ from 0 to 1.
+ percentile_max (float):
+ The upper bound of the histogram window used by the filter,
+ from 0 to 1.
+ min_width (float):
+ The minimum size of the histogram window bounds, in weight units.
+ Must be greater than 0.
+
+ Returns:
+ (nparray): A filtered copy of the input image "img", a 2-D array of floats.
+ """
+
+ # Converts an index tuple into a vector.
+ def vec(x):
+ return np.array(x)
+
+ kernel_min = -kernel_center
+ kernel_max = vec(kernel.shape) - kernel_center
+
+ def weighted_histogram_filter_single(idx):
+ idx = vec(idx)
+ min_index = np.maximum(0, idx + kernel_min)
+ max_index = np.minimum(vec(img.shape), idx + kernel_max)
+ window_shape = max_index - min_index
+
+ class WeightedElement:
+ """
+ An element of the histogram, its weight
+ and bounds.
+ """
+ def __init__(self, value, weight):
+ self.value: float = value
+ self.weight: float = weight
+ self.window_min: float = 0.0
+ self.window_max: float = 1.0
+
+ # Collect the values in the image as WeightedElements,
+ # weighted by their corresponding kernel values.
+ values = []
+ for window_tup in np.ndindex(tuple(window_shape)):
+ window_index = vec(window_tup)
+ image_index = window_index + min_index
+ centered_kernel_index = image_index - idx
+ kernel_index = centered_kernel_index + kernel_center
+ element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
+ values.append(element)
+
+ def sort_key(x: WeightedElement):
+ return x.value
+
+ values.sort(key=sort_key)
+
+ # Calculate the height of the stack (sum)
+ # and each sample's range they occupy in the stack
+ sum = 0
+ for i in range(len(values)):
+ values[i].window_min = sum
+ sum += values[i].weight
+ values[i].window_max = sum
+
+ # Calculate what range of this stack ("window")
+ # we want to get the weighted average across.
+ window_min = sum * percentile_min
+ window_max = sum * percentile_max
+ window_width = window_max - window_min
+
+ # Ensure the window is within the stack and at least a certain size.
+ if window_width < min_width:
+ window_center = (window_min + window_max) / 2
+ window_min = window_center - min_width / 2
+ window_max = window_center + min_width / 2
+
+ if window_max > sum:
+ window_max = sum
+ window_min = sum - min_width
+
+ if window_min < 0:
+ window_min = 0
+ window_max = min_width
+
+ value = 0
+ value_weight = 0
+
+ # Get the weighted average of all the samples
+ # that overlap with the window, weighted
+ # by the size of their overlap.
+ for i in range(len(values)):
+ if window_min >= values[i].window_max:
+ continue
+ if window_max <= values[i].window_min:
+ break
+
+ s = max(window_min, values[i].window_min)
+ e = min(window_max, values[i].window_max)
+ w = e - s
+
+ value += values[i].value * w
+ value_weight += w
+
+ return value / value_weight if value_weight != 0 else 0
+
+ img_out = img.copy()
+
+ # Apply the kernel operation over each pixel.
+ for index in np.ndindex(img.shape):
+ img_out[index] = weighted_histogram_filter_single(index)
+
+ return img_out
+
+def smoothstep(x):
+ """
+ The smoothstep function, input should be clamped to 0-1 range.
+ Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
+ """
+ return x * x * (3 - 2 * x)
+
+def smootherstep(x):
+ """
+ The smootherstep function, input should be clamped to 0-1 range.
+ Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
+ """
+ return x * x * x * (x * (6 * x - 15) + 10)
+
+
+def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
+ """
+ Creates a Gaussian kernel with thresholded edges.
+
+ Args:
+ stddev_radius (float):
+ Standard deviation of the gaussian kernel, in pixels.
+ max_radius (int):
+ The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.
+ The kernel is thresholded so that any values one pixel beyond this radius
+ is weighted at 0.
+
+ Returns:
+ (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
+ """
+ # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
+ def gaussian(sqr_mag):
+ return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
+
+ # Helper function for converting a tuple to an array.
+ def vec(x):
+ return np.array(x)
+
+ """
+ Since a gaussian is unbounded, we need to limit ourselves
+ to a finite range.
+ We taper the ends off at the end of that range so they equal zero
+ while preserving the maximum value of 1 at the mean.
+ """
+ zero_radius = max_radius + 1.0
+ gauss_zero = gaussian(zero_radius * zero_radius)
+ gauss_kernel_scale = 1 / (1 - gauss_zero)
+
+ def gaussian_kernel_func(coordinate):
+ x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0
+ x = gaussian(x)
+ x -= gauss_zero
+ x *= gauss_kernel_scale
+ x = max(0.0, x)
+ return x
+
+ size = max_radius * 2 + 1
+ kernel_center = max_radius
+ kernel = np.zeros((size, size))
+
+ for index in np.ndindex(kernel.shape):
+ kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)
+
+ return kernel, kernel_center
+
diff --git a/modules/img2img.py b/modules/img2img.py
index c583290a..3aa8a9ce 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -15,6 +15,7 @@ import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.scripts
+import modules.soft_inpainting as si
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
@@ -146,7 +147,48 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
return batch_results
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
+def img2img(id_task: str,
+ mode: int,
+ prompt: str,
+ negative_prompt: str,
+ prompt_styles,
+ init_img,
+ sketch,
+ init_img_with_mask,
+ inpaint_color_sketch,
+ inpaint_color_sketch_orig,
+ init_img_inpaint,
+ init_mask_inpaint,
+ steps: int,
+ sampler_name: str,
+ mask_blur: int,
+ mask_alpha: float,
+ mask_blend_enabled: bool,
+ mask_blend_power: float,
+ mask_blend_scale: float,
+ inpaint_detail_preservation: float,
+ inpainting_fill: int,
+ n_iter: int,
+ batch_size: int,
+ cfg_scale: float,
+ image_cfg_scale: float,
+ denoising_strength: float,
+ selected_scale_tab: int,
+ height: int,
+ width: int,
+ scale_by: float,
+ resize_mode: int,
+ inpaint_full_res: bool,
+ inpaint_full_res_padding: int,
+ inpainting_mask_invert: int,
+ img2img_batch_input_dir: str,
+ img2img_batch_output_dir: str,
+ img2img_batch_inpaint_mask_dir: str,
+ override_settings_texts,
+ img2img_batch_use_png_info: bool,
+ img2img_batch_png_info_props: list,
+ img2img_batch_png_info_dir: str,
+ request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -187,6 +229,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \
+ if mask_blend_enabled else None
+
p = StableDiffusionProcessingImg2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
@@ -204,6 +249,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
init_images=[image],
mask=mask,
mask_blur=mask_blur,
+ soft_inpainting=soft_inpainting,
inpainting_fill=inpainting_fill,
resize_mode=resize_mode,
denoising_strength=denoising_strength,
@@ -224,6 +270,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if mask:
p.extra_generation_params["Mask blur"] = mask_blur
+ if soft_inpainting is not None:
+ soft_inpainting.add_generation_params(p.extra_generation_params)
with closing(p):
if is_batch:
diff --git a/modules/processing.py b/modules/processing.py
index 6f01c95f..7d46949f 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -30,6 +30,7 @@ import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
+import modules.soft_inpainting as si
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
@@ -62,6 +63,16 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')
+def uncrop(image, dest_size, paste_loc):
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', dest_size)
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ return image
+
+
def apply_overlay(image, paste_loc, index, overlays):
if overlays is None or index >= len(overlays):
return image
@@ -69,11 +80,7 @@ def apply_overlay(image, paste_loc, index, overlays):
overlay = overlays[index]
if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@@ -81,9 +88,12 @@ def apply_overlay(image, paste_loc, index, overlays):
return image
-def create_binary_mask(image):
+def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
- image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ if round:
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ else:
+ image = image.split()[-1].convert("L")
else:
image = image.convert('L')
return image
@@ -140,6 +150,7 @@ class StableDiffusionProcessing:
do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
+ masks_for_overlay: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = None
@@ -308,7 +319,7 @@ class StableDiffusionProcessing:
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -320,8 +331,10 @@ class StableDiffusionProcessing:
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
+ if round_image_mask:
+ # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+
else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
@@ -345,7 +358,7 @@ class StableDiffusionProcessing:
return image_conditioning
- def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
@@ -357,7 +370,7 @@ class StableDiffusionProcessing:
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
@@ -869,9 +882,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
+ # todo: generate adaptive masks based on pixel differences.
+ if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
+ si.apply_masks(soft_inpainting=p.soft_inpainting,
+ nmask=p.nmask,
+ overlay_images=p.overlay_images,
+ masks_for_overlay=p.masks_for_overlay,
+ width=p.width,
+ height=p.height,
+ paste_to=p.paste_to)
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
+
+ # Generate the mask(s) based on similarity between the original and denoised latent vectors
+ if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
+ si.apply_adaptive_masks(latent_orig=p.init_latent,
+ latent_processed=samples_ddim,
+ overlay_images=p.overlay_images,
+ masks_for_overlay=p.masks_for_overlay,
+ width=p.width,
+ height=p.height,
+ paste_to=p.paste_to)
+
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -928,6 +961,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
+ # If the intention is to show the output from the model
+ # that is being composited over the original image,
+ # we need to keep the original image around
+ # and use it in the composite step.
+ original_denoised_image = image.copy()
+
+ if p.paste_to is not None:
+ original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to)
+
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if save_samples:
@@ -938,16 +980,24 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
+
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+ mask_for_overlay = p.mask_for_overlay
+ elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]:
+ mask_for_overlay = p.masks_for_overlay[i]
+ else:
+ mask_for_overlay = None
+
+ if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
- image_mask = p.mask_for_overlay.convert('RGB')
+ image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.return_mask:
output_images.append(image_mask)
if opts.return_mask_composite or opts.save_mask_composite:
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite:
@@ -1351,6 +1401,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
+ soft_inpainting: si.SoftInpaintingParameters = si.default
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
@@ -1396,7 +1447,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
- image_mask = create_binary_mask(image_mask)
+ image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None))
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
@@ -1414,7 +1465,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
- self.mask_for_overlay = image_mask
+ self.mask_for_overlay = image_mask if self.soft_inpainting is None else None
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@@ -1425,10 +1476,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
- np_mask = np.array(image_mask)
- np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
+ if self.soft_inpainting is None:
+ np_mask = np.array(image_mask)
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
+ self.mask_for_overlay = Image.fromarray(np_mask)
+
+ self.masks_for_overlay = [] if self.soft_inpainting is not None else None
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@@ -1450,10 +1504,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
+ if self.soft_inpainting is not None:
+ # We apply the masks AFTER to adjust mask based on changed content.
+ self.overlay_images.append(image.convert('RGBA'))
+ self.masks_for_overlay.append(image_mask)
+ else:
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
- self.overlay_images.append(image_masked.convert('RGBA'))
+ self.overlay_images.append(image_masked.convert('RGBA'))
# crop_region is not None if we are doing inpaint full res
if crop_region is not None:
@@ -1477,6 +1536,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+ if self.masks_for_overlay is not None:
+ self.masks_for_overlay = self.masks_for_overlay * self.batch_size
+
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size
@@ -1503,7 +1565,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
- latmask = np.around(latmask)
+ if self.soft_inpainting is None:
+ latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
@@ -1515,7 +1578,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
@@ -1526,7 +1589,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
- if self.mask is not None:
+ if self.mask is not None and self.soft_inpainting is None:
samples = samples * self.nmask + self.init_latent * self.mask
del x
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index b8101d38..a700e692 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -6,6 +6,7 @@ import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
+import modules.soft_inpainting as si
def catenate_conds(conds):
@@ -43,6 +44,7 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None
self.mask = None
self.nmask = None
+ self.soft_inpainting: si.SoftInpaintingParameters = None
self.init_latent = None
self.steps = None
"""number of steps as specified by user in UI"""
@@ -56,6 +58,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler
self.model_wrap = None
self.p = None
+
+ # NOTE: masking before denoising can cause the original latents to be oversmoothed
+ # as the original latents do not have noise
self.mask_before_denoising = False
@property
@@ -89,6 +94,7 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -105,8 +111,15 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
+ if self.soft_inpainting is None:
+ x = self.init_latent * self.mask + self.nmask * x
+ else:
+ x = si.latent_blend(self.soft_inpainting,
+ self.init_latent,
+ x,
+ si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -207,8 +220,15 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ # Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ if self.soft_inpainting is None:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+ else:
+ denoised = si.latent_blend(self.soft_inpainting,
+ self.init_latent,
+ denoised,
+ si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 58efcad2..9682bee3 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -277,6 +277,7 @@ class Sampler:
self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
+ self.model_wrap_cfg.soft_inpainting = p.soft_inpainting if hasattr(p, 'soft_inpainting') else None
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py
new file mode 100644
index 00000000..b36ac8fa
--- /dev/null
+++ b/modules/soft_inpainting.py
@@ -0,0 +1,308 @@
+class SoftInpaintingSettings:
+ def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation):
+ self.mask_blend_power = mask_blend_power
+ self.mask_blend_scale = mask_blend_scale
+ self.inpaint_detail_preservation = inpaint_detail_preservation
+
+ def add_generation_params(self, dest):
+ dest[enabled_gen_param_label] = True
+ dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
+ dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
+ dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
+
+
+# ------------------- Methods -------------------
+
+
+def latent_blend(soft_inpainting, a, b, t):
+ """
+ Interpolates two latent image representations according to the parameter t,
+ where the interpolated vectors' magnitudes are also interpolated separately.
+ The "detail_preservation" factor biases the magnitude interpolation towards
+ the larger of the two magnitudes.
+ """
+ import torch
+
+ # NOTE: We use inplace operations wherever possible.
+
+ # [4][w][h] to [1][4][w][h]
+ t2 = t.unsqueeze(0)
+ # [4][w][h] to [1][1][w][h] - the [4] seem redundant.
+ t3 = t[0].unsqueeze(0).unsqueeze(0)
+
+ one_minus_t2 = 1 - t2
+ one_minus_t3 = 1 - t3
+
+ # Linearly interpolate the image vectors.
+ a_scaled = a * one_minus_t2
+ b_scaled = b * t2
+ image_interp = a_scaled
+ image_interp.add_(b_scaled)
+ result_type = image_interp.dtype
+ del a_scaled, b_scaled, t2, one_minus_t2
+
+ # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
+ # 64-bit operations are used here to allow large exponents.
+ current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
+
+ # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
+ a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3
+ b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3
+ desired_magnitude = a_magnitude
+ desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
+ del a_magnitude, b_magnitude, t3, one_minus_t3
+
+ # Change the linearly interpolated image vectors' magnitudes to the value we want.
+ # This is the last 64-bit operation.
+ image_interp_scaling_factor = desired_magnitude
+ image_interp_scaling_factor.div_(current_magnitude)
+ image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
+ image_interp_scaled = image_interp
+ image_interp_scaled.mul_(image_interp_scaling_factor)
+ del current_magnitude
+ del desired_magnitude
+ del image_interp
+ del image_interp_scaling_factor
+ del result_type
+
+ return image_interp_scaled
+
+
+def get_modified_nmask(soft_inpainting, nmask, sigma):
+ """
+ Converts a negative mask representing the transparency of the original latent vectors being overlayed
+ to a mask that is scaled according to the denoising strength for this step.
+
+ Where:
+ 0 = fully opaque, infinite density, fully masked
+ 1 = fully transparent, zero density, fully unmasked
+
+ We bring this transparency to a power, as this allows one to simulate N number of blending operations
+ where N can be any positive real value. Using this one can control the balance of influence between
+ the denoiser and the original latents according to the sigma value.
+
+ NOTE: "mask" is not used
+ """
+ import torch
+ # todo: Why is sigma 2D? Both values are the same.
+ return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
+
+
+def apply_adaptive_masks(
+ latent_orig,
+ latent_processed,
+ overlay_images,
+ masks_for_overlay,
+ width, height,
+ paste_to):
+ import torch
+ import numpy as np
+ import modules.processing as proc
+ import modules.images as images
+ from PIL import Image, ImageOps, ImageFilter
+
+ # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
+ # latent_mask = p.nmask[0].float().cpu()
+ # convert the original mask into a form we use to scale distances for thresholding
+ # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
+ # mask_scalar = mask_scalar / (1.00001-mask_scalar)
+ # mask_scalar = mask_scalar.numpy()
+
+ latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
+
+ kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
+
+ for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
+ converted_mask = distance_map.float().cpu().numpy()
+ converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.9, percentile_max=1, min_width=1)
+ converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.25, percentile_max=0.75, min_width=1)
+
+ # The distance at which opacity of original decreases to 50%
+ # half_weighted_distance = 1 # * mask_scalar
+ # converted_mask = converted_mask / half_weighted_distance
+
+ converted_mask = 1 / (1 + converted_mask ** 2)
+ converted_mask = images.smootherstep(converted_mask)
+ converted_mask = 1 - converted_mask
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, width, height)
+ converted_mask = proc.create_binary_mask(converted_mask, round=False)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if paste_to is not None:
+ converted_mask = proc. uncrop(converted_mask,
+ (overlay_image.width, overlay_image.height),
+ paste_to)
+
+ masks_for_overlay[i] = converted_mask
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ overlay_images[i] = image_masked.convert('RGBA')
+
+def apply_masks(
+ soft_inpainting,
+ nmask,
+ overlay_images,
+ masks_for_overlay,
+ width, height,
+ paste_to):
+ import torch
+ import numpy as np
+ import modules.processing as proc
+ import modules.images as images
+ from PIL import Image, ImageOps, ImageFilter
+
+ converted_mask = nmask[0].float()
+ converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2)
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, width, height)
+ converted_mask = proc.create_binary_mask(converted_mask, round=False)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if paste_to is not None:
+ converted_mask = proc.uncrop(converted_mask,
+ (width, height),
+ paste_to)
+
+ for i, overlay_image in enumerate(overlay_images):
+ masks_for_overlay[i] = converted_mask
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ overlay_images[i] = image_masked.convert('RGBA')
+
+
+# ------------------- Constants -------------------
+
+
+default = SoftInpaintingSettings(1, 0.5, 4)
+
+enabled_ui_label = "Soft inpainting"
+enabled_gen_param_label = "Soft inpainting enabled"
+enabled_el_id = "soft_inpainting_enabled"
+
+ui_labels = SoftInpaintingSettings(
+ "Schedule bias",
+ "Preservation strength",
+ "Transition contrast boost")
+
+ui_info = SoftInpaintingSettings(
+ "Shifts when preservation of original content occurs during denoising.",
+ "How strongly partially masked content should be preserved.",
+ "Amplifies the contrast that may be lost in partially masked regions.")
+
+gen_param_labels = SoftInpaintingSettings(
+ "Soft inpainting schedule bias",
+ "Soft inpainting preservation strength",
+ "Soft inpainting transition contrast boost")
+
+el_ids = SoftInpaintingSettings(
+ "mask_blend_power",
+ "mask_blend_scale",
+ "inpaint_detail_preservation")
+
+
+# ------------------- UI -------------------
+
+
+def gradio_ui():
+ import gradio as gr
+ from modules.ui_components import InputAccordion
+
+ with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
+ with gr.Group():
+ gr.Markdown(
+ """
+ Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.
+ **High _Mask blur_** values are recommended!
+ """)
+
+ result = SoftInpaintingSettings(
+ gr.Slider(label=ui_labels.mask_blend_power,
+ info=ui_info.mask_blend_power,
+ minimum=0,
+ maximum=8,
+ step=0.1,
+ value=default.mask_blend_power,
+ elem_id=el_ids.mask_blend_power),
+ gr.Slider(label=ui_labels.mask_blend_scale,
+ info=ui_info.mask_blend_scale,
+ minimum=0,
+ maximum=8,
+ step=0.05,
+ value=default.mask_blend_scale,
+ elem_id=el_ids.mask_blend_scale),
+ gr.Slider(label=ui_labels.inpaint_detail_preservation,
+ info=ui_info.inpaint_detail_preservation,
+ minimum=1,
+ maximum=32,
+ step=0.5,
+ value=default.inpaint_detail_preservation,
+ elem_id=el_ids.inpaint_detail_preservation))
+
+ with gr.Accordion("Help", open=False):
+ gr.Markdown(
+ f"""
+ ### {ui_labels.mask_blend_power}
+
+ The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).
+ This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.
+ This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.
+
+ - **Below 1**: Stronger preservation near the end (with low sigma)
+ - **1**: Balanced (proportional to sigma)
+ - **Above 1**: Stronger preservation in the beginning (with high sigma)
+ """)
+ gr.Markdown(
+ f"""
+ ### {ui_labels.mask_blend_scale}
+
+ Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.
+ This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.
+
+ - **Low values**: Favors generated content.
+ - **High values**: Favors original content.
+ """)
+ gr.Markdown(
+ f"""
+ ### {ui_labels.inpaint_detail_preservation}
+
+ This parameter controls how the original latent vectors and denoised latent vectors are interpolated.
+ With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.
+ This can prevent the loss of contrast that occurs with linear interpolation.
+
+ - **Low values**: Softer blending, details may fade.
+ - **High values**: Stronger contrast, may over-saturate colors.
+ """)
+
+ return (
+ [
+ soft_inpainting_enabled,
+ result.mask_blend_power,
+ result.mask_blend_scale,
+ result.inpaint_detail_preservation
+ ],
+ [
+ (soft_inpainting_enabled, enabled_gen_param_label),
+ (result.mask_blend_power, gen_param_labels.mask_blend_power),
+ (result.mask_blend_scale, gen_param_labels.mask_blend_scale),
+ (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)
+ ]
+ )
diff --git a/modules/ui.py b/modules/ui.py
index d80486dd..bd2091e1 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -29,6 +29,7 @@ import modules.shared as shared
from modules import prompt_parser
from modules.sd_hijack import model_hijack
from modules.generation_parameters_copypaste import image_from_url_text
+import modules.soft_inpainting as si
create_setting_component = ui_settings.create_setting_component
@@ -680,6 +681,9 @@ def create_ui():
mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha")
with FormRow():
+ soft_inpainting = si.gradio_ui()
+
+ with FormRow():
inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")
with FormRow():
@@ -733,6 +737,7 @@ def create_ui():
sampler_name,
mask_blur,
mask_alpha,
+ *(soft_inpainting[0]),
inpainting_fill,
batch_count,
batch_size,
@@ -831,8 +836,10 @@ def create_ui():
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
+ *(soft_inpainting[1]),
*scripts.scripts_img2img.infotext_fields
]
+
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(