aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/Lora/networks.py39
-rw-r--r--modules/scripts.py35
-rw-r--r--modules/sd_samplers_cfg_denoiser.py6
-rw-r--r--modules/sd_samplers_timesteps.py6
-rw-r--r--modules/sd_samplers_timesteps_impl.py18
5 files changed, 73 insertions, 31 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index c252ed9e..22fdff4a 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -277,7 +277,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
self.weight.copy_(weights_backup)
if bias_backup is not None:
- self.bias.copy_(bias_backup)
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias.copy_(bias_backup)
+ else:
+ self.bias.copy_(bias_backup)
+ else:
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.out_proj.bias = None
+ else:
+ self.bias = None
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
@@ -304,8 +312,13 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
- if bias_backup is None and getattr(self, 'bias', None) is not None:
- bias_backup = self.bias.to(devices.cpu, copy=True)
+ if bias_backup is None:
+ if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
+ bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
+ elif getattr(self, 'bias', None) is not None:
+ bias_backup = self.bias.to(devices.cpu, copy=True)
+ else:
+ bias_backup = None
self.network_bias_backup = bias_backup
if current_names != wanted_names:
@@ -323,8 +336,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown
- if ex_bias is not None and getattr(self, 'bias', None) is not None:
- self.bias += ex_bias
+ if ex_bias is not None and hasattr(self, 'bias'):
+ if self.bias is None:
+ self.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
@@ -339,14 +355,19 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
try:
with torch.no_grad():
- updown_q = module_q.calc_updown(self.in_proj_weight)
- updown_k = module_k.calc_updown(self.in_proj_weight)
- updown_v = module_v.calc_updown(self.in_proj_weight)
+ updown_q, _ = module_q.calc_updown(self.in_proj_weight)
+ updown_k, _ = module_k.calc_updown(self.in_proj_weight)
+ updown_v, _ = module_v.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- updown_out = module_out.calc_updown(self.out_proj.weight)
+ updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
self.in_proj_weight += updown_qkv
self.out_proj.weight += updown_out
+ if ex_bias is not None:
+ if self.out_proj.bias is None:
+ self.out_proj.bias = torch.nn.Parameter(ex_bias)
+ else:
+ self.out_proj.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
diff --git a/modules/scripts.py b/modules/scripts.py
index d4a9da94..cbdac2b5 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -239,6 +239,8 @@ class Script:
"""
Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
+ May be called in show() or ui() - but it may be too late in latter as some components may already be created.
+
This function is an alternative to before_component in that it also cllows to run before a component is created, but
it doesn't require to be called for every created component - just for the one you need.
"""
@@ -445,6 +447,28 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
+ self.apply_on_before_component_callbacks()
+
+ def apply_on_before_component_callbacks(self):
+ for script in self.scripts:
+ on_before = script.on_before_component_elem_id or []
+ on_after = script.on_after_component_elem_id or []
+
+ for elem_id, callback in on_before:
+ if elem_id not in self.on_before_component_elem_id:
+ self.on_before_component_elem_id[elem_id] = []
+
+ self.on_before_component_elem_id[elem_id].append((callback, script))
+
+ for elem_id, callback in on_after:
+ if elem_id not in self.on_after_component_elem_id:
+ self.on_after_component_elem_id[elem_id] = []
+
+ self.on_after_component_elem_id[elem_id].append((callback, script))
+
+ on_before.clear()
+ on_after.clear()
+
def create_script_ui(self, script):
import modules.api.models as api_models
@@ -555,16 +579,7 @@ class ScriptRunner:
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
- for script in self.scripts:
- for elem_id, callback in script.on_before_component_elem_id or []:
- items = self.on_before_component_elem_id.get(elem_id, [])
- items.append((callback, script))
- self.on_before_component_elem_id[elem_id] = items
-
- for elem_id, callback in script.on_after_component_elem_id or []:
- items = self.on_after_component_elem_id.get(elem_id, [])
- items.append((callback, script))
- self.on_after_component_elem_id[elem_id] = items
+ self.apply_on_before_component_callbacks()
return self.inputs
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py
index 113425b2..bc9b97e4 100644
--- a/modules/sd_samplers_cfg_denoiser.py
+++ b/modules/sd_samplers_cfg_denoiser.py
@@ -56,6 +56,7 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler
self.model_wrap = None
self.p = None
+ self.mask_before_denoising = False
@property
def inner_model(self):
@@ -104,7 +105,7 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
- if self.mask is not None:
+ if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x
batch_size = len(conds_list)
@@ -206,6 +207,9 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ if not self.mask_before_denoising and self.mask is not None:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
if opts.live_preview_content == "Prompt":
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index 16572c7e..c1f534ed 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -49,12 +49,12 @@ class CFGDenoiserTimesteps(CFGDenoiser):
super().__init__(sampler)
self.alphas = shared.sd_model.alphas_cumprod
+ self.mask_before_denoising = True
def get_pred_x0(self, x_in, x_out, sigma):
- ts = int(sigma.item())
+ ts = sigma.to(dtype=int)
- s_in = x_in.new_ones([x_in.shape[0]])
- a_t = self.alphas[ts].item() * s_in
+ a_t = self.alphas[ts][:, None, None, None]
sqrt_one_minus_at = (1 - a_t).sqrt()
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py
index d32e3521..a72daafd 100644
--- a/modules/sd_samplers_timesteps_impl.py
+++ b/modules/sd_samplers_timesteps_impl.py
@@ -16,16 +16,17 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args
- s_in = x.new_ones([x.shape[0]])
+ s_in = x.new_ones((x.shape[0]))
+ s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
- a_t = alphas[index].item() * s_in
- a_prev = alphas_prev[index].item() * s_in
- sigma_t = sigmas[index].item() * s_in
- sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+ a_t = alphas[index].item() * s_x
+ a_prev = alphas_prev[index].item() * s_x
+ sigma_t = sigmas[index].item() * s_x
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
@@ -47,13 +48,14 @@ def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
+ s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = []
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
- a_t = alphas[index].item() * s_in
- a_prev = alphas_prev[index].item() * s_in
- sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
+ a_t = alphas[index].item() * s_x
+ a_prev = alphas_prev[index].item() * s_x
+ sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()