diff options
Diffstat (limited to 'modules/sd_vae.py')
-rw-r--r-- | modules/sd_vae.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 1db01992..31306d8b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -31,7 +31,9 @@ def get_loaded_vae_hash(): if loaded_vae_file is None: return None - return hashes.sha256(loaded_vae_file, 'vae')[0:10] + sha256 = hashes.sha256(loaded_vae_file, 'vae') + + return sha256[0:10] if sha256 else None def get_base_vae(model): @@ -68,7 +70,6 @@ def get_filename(filepath): def refresh_vae_list(): - global vae_dict vae_dict.clear() paths = [ @@ -102,7 +103,7 @@ def refresh_vae_list(): name = get_filename(filepath) vae_dict[name] = filepath - vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))) + vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))) def find_vae_near_checkpoint(checkpoint_file): @@ -158,7 +159,7 @@ def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution: vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) - if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): + if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic()): return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') return VaeResolution(resolved=False) @@ -191,7 +192,7 @@ def load_vae_dict(filename, map_location): def load_vae(model, vae_file=None, vae_source="from unknown source"): - global vae_dict, loaded_vae_file + global vae_dict, base_vae, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -229,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): restore_base_vae(model) loaded_vae_file = vae_file + model.base_vae = base_vae + model.loaded_vae_file = loaded_vae_file # don't call this from outside @@ -260,7 +263,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): if loaded_vae_file == vae_file: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + if sd_model.lowvram: lowvram.send_everything_to_cpu() else: sd_model.to(devices.cpu) @@ -272,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + if not sd_model.lowvram: sd_model.to(devices.device) print("VAE weights loaded.") |