From 75c4511e6b81ae8fb0dbd932043e8eb35cd09f72 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 29 Nov 2022 10:28:41 +0800 Subject: add AltDiffusion to webui Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/sd_hijack.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..26280fe4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -70,14 +70,19 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + + if shared.text_model_name == "XLMR-Large": + model_embeddings = m.cond_stage_model.roberta.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) + else : + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embeddings, self) - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model - apply_optimizations() + # apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -125,8 +130,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer self.token_mults = {} - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - + try: + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + except: + self.comma_token = None + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 @@ -298,6 +306,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): + if shared.text_model_name == "XLMR-Large": + return self.wrapped.encode(text) + use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -359,7 +370,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = self.wrapped.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state - + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) -- cgit v1.2.1