aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-04 12:35:07 +0300
committerAUTOMATIC <16777216c@gmail.com>2023-01-04 12:35:07 +0300
commit02d7abf5141431b9a3a8a189bb3136c71abd5e79 (patch)
tree6b19b67fab476214ffd5d19a316d4daf8baf1a70
parent7e549468b3d01e6cdf6d07d00c2719c1a5787517 (diff)
helpful error message when trying to load 2.0 without config
failing to load model weights from settings won't break generation for currently loaded model anymore
-rw-r--r--modules/errors.py25
-rw-r--r--modules/sd_models.py26
-rw-r--r--modules/shared.py9
-rw-r--r--webui.py12
4 files changed, 58 insertions, 14 deletions
diff --git a/modules/errors.py b/modules/errors.py
index 372dc51a..a668c014 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -2,9 +2,30 @@ import sys
import traceback
+def print_error_explanation(message):
+ lines = message.strip().split("\n")
+ max_len = max([len(x) for x in lines])
+
+ print('=' * max_len, file=sys.stderr)
+ for line in lines:
+ print(line, file=sys.stderr)
+ print('=' * max_len, file=sys.stderr)
+
+
+def display(e: Exception, task):
+ print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ message = str(e)
+ if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
+ print_error_explanation("""
+The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
+See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
+ """)
+
+
def run(code, task):
try:
code()
except Exception as e:
- print(f"{task}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ display(task, e)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b98b05fc..6846b74a 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -278,6 +278,7 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -312,6 +313,7 @@ def load_model(checkpoint_info=None):
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model)
+
load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -336,10 +338,12 @@ def load_model(checkpoint_info=None):
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
-
+
if not sd_model:
sd_model = shared.sd_model
+ current_checkpoint_info = sd_model.sd_checkpoint_info
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
@@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
- load_model_weights(sd_model, checkpoint_info)
-
- sd_hijack.model_hijack.hijack(sd_model)
- script_callbacks.model_loaded_callback(sd_model)
-
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
- sd_model.to(devices.device)
+ try:
+ load_model_weights(sd_model, checkpoint_info)
+ except Exception as e:
+ print("Failed to load checkpoint, restoring previous")
+ load_model_weights(sd_model, current_checkpoint_info)
+ raise
+ finally:
+ sd_hijack.model_hijack.hijack(sd_model)
+ script_callbacks.model_loaded_callback(sd_model)
+
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ sd_model.to(devices.device)
print("Weights loaded.")
+
return sd_model
diff --git a/modules/shared.py b/modules/shared.py
index 23657a93..7588c47b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -14,7 +14,7 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
-from modules import localization, sd_vae, extensions, script_loading
+from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path
@@ -494,7 +494,12 @@ class Options:
return False
if self.data_labels[key].onchange is not None:
- self.data_labels[key].onchange()
+ try:
+ self.data_labels[key].onchange()
+ except Exception as e:
+ errors.display(e, f"changing setting {key} to {value}")
+ setattr(self, key, oldval)
+ return False
return True
diff --git a/webui.py b/webui.py
index c7d55a97..13375e71 100644
--- a/webui.py
+++ b/webui.py
@@ -9,7 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
-from modules import import_hook
+from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
@@ -61,7 +61,15 @@ def initialize():
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list()
- modules.sd_models.load_model()
+
+ try:
+ modules.sd_models.load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load, exiting", file=sys.stderr)
+ exit(1)
+
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)