aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorrandom_thoughtss <random_thoughtss@proton.me>2022-10-29 10:35:51 -0700
committerrandom_thoughtss <random_thoughtss@proton.me>2022-10-29 10:35:51 -0700
commit6e2ce4e735db64afcd0fe637327ca4ec78335706 (patch)
tree428309a1e52a5dfd2d7ce9f70652cb4cdfee9bab /modules
parent44ab954fabb9c1273366ebdca47f8da394d61aab (diff)
Added image conditioning to latent upscale.
Only comuted if the mask weight is not 1.0 to avoid extra memory. Also includes some code cleanup.
Diffstat (limited to 'modules')
-rw-r--r--modules/processing.py29
1 files changed, 11 insertions, 18 deletions
diff --git a/modules/processing.py b/modules/processing.py
index f18b7db2..ee0e9e34 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return torch.zeros(
- x.shape[0], 5, 1, 1,
- dtype=x.dtype,
- device=x.device
- )
+ return x.new_zeros(x.shape[0], 5, 1, 1)
height = height or self.height
width = width or self.width
@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
- return torch.zeros(
- latent_image.shape[0], 5, 1, 1,
- dtype=latent_image.dtype,
- device=latent_image.device
- )
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs
if image_mask is not None:
@@ -174,11 +166,10 @@ class StableDiffusionProcessing():
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
- conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(source_image.device)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@@ -653,7 +644,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
- image_conditioning = self.txt2img_image_conditioning(samples)
+
+ # Avoid making the inpainting conditioning unless necessary as
+ # this does need some extra compute to decode / encode the image again.
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
@@ -675,11 +672,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
- image_conditioning = self.img2img_image_conditioning(
- decoded_samples,
- samples,
- decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
- )
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob()