From cb31abcf58ea1f64266e6d821937eed058c35f4d Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 21:54:31 +0700 Subject: Settings to select VAE --- modules/sd_models.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..91ad4b5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks +from modules import shared, modelloader, devices, script_callbacks, sd_vae from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -160,12 +160,11 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - -def load_model_weights(model, checkpoint_info): +def load_model_weights(model, checkpoint_info, force=False): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if checkpoint_info not in checkpoints_loaded: + if force or checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -186,17 +185,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path - - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict) - + sd_vae.load_vae(model, checkpoint_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: @@ -213,7 +202,7 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None): +def load_model(checkpoint_info=None, force=False): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -234,7 +223,7 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -252,16 +241,16 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model, info=None): +def reload_model_weights(sd_model, info=None, force=False): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info) + load_model(checkpoint_info, force=force) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -271,7 +260,7 @@ def reload_model_weights(sd_model, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) -- cgit v1.2.1 From 726769da35970f4c100fa7edf11850f9dc059c41 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 15:19:34 +0700 Subject: Checkpoint cache by combination key of checkpoint and vae --- modules/sd_models.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 91ad4b5e..850f7b7b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -160,11 +160,15 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} -def load_model_weights(model, checkpoint_info, force=False): +def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if force or checkpoint_info not in checkpoints_loaded: + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + + checkpoint_key = (checkpoint_info, vae_file) + + if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -185,24 +189,25 @@ def load_model_weights(model, checkpoint_info, force=False): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, checkpoint_file) + sd_vae.load_vae(model, vae_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU else: - print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) - model.load_state_dict(checkpoints_loaded[checkpoint_info]) + vae_name = sd_vae.get_filename(vae_file) + print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") + checkpoints_loaded.move_to_end(checkpoint_key) + model.load_state_dict(checkpoints_loaded[checkpoint_key]) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None, force=False): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -223,7 +228,7 @@ def load_model(checkpoint_info=None, force=False): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info, force=force) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -250,7 +255,7 @@ def reload_model_weights(sd_model, info=None, force=False): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info, force=force) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -260,7 +265,7 @@ def reload_model_weights(sd_model, info=None, force=False): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info, force=force) + load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) -- cgit v1.2.1 From 36966e3200943dbf890b5338cfa939df552d3c47 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 15:38:58 +0700 Subject: Fix #4035 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..a29c8c1a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -201,7 +201,7 @@ def load_model_weights(model, checkpoint_info): if shared.opts.sd_checkpoint_cache > 0: checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") -- cgit v1.2.1 From bf7a699845675eefdabb9cfa40c55398976274ae Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 16:27:27 +0700 Subject: Fix #4035 for real now --- modules/sd_models.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a29c8c1a..b2dd005a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,6 +165,9 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash + if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -198,16 +201,14 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) - - if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: - checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) model.load_state_dict(checkpoints_loaded[checkpoint_info]) + if shared.opts.sd_checkpoint_cache > 0: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info -- cgit v1.2.1 From 056f06d3738c267b1014e6e8e1ef5bd97af1fb45 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 12:51:46 +0700 Subject: Reload VAE without reloading sd checkpoint --- modules/sd_models.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ab85b65..883639d1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -159,15 +159,13 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd -vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - checkpoint_key = (checkpoint_info, vae_file) + checkpoint_key = checkpoint_info if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -190,13 +188,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, vae_file) - model.first_stage_model.to(devices.dtype_vae) - if shared.opts.sd_checkpoint_cache > 0: + # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU + else: vae_name = sd_vae.get_filename(vae_file) print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") @@ -207,6 +204,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.load_vae(model, vae_file) + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack @@ -254,14 +253,14 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model=None, info=None, force=False): +def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() if not sd_model: sd_model = shared.sd_model - if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): -- cgit v1.2.1 From f2a5cbe6f55592c4c5527b8e0bf99ea8d658f057 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 14:41:29 +0300 Subject: fix #3986 breaking --no-half-vae --- modules/sd_models.py | 9 +++++++++ 1 file changed, 9 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 883639d1..5075fadb 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.to(memory_format=torch.channels_last) if not shared.cmd_opts.no_half: + vae = model.first_stage_model + + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + model.half() + model.first_stage_model = vae devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + model.first_stage_model.to(devices.dtype_vae) + if shared.opts.sd_checkpoint_cache > 0: # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() -- cgit v1.2.1 From 3780ad3ad837dd406da39eebd5d91009b5a58445 Mon Sep 17 00:00:00 2001 From: digburn Date: Fri, 4 Nov 2022 00:40:21 +0000 Subject: fix: loading models without vae from cache --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5075fadb..ae427a5c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,8 +204,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU else: - vae_name = sd_vae.get_filename(vae_file) - print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") + vae_name = sd_vae.get_filename(vae_file) if vae_file else None + vae_message = f" with {vae_name} VAE" if vae_name else "" + print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") checkpoints_loaded.move_to_end(checkpoint_key) model.load_state_dict(checkpoints_loaded[checkpoint_key]) -- cgit v1.2.1 From 99043f33606d3057f83ea52a403e10cd29d1f7e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 11:20:42 +0300 Subject: fix one of previous merges breaking the program --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 63e07a12..34c57bfa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -167,6 +167,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): sd_vae.restore_base_vae(model) checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") -- cgit v1.2.1 From 3b51d239ac9201228c6032fc109111e347e8e6b0 Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Wed, 9 Nov 2022 04:54:21 +0100 Subject: - do not use ckpt cache, if disabled - cache model after is has been loaded from file --- modules/sd_models.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 34c57bfa..720c2a96 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + cache_enabled = shared.opts.sd_checkpoint_cache > 0 + + if cache_enabled: sd_vae.restore_base_vae(model) - checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if checkpoint_info not in checkpoints_loaded: + if cache_enabled and checkpoint_info in checkpoints_loaded: + # use checkpoint cache + vae_name = sd_vae.get_filename(vae_file) if vae_file else None + vae_message = f" with {vae_name} VAE" if vae_name else "" + print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + model.load_state_dict(checkpoints_loaded[checkpoint_info]) + else: + # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): del pl_sd model.load_state_dict(sd, strict=False) del sd + + if cache_enabled: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -199,13 +211,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.first_stage_model.to(devices.dtype_vae) - else: - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - - if shared.opts.sd_checkpoint_cache > 0: + # clean up cache if limit is reached + if cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU -- cgit v1.2.1 From eebf49592ad2c0933e58b06a098b92e48d47e4fe Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Wed, 9 Nov 2022 07:17:09 +0100 Subject: restore #4035 behavior - if checkpoint cache is set to 1, keep 2 models in cache (current +1 more) --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 720c2a96..80addf03 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -213,7 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # clean up cache if limit is reached if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash -- cgit v1.2.1 From abc1e79a5da24a1ea0f4bceedcdf225f32010aa8 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 3 Nov 2022 11:10:53 +0700 Subject: Fix base VAE caching was done after loading VAE, also add safeguard --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..e4dba62c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.clear_loaded_vae() sd_vae.load_vae(model, vae_file) -- cgit v1.2.1 From c7be83bf0240498d9382e2afeaa3f0677d26c7f6 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 11:11:14 +0700 Subject: Misc Misc --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e4dba62c..cd7fe37a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() sd_vae.load_vae(model, vae_file) -- cgit v1.2.1 From 2c5ca706a7e624d268545ba3318ba230b7b33477 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 10:55:47 +0700 Subject: Remove no longer necessary parts and add vae_file safeguard --- modules/sd_models.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..c59151e0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): cache_enabled = shared.opts.sd_checkpoint_cache > 0 - if cache_enabled: - sd_vae.restore_base_vae(model) - - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if cache_enabled and checkpoint_info in checkpoints_loaded: # use checkpoint cache - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + print(f"Loading weights [{sd_model_hash}] from cache") model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file @@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) sd_vae.load_vae(model, vae_file) -- cgit v1.2.1 From 0efffbb407a9d07eae6850374099775385ce176c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 21 Nov 2022 14:04:25 +0100 Subject: Supporting `*.safetensors` format. If a model file exists with extension `.safetensors` then we can load it more safely than with PyTorch weights. --- modules/sd_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..0164cc1b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -45,7 +45,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -180,7 +180,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if checkpoint_file.endswith(".safetensors"): + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") + pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") -- cgit v1.2.1 From 1e506657e1cb732a5f0e567ba2585fba2bbb1327 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sat, 26 Nov 2022 13:28:44 -0500 Subject: no-half support for SD 2.0 --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index c59151e0..0e0bd79e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -244,6 +244,9 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.1 From 6074175faa751dde933aa8e15cd687ca4e4b4a23 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 14:46:40 +0300 Subject: add safetensors to requirements --- modules/sd_models.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ae36841a..77236480 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -5,6 +5,7 @@ import gc from collections import namedtuple import torch import re +import safetensors.torch from omegaconf import OmegaConf from ldm.util import instantiate_from_config @@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - if checkpoint_file.endswith(".safetensors"): - try: - from safetensors.torch import load_file - except ImportError as e: - raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") - pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location) else: pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") -- cgit v1.2.1 From dac9b6f15de5e675053d9490a20e0457dcd1a23e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 15:51:29 +0300 Subject: add safetensors support for model merging #4869 --- modules/sd_models.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 77236480..a1ea5611 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -160,6 +160,20 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) + + if print_global_state and "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + + sd = get_state_dict_from_checkpoint(pl_sd) + return sd + + def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -174,17 +188,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - _, extension = os.path.splitext(checkpoint_file) - if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location) - else: - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) - - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - - sd = get_state_dict_from_checkpoint(pl_sd) - del pl_sd + sd = read_state_dict(checkpoint_file) model.load_state_dict(sd, strict=False) del sd -- cgit v1.2.1 From 0376da180c81a11880a2587903d69d85541051e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 28 Nov 2022 08:39:59 +0300 Subject: make it possible to save nai model using safetensors --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a1ea5611..283cf1cd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -144,8 +144,8 @@ def transform_checkpoint_dict_key(k): def get_state_dict_from_checkpoint(pl_sd): - if "state_dict" in pl_sd: - pl_sd = pl_sd["state_dict"] + pl_sd = pl_sd.pop("state_dict", pl_sd) + pl_sd.pop("state_dict", None) sd = {} for k, v in pl_sd.items(): -- cgit v1.2.1 From 1ed4f0e22807f3afef925210182cbbee51f0cb2c Mon Sep 17 00:00:00 2001 From: Jay Smith Date: Thu, 8 Dec 2022 18:14:35 -0600 Subject: Depth2img model support --- modules/sd_models.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 283cf1cd..139952ba 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,6 +7,9 @@ import torch import re import safetensors.torch from omegaconf import OmegaConf +from os import mkdir +from urllib import request +import ldm.modules.midas as midas from ldm.util import instantiate_from_config @@ -36,6 +39,7 @@ def setup_model(): os.makedirs(model_path) list_models() + enable_midas_autodownload() def checkpoint_tiles(): @@ -227,6 +231,48 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): sd_vae.load_vae(model, vae_file) +def enable_midas_autodownload(): + """ + Gives the ldm.modules.midas.api.load_model function automatic downloading. + + When the 512-depth-ema model, and other future models like it, is loaded, + it calls midas.api.load_model to load the associated midas depth model. + This function applies a wrapper to download the model to the correct + location automatically. + """ + + midas_path = os.path.join(models_path, 'midas') + + # stable-diffusion-stability-ai hard-codes the midas model path to + # a location that differs from where other scripts using this model look. + # HACK: Overriding the path here. + for k, v in midas.api.ISL_PATHS.items(): + file_name = os.path.basename(v) + midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name) + + midas_urls = { + "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt", + "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt", + } + + midas.api.load_model_inner = midas.api.load_model + + def load_model_wrapper(model_type): + path = midas.api.ISL_PATHS[model_type] + if not os.path.exists(path): + if not os.path.exists(midas_path): + mkdir(midas_path) + + print(f"Downloading midas model weights for {model_type} to {path}") + request.urlretrieve(midas_urls[model_type], path) + print(f"{model_type} downloaded") + + return midas.api.load_model_inner(model_type) + + midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() -- cgit v1.2.1 From bd81a09eacf02dad095b98094ab936f276d0343f Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sat, 10 Dec 2022 11:29:26 -0500 Subject: fix support for 2.0 inpainting model while maintaining support for 1.5 inpainting model --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5b37f3fe..b64f573f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -296,6 +296,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.use_ema = False sd_config.model.params.conditioning_key = "hybrid" sd_config.model.params.unet_config.params.in_channels = 9 + sd_config.model.params.finetune_keys = None # Create a "fake" config with a different name so that we know to unload it when switching models. checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) -- cgit v1.2.1 From 59c6511494c55a578eecdf71fb4590b6bd5d04a7 Mon Sep 17 00:00:00 2001 From: Dean van Dugteren <31391056+deanpress@users.noreply.github.com> Date: Sun, 11 Dec 2022 17:08:51 +0100 Subject: fix: fallback model_checkpoint if it's empty This fixes the following error when SD attempts to start with a deleted checkpoint: ``` Traceback (most recent call last): File "D:\Web\stable-diffusion-webui\launch.py", line 295, in start() File "D:\Web\stable-diffusion-webui\launch.py", line 290, in start webui.webui() File "D:\Web\stable-diffusion-webui\webui.py", line 132, in webui initialize() File "D:\Web\stable-diffusion-webui\webui.py", line 62, in initialize modules.sd_models.load_model() File "D:\Web\stable-diffusion-webui\modules\sd_models.py", line 283, in load_model checkpoint_info = checkpoint_info or select_checkpoint() File "D:\Web\stable-diffusion-webui\modules\sd_models.py", line 117, in select_checkpoint checkpoint_info = checkpoints_list.get(model_checkpoint, None) TypeError: unhashable type: 'list' ``` --- modules/sd_models.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 5b37f3fe..b6d75db7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -111,6 +111,10 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint + + if len(model_checkpoint) == 0: + model_checkpoint = shared.default_sd_model_file + checkpoint_info = checkpoints_list.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info -- cgit v1.2.1 From ec0a48826fb41c1b1baab45a9030f7eb55568fd0 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sun, 11 Dec 2022 10:19:46 -0500 Subject: unconditionally set use_ema=False if value not specified (True never worked, and all configs except v1-inpainting-inference.yaml already correctly set it to False) --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b64f573f..f36b299f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -293,7 +293,6 @@ def load_model(checkpoint_info=None): if should_hijack_inpainting(checkpoint_info): # Hardcoded config for now... sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.use_ema = False sd_config.model.params.conditioning_key = "hybrid" sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None @@ -301,6 +300,9 @@ def load_model(checkpoint_info=None): # Create a "fake" config with a different name so that we know to unload it when switching models. checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + do_inpainting_hijack() if shared.cmd_opts.no_half: -- cgit v1.2.1 From 5a650055de3792223a91925aba8130ebdee29e35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?linuxmobile=20=28=20=E3=83=AA=E3=83=8A=E3=83=83=E3=82=AF?= =?UTF-8?q?=E3=82=B9=20=29?= Date: Sat, 24 Dec 2022 09:25:35 -0300 Subject: Removed lenght in sd_model at line 115 Commit eba60a4 is what is causing this error, delete the length check in sd_model starting at line 115 and it's fine. https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5971#issuecomment-1364507379 --- modules/sd_models.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1254e5ae..6ca06211 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -111,9 +111,6 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - - if len(model_checkpoint) == 0: - model_checkpoint = shared.default_sd_model_file checkpoint_info = checkpoints_list.get(model_checkpoint, None) if checkpoint_info is not None: -- cgit v1.2.1 From 3bf5591efe9a9f219c6088be322a87adc4f48f95 Mon Sep 17 00:00:00 2001 From: Yuval Aboulafia Date: Sat, 24 Dec 2022 21:35:29 +0200 Subject: fix F541 f-string without any placeholders --- modules/sd_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ca06211..ecdd91c5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -117,13 +117,13 @@ def select_checkpoint(): return checkpoint_info if len(checkpoints_list) == 0: - print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) + print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) if shared.cmd_opts.ckpt is not None: print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) print(f" - directory {model_path}", file=sys.stderr) if shared.cmd_opts.ckpt_dir is not None: print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) - print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) + print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) exit(1) checkpoint_info = next(iter(checkpoints_list.values())) @@ -324,7 +324,7 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) - print(f"Model loaded.") + print("Model loaded.") return sd_model @@ -359,5 +359,5 @@ def reload_model_weights(sd_model=None, info=None): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print(f"Weights loaded.") + print("Weights loaded.") return sd_model -- cgit v1.2.1 From 5ba04f9ec050a66e918571f07e8863f157f05b44 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Dec 2022 13:45:58 +0100 Subject: Attempting to solve slow loads for `safetensors`. Fixes #5893 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..cd938656 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -168,7 +168,10 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) + device = map_location or shared.weight_load_location + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.1 From f55ac33d446185680604e872ceda2ae858821d5c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 31 Dec 2022 11:27:02 -0500 Subject: validate textual inversion embeddings --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..ebd4dff7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -325,6 +325,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model + return sd_model -- cgit v1.2.1 From 311354c0bb8930ea939d6aa6b3edd50c69301320 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 00:38:09 +0300 Subject: fix the issue with training on SD2.0 --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ebd4dff7..bff8d6c9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + model.logvar = model.logvar.to(devices.device) # fix for training + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) -- cgit v1.2.1 From 8f96f9289981a66741ba770d14f3d27ce335a0fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 18:39:14 +0300 Subject: call script callbacks for reloaded model after loading embeddings --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index bff8d6c9..b98b05fc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -324,12 +324,12 @@ def load_model(checkpoint_info=None): sd_model.eval() shared.sd_model = sd_model + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model - return sd_model -- cgit v1.2.1 From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/sd_models.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc..6846b74a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) - - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model -- cgit v1.2.1 From 8d8a05a3bbb50fdfeab51679a919d2487bd97976 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:47:42 +0300 Subject: find configs for models at runtime rather than when starting --- modules/sd_models.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6846b74a..6dca4ddf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() @@ -48,6 +48,14 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def find_checkpoint_config(info): + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return shared.cmd_opts.config + + def list_models(): checkpoints_list.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) @@ -73,7 +81,7 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) @@ -81,12 +89,7 @@ def list_models(): h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - basename, _ = os.path.splitext(filename) - config = basename + ".yaml" - if not os.path.exists(config): - config = shared.cmd_opts.config - - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) def get_closet_checkpoint_match(searchString): @@ -282,9 +285,10 @@ def enable_midas_autodownload(): def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + checkpoint_config = find_checkpoint_config(checkpoint_info) - if checkpoint_info.config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_info.config}") + if checkpoint_config != shared.cmd_opts.config: + print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -292,7 +296,7 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_info.config) + sd_config = OmegaConf.load(checkpoint_config) if should_hijack_inpainting(checkpoint_info): # Hardcoded config for now... @@ -302,7 +306,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.finetune_keys = None # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None): sd_model = shared.sd_model current_checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_config = find_checkpoint_config(current_checkpoint_info) if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.1 From 0cd6399b8b1699b8b7acad6f0ad2988111fe618e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 14:29:13 +0300 Subject: fix broken inpainting model --- modules/sd_models.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6dca4ddf..a568823d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -305,9 +305,6 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None - # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) - if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False -- cgit v1.2.1 From 642142556d8ecdea9beb86d7618b628b1803ab98 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 15:09:53 +0300 Subject: use commandline-supplied cuda device name instead of cuda:0 for safetensors PR that doesn't fix anything --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ee918f24..76a89e88 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -173,7 +173,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None if extension.lower() == ".safetensors": device = map_location or shared.weight_load_location if device is None: - device = "cuda:0" if torch.cuda.is_available() else "cpu" + device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.1