From f40617d6c4e366773677baa8d7f4114ba2893282 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 3 Sep 2022 17:21:15 +0300 Subject: support for scripts --- modules/sd_samplers.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) (limited to 'modules/sd_samplers.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 54c5fd7c..6f028f5f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -9,18 +9,28 @@ from ldm.models.diffusion.plms import PLMSSampler from modules.shared import opts, cmd_opts, state import modules.shared as shared -SamplerData = namedtuple('SamplerData', ['name', 'constructor']) + +SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases']) + +samplers_k_diffusion = [ + ('Euler a', 'sample_euler_ancestral', ['k_euler_a']), + ('Euler', 'sample_euler', ['k_euler']), + ('LMS', 'sample_lms', ['k_lms']), + ('Heun', 'sample_heun', ['k_heun']), + ('DPM2', 'sample_dpm_2', ['k_dpm_2']), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), +] + +samplers_data_k_diffusion = [ + SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) + for label, funcname, aliases in samplers_k_diffusion + if hasattr(k_diffusion.sampling, funcname) +] + samplers = [ - *[SamplerData(x[0], lambda model, funcname=x[1]: KDiffusionSampler(funcname, model)) for x in [ - ('Euler a', 'sample_euler_ancestral'), - ('Euler', 'sample_euler'), - ('LMS', 'sample_lms'), - ('Heun', 'sample_heun'), - ('DPM2', 'sample_dpm_2'), - ('DPM2 a', 'sample_dpm_2_ancestral'), - ] if hasattr(k_diffusion.sampling, x[1])], - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSample, model)), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model)), + *samplers_data_k_diffusion, + SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSampler, model), []), + SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model), []), ] samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] -- cgit v1.2.1