diff options
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 113 |
1 files changed, 88 insertions, 25 deletions
diff --git a/modules/processing.py b/modules/processing.py index 6f01c95f..7d46949f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,6 +30,7 @@ import modules.sd_models as sd_models import modules.sd_vae as sd_vae
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
+import modules.soft_inpainting as si
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
@@ -62,6 +63,16 @@ def apply_color_correction(correction, original_image): return image.convert('RGB')
+def uncrop(image, dest_size, paste_loc):
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', dest_size)
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ return image
+
+
def apply_overlay(image, paste_loc, index, overlays):
if overlays is None or index >= len(overlays):
return image
@@ -69,11 +80,7 @@ def apply_overlay(image, paste_loc, index, overlays): overlay = overlays[index]
if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@@ -81,9 +88,12 @@ def apply_overlay(image, paste_loc, index, overlays): return image
-def create_binary_mask(image):
+def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
- image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ if round:
+ image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
+ else:
+ image = image.split()[-1].convert("L")
else:
image = image.convert('L')
return image
@@ -140,6 +150,7 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
+ masks_for_overlay: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = None
@@ -308,7 +319,7 @@ class StableDiffusionProcessing: c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -320,8 +331,10 @@ class StableDiffusionProcessing: conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
- # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
- conditioning_mask = torch.round(conditioning_mask)
+ if round_image_mask:
+ # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
+ conditioning_mask = torch.round(conditioning_mask)
+
else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
@@ -345,7 +358,7 @@ class StableDiffusionProcessing: return image_conditioning
- def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
@@ -357,7 +370,7 @@ class StableDiffusionProcessing: return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
@@ -869,9 +882,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
+ # todo: generate adaptive masks based on pixel differences.
+ if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
+ si.apply_masks(soft_inpainting=p.soft_inpainting,
+ nmask=p.nmask,
+ overlay_images=p.overlay_images,
+ masks_for_overlay=p.masks_for_overlay,
+ width=p.width,
+ height=p.height,
+ paste_to=p.paste_to)
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
+
+ # Generate the mask(s) based on similarity between the original and denoised latent vectors
+ if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
+ si.apply_adaptive_masks(latent_orig=p.init_latent,
+ latent_processed=samples_ddim,
+ overlay_images=p.overlay_images,
+ masks_for_overlay=p.masks_for_overlay,
+ width=p.width,
+ height=p.height,
+ paste_to=p.paste_to)
+
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -928,6 +961,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
+ # If the intention is to show the output from the model
+ # that is being composited over the original image,
+ # we need to keep the original image around
+ # and use it in the composite step.
+ original_denoised_image = image.copy()
+
+ if p.paste_to is not None:
+ original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to)
+
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if save_samples:
@@ -938,16 +980,24 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
+
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
+ mask_for_overlay = p.mask_for_overlay
+ elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]:
+ mask_for_overlay = p.masks_for_overlay[i]
+ else:
+ mask_for_overlay = None
+
+ if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
- image_mask = p.mask_for_overlay.convert('RGB')
+ image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.return_mask:
output_images.append(image_mask)
if opts.return_mask_composite or opts.save_mask_composite:
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite:
@@ -1351,6 +1401,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
+ soft_inpainting: si.SoftInpaintingParameters = si.default
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
@@ -1396,7 +1447,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
- image_mask = create_binary_mask(image_mask)
+ image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None))
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
@@ -1414,7 +1465,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
- self.mask_for_overlay = image_mask
+ self.mask_for_overlay = image_mask if self.soft_inpainting is None else None
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@@ -1425,10 +1476,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
- np_mask = np.array(image_mask)
- np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
+ if self.soft_inpainting is None:
+ np_mask = np.array(image_mask)
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
+ self.mask_for_overlay = Image.fromarray(np_mask)
+
+ self.masks_for_overlay = [] if self.soft_inpainting is not None else None
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@@ -1450,10 +1504,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
+ if self.soft_inpainting is not None:
+ # We apply the masks AFTER to adjust mask based on changed content.
+ self.overlay_images.append(image.convert('RGBA'))
+ self.masks_for_overlay.append(image_mask)
+ else:
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
- self.overlay_images.append(image_masked.convert('RGBA'))
+ self.overlay_images.append(image_masked.convert('RGBA'))
# crop_region is not None if we are doing inpaint full res
if crop_region is not None:
@@ -1477,6 +1536,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+ if self.masks_for_overlay is not None:
+ self.masks_for_overlay = self.masks_for_overlay * self.batch_size
+
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size
@@ -1503,7 +1565,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
- latmask = np.around(latmask)
+ if self.soft_inpainting is None:
+ latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
@@ -1515,7 +1578,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
- self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
@@ -1526,7 +1589,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
- if self.mask is not None:
+ if self.mask is not None and self.soft_inpainting is None:
samples = samples * self.nmask + self.init_latent * self.mask
del x
|