From edb10092de516dda5271130ed53628387780a859 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 16:29:00 +0100 Subject: Add ability to choose using weighted loss or not --- modules/textual_inversion/dataset.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index f4ce4552..1568b2b8 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -31,7 +31,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -64,7 +64,7 @@ class PersonalizedBase(Dataset): image = Image.open(path) #Currently does not work for single color transparency #We would need to read image.info['transparency'] for that - if 'A' in image.getbands(): + if use_weight and 'A' in image.getbands(): alpha_channel = image.getchannel('A') image = image.convert('RGB') if not varsize: @@ -104,7 +104,7 @@ class PersonalizedBase(Dataset): latent_sampling_method = "once" latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - if alpha_channel is not None: + if use_weight and alpha_channel is not None: channels, *latent_size = latent_sample.shape weight_img = alpha_channel.resize(latent_size) npweight = np.array(weight_img).astype(np.float32) @@ -113,9 +113,11 @@ class PersonalizedBase(Dataset): #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default. weight -= weight.min() weight /= weight.mean() - else: + elif use_weight: #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later weight = torch.ones([channels] + latent_size) + else: + weight = None if latent_sampling_method == "random": entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) @@ -219,7 +221,10 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) + if all(entry.weight is not None for entry in data): + self.weight = torch.stack([entry.weight for entry in data]).squeeze(1) + else: + self.weight = None #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) -- cgit v1.2.1