From 7f1d087cba681ddd12cd54152090b176f26bd25c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 4 Aug 2023 12:01:28 +0900 Subject: sort VAE --- modules/sd_vae.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e4ff2994..84271db0 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -50,6 +50,7 @@ def get_filename(filepath): def refresh_vae_list(): + global vae_dict vae_dict.clear() paths = [ @@ -83,6 +84,8 @@ def refresh_vae_list(): name = get_filename(filepath) vae_dict[name] = filepath + vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))) + def find_vae_near_checkpoint(checkpoint_file): checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0] -- cgit v1.2.1 From 45601766409e531d2b4ee512bf1433600f140183 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 4 Aug 2023 22:05:40 +0300 Subject: added VAE selection to checkpoint user metadata --- modules/sd_vae.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 84271db0..0bd5e19b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,6 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -16,6 +16,7 @@ checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: return base_vae @@ -100,6 +101,16 @@ def resolve_vae(checkpoint_file): if shared.cmd_opts.vae_path is not None: return shared.cmd_opts.vae_path, 'from commandline argument' + metadata = extra_networks.get_user_metadata(checkpoint_file) + vae_metadata = metadata.get("vae", None) + if vae_metadata is not None and vae_metadata != "Automatic": + if vae_metadata == "None": + return None, None + + vae_from_metadata = vae_dict.get(vae_metadata, None) + if vae_from_metadata is not None: + return vae_from_metadata, "from user metadata" + is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) -- cgit v1.2.1 From c96e4750d895a47290dc7f96e030197069c75fa4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 7 Aug 2023 08:07:09 +0300 Subject: SD VAE rework 2 - the setting for preferring opts.sd_vae has been inverted and reworded - resolve_vae function made easier to read and now returns an object rather than a tuple - if the checkbox for overriding per-model preferences is checked, opts.sd_vae overrides checkpoint user metadata - changing VAE in user metadata for currently loaded model immediately applies the selection --- modules/sd_vae.py | 71 ++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 16 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0bd5e19b..38bcb840 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,7 @@ import os import collections +from dataclasses import dataclass + from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file): return None -def resolve_vae(checkpoint_file): - if shared.cmd_opts.vae_path is not None: - return shared.cmd_opts.vae_path, 'from commandline argument' +@dataclass +class VaeResolution: + vae: str = None + source: str = None + resolved: bool = True + + def tuple(self): + return self.vae, self.source + + +def is_automatic(): + return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + + +def resolve_vae_from_setting() -> VaeResolution: + if shared.opts.sd_vae == "None": + return VaeResolution() + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return VaeResolution(vae_from_options, 'specified in settings') + + if not is_automatic(): + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + return VaeResolution(resolved=False) + + +def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: metadata = extra_networks.get_user_metadata(checkpoint_file) vae_metadata = metadata.get("vae", None) if vae_metadata is not None and vae_metadata != "Automatic": if vae_metadata == "None": - return None, None + return VaeResolution() vae_from_metadata = vae_dict.get(vae_metadata, None) if vae_from_metadata is not None: - return vae_from_metadata, "from user metadata" + return VaeResolution(vae_from_metadata, "from user metadata") + + return VaeResolution(resolved=False) - is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config +def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution: vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): - return vae_near_checkpoint, 'found near the checkpoint' + return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') - if shared.opts.sd_vae == "None": - return None, None + return VaeResolution(resolved=False) - vae_from_options = vae_dict.get(shared.opts.sd_vae, None) - if vae_from_options is not None: - return vae_from_options, 'specified in settings' - if not is_automatic: - print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") +def resolve_vae(checkpoint_file) -> VaeResolution: + if shared.cmd_opts.vae_path is not None: + return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument') + + if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic(): + return resolve_vae_from_setting() + + res = resolve_vae_from_user_metadata(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_near_checkpoint(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_from_setting() - return None, None + return res def load_vae_dict(filename, map_location): @@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): checkpoint_file = checkpoint_info.filename if vae_file == unspecified: - vae_file, vae_source = resolve_vae(checkpoint_file) + vae_file, vae_source = resolve_vae(checkpoint_file).tuple() else: vae_source = "from function argument" -- cgit v1.2.1