diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-06 17:53:33 +0300 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-06 17:53:33 +0300 |
commit | 5a0db84b6c7322082c7532df11a29a95a59a612b (patch) | |
tree | d28c4ed2cc29a105bcc1895b3600fc24f9c0aaec /modules/sd_samplers_common.py | |
parent | 956e69bf3a9387b6a6cda6823d729e3d3f13c3e1 (diff) |
add infotext
add proper support for recalculating conds in k-diffusion samplers
remove support for compvis samplers
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r-- | modules/sd_samplers_common.py | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 3f3e83e3..92bf0ca1 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -131,16 +131,27 @@ replace_torchsde_browinan() def apply_refiner(sampler):
completed_ratio = sampler.step / sampler.steps
- if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint:
- refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
- if refiner_checkpoint_info is None:
- raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
- with sd_models.SkipWritingToConfig():
- sd_models.reload_model_weights(info=refiner_checkpoint_info)
+ if completed_ratio <= shared.opts.sd_refiner_switch_at:
+ return False
+
+ if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
+ return False
+
+ refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
+ if refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
+
+ sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
+ sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
+
+ with sd_models.SkipWritingToConfig():
+ sd_models.reload_model_weights(info=refiner_checkpoint_info)
+
+ devices.torch_gc()
+ sampler.p.setup_conds()
+ sampler.update_inner_model()
- devices.torch_gc()
+ return True
- sampler.update_inner_model()
- sampler.p.setup_conds()
|