From 80b26d2a69617b75d2d01c1e6b7d11445815ed4d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 25 Mar 2023 23:06:33 +0300 Subject: apply Lora by altering layer's weights instead of adding more calculations in forward() --- extensions-builtin/Lora/lora.py | 72 ++++++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 16 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7c371deb..a737fec3 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -131,7 +131,7 @@ def load_lora(name, filename): with torch.no_grad(): module.weight.copy_(weight) - module.to(device=devices.device, dtype=devices.dtype) + module.to(device=devices.cpu, dtype=devices.dtype) if lora_key == "lora_up.weight": lora_module.up = module @@ -177,29 +177,69 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_forward(module, input, res): - input = devices.cond_cast_unet(input) - if len(loaded_loras) == 0: - return res +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): + """ + Applies the currently selected set of Loras to the weight of torch layer self. + If weights already have this particular set of loras applied, does nothing. + If not, restores orginal weights from backup and alters weights according to loras. + """ - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None: - if shared.opts.lora_apply_to_outputs and res.shape == input.shape: - res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - else: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + current_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) + + weights_backup = getattr(self, "lora_weights_backup", None) + if weights_backup is None: + weights_backup = self.weight.to(devices.cpu, copy=True) + self.lora_weights_backup = weights_backup + + if current_names != wanted_names: + if weights_backup is not None: + self.weight.copy_(weights_backup) + + lora_layer_name = getattr(self, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is None: + continue - return res + with torch.no_grad(): + up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype) + down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + setattr(self, "lora_current_names", wanted_names) def lora_Linear_forward(self, input): - return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Linear_forward_before_lora(self, input) + + +def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) def lora_Conv2d_forward(self, input): - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lora(self, input) + + +def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): -- cgit v1.2.1