aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_samplers_timesteps.py
diff options
context:
space:
mode:
authorSj-Si <sjw.jetty@gmail.com>2024-01-11 16:37:35 -0500
committerSj-Si <sjw.jetty@gmail.com>2024-01-11 16:37:35 -0500
commit036500223de0a3caaa86360a8ad3ed301e4367b0 (patch)
treef05f0d5fc503d9c35d57bad077a5dab1dfd6569e /modules/sd_samplers_timesteps.py
parent0726a6e12e85a37d1e514f5603acf9f058c11783 (diff)
parentcb5b335acddd126d4f6c990982816c06beb0d6ae (diff)
Merge changes from dev
Diffstat (limited to 'modules/sd_samplers_timesteps.py')
-rw-r--r--modules/sd_samplers_timesteps.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py
index b17a8f93..777dd8d0 100644
--- a/modules/sd_samplers_timesteps.py
+++ b/modules/sd_samplers_timesteps.py
@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v):
- return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
+ return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
@@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
self.eta_default = 0.0
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
def get_timesteps(self, p, steps):
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)