aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py151
1 files changed, 117 insertions, 34 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 7a866a07..841402e8 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,22 +1,22 @@
import collections
import os.path
import sys
-import gc
import threading
import torch
import re
import safetensors.torch
-from omegaconf import OmegaConf
+from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
+import numpy as np
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -27,15 +27,34 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict()
+def replace_key(d, key, new_key, value):
+ keys = list(d.keys())
+
+ d[new_key] = value
+
+ if key not in keys:
+ return d
+
+ index = keys.index(key)
+ keys[index] = new_key
+
+ new_d = {k: d[k] for k in keys}
+
+ d.clear()
+ d.update(new_d)
+ return d
+
+
class CheckpointInfo:
def __init__(self, filename):
self.filename = filename
abspath = os.path.abspath(filename)
+ abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
+ if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
+ name = abspath.replace(abs_ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
@@ -91,9 +110,11 @@ class CheckpointInfo:
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
- checkpoints_list.pop(self.title, None)
+ old_title = self.title
self.title = f'{self.name} [{self.shorthash}]'
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
+
+ replace_key(checkpoints_list, old_title, self.title, self)
self.register()
return self.shorthash
@@ -109,9 +130,12 @@ except Exception:
def setup_model():
+ """called once at startup to do various one-time tasks related to SD models"""
+
os.makedirs(model_path, exist_ok=True)
enable_midas_autodownload()
+ patch_given_betas()
def checkpoint_tiles(use_short=False):
@@ -147,6 +171,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
def get_closet_checkpoint_match(search_string):
+ if not search_string:
+ return None
+
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -286,6 +313,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
print(f"Loading weights [{sd_model_hash}] from cache")
+ # move to end as latest
+ checkpoints_loaded.move_to_end(checkpoint_info)
return checkpoints_loaded[checkpoint_info]
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
@@ -295,11 +324,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return res
+class SkipWritingToConfig:
+ """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
+
+ skip = False
+ previous = None
+
+ def __enter__(self):
+ self.previous = SkipWritingToConfig.skip
+ SkipWritingToConfig.skip = True
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ SkipWritingToConfig.skip = self.previous
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
+ if not SkipWritingToConfig.skip:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
@@ -307,16 +352,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
-
+ model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
- model.load_state_dict(state_dict, strict=False)
- timer.record("apply weights to model")
+ if model.is_ssd:
+ sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = state_dict
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
+
+ model.load_state_dict(state_dict, strict=False)
+ timer.record("apply weights to model")
del state_dict
@@ -324,7 +372,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
- if not shared.cmd_opts.no_half:
+ if shared.cmd_opts.no_half:
+ model.float()
+ devices.dtype_unet = torch.float32
+ timer.record("apply float()")
+ else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
@@ -340,9 +392,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if depth_model:
model.depth_model = depth_model
+ devices.dtype_unet = torch.float16
timer.record("apply half()")
- devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -410,6 +462,20 @@ def enable_midas_autodownload():
midas.api.load_model = load_model_wrapper
+def patch_given_betas():
+ import ldm.models.diffusion.ddpm
+
+ def patched_register_schedule(*args, **kwargs):
+ """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
+
+ if isinstance(args[1], ListConfig):
+ args = (args[0], np.array(args[1]), *args[2:])
+
+ original_register_schedule(*args, **kwargs)
+
+ original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
+
+
def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"):
@@ -463,8 +529,12 @@ class SdModelData:
return self.sd_model
- def set_sd_model(self, v):
+ def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
+ if already_loaded:
+ sd_vae.base_vae = getattr(v, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
try:
self.loaded_sd_models.remove(v)
@@ -491,7 +561,7 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
@@ -499,10 +569,17 @@ def send_model_to_cpu(m):
devices.torch_gc()
-def send_model_to_device(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
+def model_target_device(m):
+ if lowvram.is_needed(m):
+ return devices.cpu
else:
+ return devices.device
+
+
+def send_model_to_device(m):
+ lowvram.apply(m)
+
+ if not m.lowvram:
m.to(shared.device)
@@ -560,7 +637,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("create model")
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ if shared.cmd_opts.no_half:
+ weight_dtype_conversion = None
+ else:
+ weight_dtype_conversion = {
+ 'first_stage_model': None,
+ '': torch.float16,
+ }
+
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
@@ -623,10 +708,14 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
send_model_to_device(already_loaded)
timer.record("send model to device")
- model_data.set_sd_model(already_loaded)
- shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
- shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
+ model_data.set_sd_model(already_loaded, already_loaded=True)
+
+ if not SkipWritingToConfig.skip:
+ shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
+ shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
+
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
@@ -638,6 +727,10 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
+
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
@@ -694,7 +787,7 @@ def reload_model_weights(sd_model=None, info=None):
script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
@@ -707,17 +800,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- timer = Timer()
-
- if model_data.sd_model:
- model_data.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
- model_data.sd_model = None
- sd_model = None
- gc.collect()
- devices.torch_gc()
-
- print(f"Unloaded weights {timer.summary()}.")
+ send_model_to_cpu(sd_model or shared.sd_model)
return sd_model