diff options
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 50 |
1 files changed, 16 insertions, 34 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 227e7670..36198a3c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -14,10 +14,6 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention
import ldm.modules.diffusionmodules.model
-from tqdm import trange
-from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer
-import torch.optim as optim
-import copy
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -29,19 +25,16 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (
- 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- elif not cmd_opts.disable_opt_split_attention and (
- cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
if not invokeAI_mps_available and shared.device.type == 'mps':
- print(
- "The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
+ print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
else:
@@ -122,14 +115,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
-
- self.token_mults = {}
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
+ self.token_mults = {}
+
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
- tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if
- '(' in k or ')' in k or '[' in k or ']' in k]
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
@@ -153,8 +145,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else:
parsed = [[line, 1.0]]
- tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[
- "input_ids"]
+ tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
@@ -170,9 +161,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if token == self.comma_token:
last_comma = len(remade_tokens)
- elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens),
- 1) % 75 == 0 and last_comma != -1 and len(
- remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
@@ -224,8 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
- remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms,
- hijack_comments)
+ remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
@@ -265,8 +253,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens):
token = tokens[i]
- embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens,
- i)
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
@@ -289,8 +276,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
- hijack_comments.append(
- f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
@@ -308,17 +294,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(
- text)
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
- batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(
- text)
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
- self.hijack.comments.append(
- "Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+ self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
@@ -351,6 +334,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
z = shared.aesthetic_clip(z, remade_batch_tokens)
+
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
@@ -359,9 +343,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
- remade_batch_tokens = [
- [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
- remade_batch_tokens]
+ remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
|