From 525cea924562afd676f55470095268a0f6fca59e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 17:58:07 +0300 Subject: use shared function from processing for creating dummy mask when training inpainting model --- modules/textual_inversion/textual_inversion.py | 33 +++++++------------------- 1 file changed, 9 insertions(+), 24 deletions(-) (limited to 'modules/textual_inversion/textual_inversion.py') diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8731ea5d..2250e41b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -252,26 +252,6 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def create_dummy_mask(x, width=None, height=None): - if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}: - - # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning)) - - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) - - else: - # Dummy zero conditioning if we're not using inpainting model. - # Still takes up a bit of memory, but no encoder call. - # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. - image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) - - return image_conditioning - - def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 @@ -346,7 +326,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: print("No saved optimizer exists in checkpoint") - scaler = torch.cuda.amp.GradScaler() batch_size = ds.batch_size @@ -362,7 +341,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ forced_filename = "" embedding_yet_to_be_embedded = False + is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'} img_c = None + pbar = tqdm.tqdm(total=steps - initial_step) try: for i in range((steps-initial_step) * gradient_step): @@ -384,10 +365,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) c = shared.sd_model.cond_stage_model(batch.cond_text) - if img_c is None: - img_c = create_dummy_mask(c, training_width, training_height) + if is_training_inpainting_model: + if img_c is None: + img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) + + cond = {"c_concat": [img_c], "c_crossattn": [c]} + else: + cond = c - cond = {"c_concat": [img_c], "c_crossattn": [c]} loss = shared.sd_model(x, cond)[0] / gradient_step del x -- cgit v1.2.1