diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/sd_samplers.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index fd63e47f..6b7979e2 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -87,7 +87,7 @@ ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdq class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
- self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None
+ self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
self.mask = None
self.nmask = None
self.init_latent = None
@@ -113,7 +113,9 @@ class VanillaStableDiffusionSampler: return samples
def sample(self, p, x, conditioning, unconditional_conditioning):
- self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
+ for fieldname in ['p_sample_ddim', 'p_sample_plms']:
+ if hasattr(self.sampler, fieldname):
+ setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
self.mask = None
self.nmask = None
self.init_latent = None
|