diff options
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5420903f..853246a6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -9,6 +9,7 @@ import tqdm import html
import datetime
import csv
+import safetensors.torch
from PIL import Image, PngImagePlugin
@@ -150,6 +151,8 @@ class EmbeddingDatabase: name = data.get('name', name)
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ elif ext in ['.SAFETENSORS']:
+ data = safetensors.torch.load_file(path, device="cpu")
else:
return
@@ -245,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
- embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
+ #cond_model expects at least some text, so we provide '*' as backup.
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
+ if init_text:
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -473,7 +479,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch
epoch_step = embedding.step % steps_per_epoch
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
|