diff options
-rw-r--r-- | modules/generation_parameters_copypaste.py | 3 | ||||
-rw-r--r-- | modules/script_callbacks.py | 20 | ||||
-rw-r--r-- | modules/textual_inversion/dataset.py | 52 | ||||
-rw-r--r-- | modules/textual_inversion/image_embedding.py | 4 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 2 |
5 files changed, 72 insertions, 9 deletions
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 620aa606..593d99ef 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr
from modules.shared import script_path
-from modules import shared, ui_tempdir
+from modules import shared, ui_tempdir, script_callbacks
import tempfile
from PIL import Image
@@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): prompt = file.read()
params = parse_generation_parameters(prompt)
+ script_callbacks.infotext_pasted_callback(prompt, params)
res = []
for output, key in paste_fields:
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 608c5300..a9e19236 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,7 +2,7 @@ import sys import traceback
from collections import namedtuple
import inspect
-from typing import Optional
+from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
@@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
+ callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
)
@@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid')
+def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
+ for c in callback_map['callbacks_infotext_pasted']:
+ try:
+ c.callback(infotext, params)
+ except Exception:
+ report_exception(c, 'infotext_pasted')
+
+
def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']):
try:
@@ -290,6 +299,15 @@ def on_image_grid(callback): add_callback(callback_map['callbacks_image_grid'], callback)
+def on_infotext_pasted(callback):
+ """register a function to be called before applying an infotext.
+ The callback is called with two arguments:
+ - infotext: str - raw infotext.
+ - result: Dict[str, any] - parsed infotext parameters.
+ """
+ add_callback(callback_map['callbacks_infotext_pasted'], callback)
+
+
def on_script_unloaded(callback):
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
the script did should be reverted here"""
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index fa48708e..d31963d4 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,8 +3,10 @@ import numpy as np import PIL
import torch
from PIL import Image
-from torch.utils.data import Dataset, DataLoader
+from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
+from collections import defaultdict
+from random import shuffle, choices
import random
import tqdm
@@ -45,12 +47,12 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
+ groups = defaultdict(list)
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
@@ -103,18 +105,25 @@ class PersonalizedBase(Dataset): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
-
+ groups[image.size].append(len(self.dataset))
self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
self.length = len(self.dataset)
+ self.groups = list(groups.values())
assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
+ if len(groups) > 1:
+ print("Buckets:")
+ for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
+ print(f" {w}x{h}: {len(ids)}")
+ print()
+
def create_text(self, filename_text):
text = random.choice(self.lines)
tags = filename_text.split(',')
@@ -137,9 +146,44 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
+
+class GroupedBatchSampler(Sampler):
+ def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ super().__init__(data_source)
+
+ n = len(data_source)
+ self.groups = data_source.groups
+ self.len = n_batch = n // batch_size
+ expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]
+ self.base = [int(e) // batch_size for e in expected]
+ self.n_rand_batches = nrb = n_batch - sum(self.base)
+ self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
+ self.batch_size = batch_size
+
+ def __len__(self):
+ return self.len
+
+ def __iter__(self):
+ b = self.batch_size
+
+ for g in self.groups:
+ shuffle(g)
+
+ batches = []
+ for g in self.groups:
+ batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
+ for _ in range(self.n_rand_batches):
+ rand_group = choices(self.groups, self.probs)[0]
+ batches.append(choices(rand_group, k=b))
+
+ shuffle(batches)
+
+ yield from batches
+
+
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
- super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+ super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index ea653806..5593f88c 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -76,10 +76,10 @@ def insert_image_data_embed(image, data): next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d)))
- data_np_low.resize(next_size)
+ data_np_low = np.resize(data_np_low, next_size)
data_np_low = data_np_low.reshape((h, -1, d))
- data_np_high.resize(next_size)
+ data_np_high = np.resize(data_np_high, next_size)
data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 853246a6..e23906ca 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -479,7 +479,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
- description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
pbar.set_description(description)
shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|