From ac7ecd2d847bf4e3a9503db0f2a291e32b82302c Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Sat, 19 Nov 2022 14:49:22 -0500 Subject: Label and load SD .safetensors model files --- modules/sd_models.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index c59151e0..4ccdf30b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,6 +4,7 @@ import sys import gc from collections import namedtuple import torch +from safetensors.torch import load_file import re from omegaconf import OmegaConf @@ -16,9 +17,10 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config', 'exttype']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() +checkpoint_types = {'.ckpt':'pickle','.safetensors':'safetensors'} try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -45,7 +47,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt",".safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -60,15 +62,15 @@ def list_models(): if name.startswith("\\") or name.startswith("/"): name = name[1:] - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + shortname, ext = os.path.splitext(name.replace("/", "_").replace("\\", "_")) - return f'{name} [{shorthash}]', shortname + return f'{name} [{checkpoint_types[ext]}] [{shorthash}]', shortname cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config, '') shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) @@ -76,12 +78,12 @@ def list_models(): h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - basename, _ = os.path.splitext(filename) + basename, ext = os.path.splitext(filename) config = basename + ".yaml" if not os.path.exists(config): config = shared.cmd_opts.config - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config, ext) def get_closet_checkpoint_match(searchString): @@ -173,7 +175,13 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if(checkpoint_types[checkpoint_info.exttype] == 'safetensors'): + # safely load weights + # TODO: safetensors supports zero copy fast load to gpu, see issue #684 + pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") -- cgit v1.2.1