From 3bca90b249d749ed5429f76e380d2ffa52fc0d41 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 30 Jul 2023 13:48:27 +0300 Subject: hires fix checkpoint selection --- modules/sd_models.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..cb67e425 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -52,6 +52,7 @@ class CheckpointInfo: self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) @@ -81,6 +82,7 @@ class CheckpointInfo: checkpoints_list.pop(self.title) self.title = f'{self.name} [{self.shorthash}]' + self.short_title = f'{self.name_for_extra} [{self.shorthash}]' self.register() return self.shorthash @@ -101,14 +103,8 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() - - def alphanumeric_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] - - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) +def checkpoint_tiles(use_short=False): + return [x.short_title if use_short else x.title for x in checkpoints_list.values()] def list_models(): @@ -131,11 +127,14 @@ def list_models(): elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in sorted(model_list, key=str.lower): + for filename in model_list: checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: @@ -145,6 +144,11 @@ def get_closet_checkpoint_match(search_string): if found: return found[0] + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + return None -- cgit v1.2.1