From 6f0abbb71a3f29d6df63fed82d5d5e196ca0d4de Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 29 Jul 2023 15:15:06 +0300 Subject: textual inversion support for SDXL --- modules/sd_hijack.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c8fdd4f1..cfa5f0eb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -197,7 +197,7 @@ class StableDiffusionModelHijack: conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': - embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g') conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) text_cond_models.append(conditioner.embedders[i]) @@ -292,10 +292,11 @@ class StableDiffusionModelHijack: class EmbeddingsWithFixes(torch.nn.Module): - def __init__(self, wrapped, embeddings): + def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): super().__init__() self.wrapped = wrapped self.embeddings = embeddings + self.textual_inversion_key = textual_inversion_key def forward(self, input_ids): batch_fixes = self.embeddings.fixes @@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = devices.cond_cast_unet(embedding.vec) + vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec + emb = devices.cond_cast_unet(vec) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) -- cgit v1.2.1