aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-04 12:35:07 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-01-04 12:35:07 +0300
commit02d7abf5141431b9a3a8a189bb3136c71abd5e79 (patch)
tree6b19b67fab476214ffd5d19a316d4daf8baf1a70 /modules/sd_models.py
parent7e549468b3d01e6cdf6d07d00c2719c1a5787517 (diff)
helpful error message when trying to load 2.0 without config
failing to load model weights from settings won't break generation for currently loaded model anymore
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py26
1 files changed, 18 insertions, 8 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b98b05fc..6846b74a 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -278,6 +278,7 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -312,6 +313,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model)
+
load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -336,10 +338,12 @@ def load_model(checkpoint_info=None):
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
-
+
if not sd_model:
sd_model = shared.sd_model
+ current_checkpoint_info = sd_model.sd_checkpoint_info
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
@@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info)
-
- sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
-
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
+ try:
+ load_model_weights(sd_model, checkpoint_info)
+ except Exception as e:
+ print("Failed to load checkpoint, restoring previous")
+ load_model_weights(sd_model, current_checkpoint_info)
+ raise
+ finally:
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
+
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
print("Weights loaded.")
+
return sd_model