From f194457229e4537912467bc60ac3a873f473a63c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 11 Sep 2022 18:48:36 +0300 Subject: CLIP interrogator --- modules/devices.py | 16 +++--- modules/interrogate.py | 142 +++++++++++++++++++++++++++++++++++++++++++++++++ modules/paths.py | 1 + modules/shared.py | 8 +++ modules/ui.py | 18 ++++++- 5 files changed, 177 insertions(+), 8 deletions(-) create mode 100644 modules/interrogate.py (limited to 'modules') diff --git a/modules/devices.py b/modules/devices.py index 25008a04..30d30b99 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,12 +1,16 @@ import torch - # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility has_mps = getattr(torch, 'has_mps', False) +cpu = torch.device("cpu") + + def get_optimal_device(): - if torch.cuda.is_available(): - return torch.device("cuda") - if has_mps: - return torch.device("mps") - return torch.device("cpu") + if torch.cuda.is_available(): + return torch.device("cuda") + + if has_mps: + return torch.device("mps") + + return cpu diff --git a/modules/interrogate.py b/modules/interrogate.py new file mode 100644 index 00000000..ed97a58b --- /dev/null +++ b/modules/interrogate.py @@ -0,0 +1,142 @@ +import os +import sys +import traceback +from collections import namedtuple +import re + +import torch + +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + +import modules.shared as shared +from modules import devices, paths + +blip_image_eval_size = 384 +blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' +clip_model_name = 'ViT-L/14' + +Category = namedtuple("Category", ["name", "topn", "items"]) + +re_topn = re.compile(r"\.top(\d+)\.") + +class InterrogateModels: + blip_model = None + clip_model = None + clip_preprocess = None + categories = None + + def __init__(self, content_dir): + self.categories = [] + + if os.path.exists(content_dir): + for filename in os.listdir(content_dir): + m = re_topn.search(filename) + topn = 1 if m is None else int(m.group(1)) + + with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: + lines = [x.strip() for x in file.readlines()] + + self.categories.append(Category(name=filename, topn=topn, items=lines)) + + def load_blip_model(self): + import models.blip + + blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) + blip_model.eval() + + return blip_model + + def load_clip_model(self): + import clip + + model, preprocess = clip.load(clip_model_name) + model.eval() + model = model.to(shared.device) + + return model, preprocess + + def load(self): + if self.blip_model is None: + self.blip_model = self.load_blip_model() + + self.blip_model = self.blip_model.to(shared.device) + + if self.clip_model is None: + self.clip_model, self.clip_preprocess = self.load_clip_model() + + self.clip_model = self.clip_model.to(shared.device) + + def unload(self): + if not shared.opts.interrogate_keep_models_in_memory: + if self.clip_model is not None: + self.clip_model = self.clip_model.to(devices.cpu) + + if self.blip_model is not None: + self.blip_model = self.blip_model.to(devices.cpu) + + + def rank(self, image_features, text_array, top_count=1): + import clip + + top_count = min(top_count, len(text_array)) + text_tokens = clip.tokenize([text for text in text_array]).cuda() + with torch.no_grad(): + text_features = self.clip_model.encode_text(text_tokens).float() + text_features /= text_features.norm(dim=-1, keepdim=True) + + similarity = torch.zeros((1, len(text_array))).to(shared.device) + for i in range(image_features.shape[0]): + similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) + similarity /= image_features.shape[0] + + top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) + return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] + + + def generate_caption(self, pil_image): + gpu_image = transforms.Compose([ + transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + ])(pil_image).unsqueeze(0).to(shared.device) + + with torch.no_grad(): + caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) + + return caption[0] + + def interrogate(self, pil_image): + res = None + + try: + self.load() + + caption = self.generate_caption(pil_image) + res = caption + + images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device) + + with torch.no_grad(): + image_features = self.clip_model.encode_image(images).float() + + image_features /= image_features.norm(dim=-1, keepdim=True) + + if shared.opts.interrogate_use_builtin_artists: + artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] + + res += ", " + artist[0] + + for name, topn, items in self.categories: + matches = self.rank(image_features, items, top_count=topn) + for match, score in matches: + res += ", " + match + + except Exception: + print(f"Error interrogating", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + self.unload() + + return res diff --git a/modules/paths.py b/modules/paths.py index 130aecb9..97c17a98 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -18,6 +18,7 @@ path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion'), (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'), + (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'), ] paths = {} diff --git a/modules/shared.py b/modules/shared.py index 74b0ad89..9eeb64e3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,6 +11,7 @@ import modules.artists from modules.paths import script_path, sd_path from modules.devices import get_optimal_device import modules.styles +import modules.interrogate config_filename = "config.json" @@ -77,6 +78,8 @@ artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.c styles_filename = os.path.join(script_path, 'styles.csv') prompt_styles = modules.styles.load_styles(styles_filename) +interrogator = modules.interrogate.InterrogateModels("interrogate") + face_restorers = [] class Options: @@ -123,6 +126,11 @@ class Options: "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."), "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), + "interrogate_keep_models_in_memory": OptionInfo(True, "Interrogate: keep models in VRAM"), + "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), + "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), + "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), + "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), } def __init__(self): diff --git a/modules/ui.py b/modules/ui.py index 032c20ff..ebc1ae63 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -242,9 +242,14 @@ def add_style(style_name, text): return [update, update] +def interrogate(image): + prompt = shared.interrogator.interrogate(image) + + return gr_show(True) if prompt is None else prompt + def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Blocks(analytics_enabled=False) as txt2img_interface: - with gr.Row(): + with gr.Row(elem_id="toprow"): txt2img_prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1) negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1) txt2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1) @@ -365,10 +370,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ) with gr.Blocks(analytics_enabled=False) as img2img_interface: - with gr.Row(): + with gr.Row(elem_id="toprow"): img2img_prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1) negative_prompt = gr.Textbox(label="Negative prompt", elem_id="img2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1) img2img_prompt_style = gr.Dropdown(label="Style", show_label=False, elem_id="style_index", choices=[k for k, v in shared.prompt_styles.items()], value=next(iter(shared.prompt_styles.keys())), visible=len(shared.prompt_styles) > 1) + img2img_interrogate = gr.Button('Interrogate', elem_id="img2img_interrogate", variant='primary') submit = gr.Button('Generate', elem_id="img2img_generate", variant='primary') check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False) @@ -461,6 +467,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): inpaint_full_res: gr_show(is_inpaint), inpainting_mask_invert: gr_show(is_inpaint), denoising_strength_change_factor: gr_show(is_loopback), + img2img_interrogate: gr_show(not is_inpaint), } switch_mode.change( @@ -480,6 +487,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): inpaint_full_res, inpainting_mask_invert, denoising_strength_change_factor, + img2img_interrogate, ] ) @@ -540,6 +548,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + check_progress.click( fn=check_progress_call, show_progress=False, -- cgit v1.2.1