From 6d805b669e86233432f56ee1892d062103abe501 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 09:14:27 +0300 Subject: make CLIP interrogator download original text files if the directory does not exist remove random artist built-in extension (to re-added as a normal extension on demand) remove artists.csv (but what does it mean????????????????????) make interrogate buttons show Loading... when you click them --- modules/api/api.py | 8 -------- modules/artists.py | 25 ----------------------- modules/interrogate.py | 55 +++++++++++++++++++++++++++++++++++++------------- modules/shared.py | 5 ----- modules/ui.py | 11 +++++----- 5 files changed, 46 insertions(+), 58 deletions(-) delete mode 100644 modules/artists.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 2c371e6e..f2e9e884 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -126,8 +126,6 @@ class Api: self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) @@ -390,12 +388,6 @@ class Api: return styleList - def get_artists_categories(self): - return shared.artist_db.cats - - def get_artists(self): - return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] - def get_embeddings(self): db = sd_hijack.model_hijack.embedding_db diff --git a/modules/artists.py b/modules/artists.py deleted file mode 100644 index 3612758b..00000000 --- a/modules/artists.py +++ /dev/null @@ -1,25 +0,0 @@ -import os.path -import csv -from collections import namedtuple - -Artist = namedtuple("Artist", ['name', 'weight', 'category']) - - -class ArtistsDatabase: - def __init__(self, filename): - self.cats = set() - self.artists = [] - - if not os.path.exists(filename): - return - - with open(filename, "r", newline='', encoding="utf8") as file: - reader = csv.DictReader(file) - - for row in reader: - artist = Artist(row["artist"], float(row["score"]), row["category"]) - self.artists.append(artist) - self.cats.add(artist.category) - - def categories(self): - return sorted(self.cats) diff --git a/modules/interrogate.py b/modules/interrogate.py index 738d8ff7..19938cbb 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -5,12 +5,13 @@ from collections import namedtuple import re import torch +import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader +from modules import devices, paths, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +def download_default_clip_interrogate_categories(content_dir): + print("Downloading CLIP categories...") + + tmpdir = content_dir + "_tmp" + try: + os.makedirs(tmpdir) + + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) + + os.rename(tmpdir, content_dir) + + except Exception as e: + errors.display(e, "downloading default CLIP interrogate categories") + finally: + if os.path.exists(tmpdir): + os.remove(tmpdir) + + class InterrogateModels: blip_model = None clip_model = None clip_preprocess = None - categories = None dtype = None running_on_cpu = None def __init__(self, content_dir): - self.categories = [] + self.loaded_categories = None + self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") - if os.path.exists(content_dir): - for filename in os.listdir(content_dir): + def categories(self): + if self.loaded_categories is not None: + return self.loaded_categories + + self.loaded_categories = [] + + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if os.path.exists(self.content_dir): + for filename in os.listdir(self.content_dir): m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) - with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: + with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + + return self.loaded_categories def load_blip_model(self): import models.blip @@ -139,7 +172,6 @@ class InterrogateModels: shared.state.begin() shared.state.job = 'interrogate' try: - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() devices.torch_gc() @@ -159,12 +191,7 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True) - if shared.opts.interrogate_use_builtin_artists: - artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] - - res += ", " + artist[0] - - for name, topn, items in self.categories: + for name, topn, items in self.categories(): matches = self.rank(image_features, items, top_count=topn) for match, score in matches: if shared.opts.interrogate_return_ranks: diff --git a/modules/shared.py b/modules/shared.py index c0e11f18..72fb1934 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -9,7 +9,6 @@ from PIL import Image import gradio as gr import tqdm -import modules.artists import modules.interrogate import modules.memmon import modules.styles @@ -254,8 +253,6 @@ class State: state = State() state.server_start = time.time() -artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) - styles_filename = cmd_opts.styles_file prompt_styles = modules.styles.StyleDatabase(styles_filename) @@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), - "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) options_templates.update(options_section(('compatibility', "Compatibility"), { @@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), - "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), diff --git a/modules/ui.py b/modules/ui.py index d23b2b8e..164e0e93 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di left, _ = os.path.splitext(filename) print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a')) - return [gr_show(True), None] + return [gr.update(), None] def interrogate(image): prompt = shared.interrogator.interrogate(image.convert("RGB")) - return gr_show(True) if prompt is None else prompt + return gr.update() if prompt is None else prompt def interrogate_deepbooru(image): prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt + return gr.update() if prompt is None else prompt def create_seed_inputs(target_interface): @@ -1039,19 +1039,18 @@ def create_ui(): init_img_inpaint, ], outputs=[img2img_prompt, dummy_component], - show_progress=False, ) img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) img2img_interrogate.click( - fn=lambda *args : process_interrogate(interrogate, *args), + fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, ) img2img_deepbooru.click( - fn=lambda *args : process_interrogate(interrogate_deepbooru, *args), + fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args, ) -- cgit v1.2.1