From c77c89cc83c618472ad352cf8a28fde28c3a1377 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 10:23:31 +0300 Subject: make main model loading and model merger use the same code --- modules/sd_models.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index cb3982b1..18fb8c2e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -122,6 +122,13 @@ def select_checkpoint(): return checkpoint_info +def get_state_dict_from_checkpoint(pl_sd): + if "state_dict" in pl_sd: + return pl_sd["state_dict"] + + return pl_sd + + def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -131,11 +138,8 @@ def load_model_weights(model, checkpoint_info): pl_sd = torch.load(checkpoint_file, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") - - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd + + sd = get_state_dict_from_checkpoint(pl_sd) model.load_state_dict(sd, strict=False) -- cgit v1.2.1 From 4e569fd888f8e3c5632a072d51abbb6e4d17abd6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 10:31:47 +0300 Subject: fixed incorrect message about loading config; thanks anon! --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 18fb8c2e..2101b18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -169,7 +169,7 @@ def load_model(): checkpoint_info = select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: - print(f"Loading config from: {shared.cmd_opts.config}") + print(f"Loading config from: {checkpoint_info.config}") sd_config = OmegaConf.load(checkpoint_info.config) sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.1 From f4578b343ded3b8ccd1879ea0c0b3cdadfcc3a5f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 13:23:30 +0300 Subject: fix model switching not working properly if there is a different yaml config --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 2101b18d..d0c74dd8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -196,7 +196,8 @@ def reload_model_weights(sd_model, info=None): return if sd_model.sd_checkpoint_info.config != checkpoint_info.config: - return load_model() + shared.sd_model = load_model() + return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() -- cgit v1.2.1 From d6d10a37bfd21568e74efb46137f906da96d5fdb Mon Sep 17 00:00:00 2001 From: William Moorehouse Date: Sun, 9 Oct 2022 04:58:40 -0400 Subject: Added extended model details to infotext --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index d0c74dd8..3fa42329 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,7 @@ import sys from collections import namedtuple import torch from omegaconf import OmegaConf - +from pathlib import Path from ldm.util import instantiate_from_config @@ -158,6 +158,7 @@ def load_model_weights(model, checkpoint_info): vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) + model.sd_model_vae_name = Path(vae_file).stem model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file -- cgit v1.2.1 From e6e8cabe0c9c335e0d72345602c069b198558b53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:57:48 +0300 Subject: change up #2056 to make it work how i want it to plus make xy plot write correct values to images --- modules/sd_models.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 3fa42329..e63d3c29 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,6 @@ import sys from collections import namedtuple import torch from omegaconf import OmegaConf -from pathlib import Path from ldm.util import instantiate_from_config @@ -158,7 +157,6 @@ def load_model_weights(model, checkpoint_info): vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) - model.sd_model_vae_name = Path(vae_file).stem model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file -- cgit v1.2.1