aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r--extensions-builtin/Lora/lora.py92
1 files changed, 76 insertions, 16 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index b7e775ae..d4345ada 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -3,7 +3,9 @@ import os
import re
import torch
-from modules import shared, devices, sd_models
+from modules import shared, devices, sd_models, errors
+
+metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
@@ -51,6 +53,23 @@ class LoraOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
+ self.metadata = {}
+
+ _, ext = os.path.splitext(filename)
+ if ext.lower() == ".safetensors":
+ try:
+ self.metadata = sd_models.read_metadata_from_safetensors(filename)
+ except Exception as e:
+ errors.display(e, f"reading lora {filename}")
+
+ if self.metadata:
+ m = {}
+ for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
+ m[k] = v
+
+ self.metadata = m
+
+ self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
class LoraModule:
@@ -125,7 +144,7 @@ def load_lora(name, filename):
with torch.no_grad():
module.weight.copy_(weight)
- module.to(device=devices.device, dtype=devices.dtype)
+ module.to(device=devices.cpu, dtype=devices.dtype)
if lora_key == "lora_up.weight":
lora_module.up = module
@@ -171,28 +190,69 @@ def load_loras(names, multipliers=None):
loaded_loras.append(lora)
-def lora_forward(module, input, res):
- if len(loaded_loras) == 0:
- return res
+def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
+ """
+ Applies the currently selected set of Loras to the weight of torch layer self.
+ If weights already have this particular set of loras applied, does nothing.
+ If not, restores orginal weights from backup and alters weights according to loras.
+ """
- lora_layer_name = getattr(module, 'lora_layer_name', None)
- for lora in loaded_loras:
- module = lora.modules.get(lora_layer_name, None)
- if module is not None:
- if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
- res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
- else:
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ current_names = getattr(self, "lora_current_names", ())
+ wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
+
+ weights_backup = getattr(self, "lora_weights_backup", None)
+ if weights_backup is None:
+ weights_backup = self.weight.to(devices.cpu, copy=True)
+ self.lora_weights_backup = weights_backup
+
+ if current_names != wanted_names:
+ if weights_backup is not None:
+ self.weight.copy_(weights_backup)
+
+ lora_layer_name = getattr(self, 'lora_layer_name', None)
+ for lora in loaded_loras:
+ module = lora.modules.get(lora_layer_name, None)
+ if module is None:
+ continue
+
+ with torch.no_grad():
+ up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
+ down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)
- return res
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ else:
+ updown = up @ down
+
+ self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+ setattr(self, "lora_current_names", wanted_names)
def lora_Linear_forward(self, input):
- return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
+ lora_apply_weights(self)
+
+ return torch.nn.Linear_forward_before_lora(self, input)
+
+
+def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
+ setattr(self, "lora_current_names", ())
+ setattr(self, "lora_weights_backup", None)
+
+ return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
def lora_Conv2d_forward(self, input):
- return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
+ lora_apply_weights(self)
+
+ return torch.nn.Conv2d_forward_before_lora(self, input)
+
+
+def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
+ setattr(self, "lora_current_names", ())
+ setattr(self, "lora_weights_backup", None)
+
+ return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
def lora_NonDynamicallyQuantizableLinear_forward(self, input):