From 43bb5190fc9e7ae479a5dc6640be202c9a71e464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 22:52:23 +0300 Subject: remove/simplify some changes from #6481 --- modules/textual_inversion/dataset.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index bcad6848..fa48708e 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,7 +25,6 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values - self.img_shape = img_shape class PersonalizedBase(Dataset): @@ -46,12 +45,10 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - if varsize: - assert batch_size == 1, 'variable img size must have batch size 1' + assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] - self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out @@ -91,14 +88,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -154,7 +151,6 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) -- cgit v1.2.1