From 11875f586323cea7c5b8398976449788a83dee76 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 27 Sep 2022 11:01:13 -0500 Subject: Use model loader with stable-diffusion too. Hook the model loader into the SD_models file. Add default url/download if checkpoint is not found. Add matching stablediffusion-models-path argument. Add message that --ckpt-dir will be removed in the future, but have it pipe to stablediffusion-models-path for now. Update help strings for models-path args so they're more or less uniform. Move sd_model "setup" call to webUI with the others. Ensure "cleanup_models" method moves existing models to the new locations, including SD, and that we aren't deleting folders that still have stuff in them. --- modules/sd_models.py | 48 +++++++++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 19 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index dc81b0dc..89b7d276 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,13 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared +from modules import shared, modelloader +from modules.paths import models_path + +model_dir = "Stable-diffusion" +model_path = os.path.join(models_path, model_dir) +model_name = "sd-v1-4.ckpt" +model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash']) checkpoints_list = {} @@ -23,23 +29,28 @@ except Exception: pass -def list_models(): - checkpoints_list.clear() +def modeltitle(path, h): + abspath = os.path.abspath(path) - model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) + if abspath.startswith(model_dir): + name = abspath.replace(model_dir, '') + else: + name = os.path.basename(path) - def modeltitle(path, h): - abspath = os.path.abspath(path) + if name.startswith("\\") or name.startswith("/"): + name = name[1:] - if abspath.startswith(model_dir): - name = abspath.replace(model_dir, '') - else: - name = os.path.basename(path) + return f'{name} [{h}]' - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - return f'{name} [{h}]' +def setup_model(dirname): + global model_path + global model_name + global model_url + if not os.path.exists(model_path): + os.makedirs(model_path) + checkpoints_list.clear() + model_list = modelloader.load_models(model_path, model_url, dirname, model_name, ext_filter=".ckpt") cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): @@ -47,13 +58,12 @@ def list_models(): title = modeltitle(cmd_ckpt, h) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h) elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: - print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr) + print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - if os.path.exists(model_dir): - for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True): - h = model_hash(filename) - title = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h) + for filename in model_list: + h = model_hash(filename) + title = modeltitle(filename, h) + checkpoints_list[title] = CheckpointInfo(filename, title, h) def model_hash(filename): -- cgit v1.2.1