aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordrhead <1313496+drhead@users.noreply.github.com>2023-12-09 14:09:28 -0500
committerGitHub <noreply@github.com>2023-12-09 14:09:28 -0500
commit5381405eaa1e809e5cfb97522bd4c19d3c946079 (patch)
tree2d7630111fcdec4dd0c51f8a2cdfb4a3a4aace94
parent78acdcf677a96894651ff0d7d8287f2a994f3781 (diff)
re-derive sqrt alpha bar and sqrt one minus alphabar
This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent.
-rw-r--r--modules/sd_samplers_timesteps.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index b17a8f93..c4bd5c12 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
- return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+ return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)