From d1f098540ad1dbc2abb8d04322634efba650b631 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 30 Sep 2022 11:42:40 +0300 Subject: remove unwanted formatting/functionality from the PR --- modules/sd_models.py | 56 ++++++++++++++++------------------------------------ 1 file changed, 17 insertions(+), 39 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b9000a4..caa85d5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,10 +12,10 @@ from modules import shared, modelloader from modules.paths import models_path model_dir = "Stable-diffusion" -model_path = os.path.join(models_path, model_dir) +model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_name = "sd-v1-4.ckpt" model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" -user_dir = None +user_dir: (str | None) = None CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} @@ -30,26 +30,8 @@ except Exception: pass -def modeltitle(path, h): - abspath = os.path.abspath(path) - - if abspath.startswith(model_dir): - name = abspath.replace(model_dir, '') - else: - name = os.path.basename(path) - - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - - return f'{name} [{h}]' - - def setup_model(dirname): - global model_path - global model_name - global model_url global user_dir - global model_list user_dir = dirname if not os.path.exists(model_path): os.makedirs(model_path) @@ -62,21 +44,16 @@ def checkpoint_tiles(): def list_models(): - global model_path - global model_url - global model_name - global user_dir checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir, - ext_filter=[".ckpt"], download_name=model_name) - print(f"Model list: {model_list}") - model_dir = os.path.abspath(model_path) + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name) - def modeltitle(path, h): + def modeltitle(path, shorthash): abspath = os.path.abspath(path) - if abspath.startswith(model_dir): - name = abspath.replace(model_dir, '') + if user_dir is not None and abspath.startswith(user_dir): + name = abspath.replace(user_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') else: name = os.path.basename(path) @@ -85,29 +62,30 @@ def list_models(): shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - return f'{name} [{h}]', shortname + return f'{name} [{shorthash}]', shortname cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) - title, model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) + title, short_model_name = modeltitle(cmd_ckpt, h) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) for filename in model_list: h = model_hash(filename) - title, model_name = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) + title, short_model_name = modeltitle(filename, h) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) + def get_closet_checkpoint_match(searchString): applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) - if len(applicable)>0: + if len(applicable) > 0: return applicable[0] return None + def model_hash(filename): try: - print(f"Opening: {filename}") with open(filename, "rb") as file: import hashlib m = hashlib.sha256() @@ -128,7 +106,7 @@ def select_checkpoint(): if len(checkpoints_list) == 0: print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) - print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr) + print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) exit(1) -- cgit v1.2.1