From 345028099d893f8a66726cfd13627d8cc1bcc724 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 3 Sep 2022 12:08:45 +0300 Subject: split codebase into multiple files; to anyone this affects negatively: sorry --- modules/ui.py | 539 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 539 insertions(+) create mode 100644 modules/ui.py (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py new file mode 100644 index 00000000..5223179f --- /dev/null +++ b/modules/ui.py @@ -0,0 +1,539 @@ +import base64 +import html +import io +import json +import mimetypes +import os +import sys +import time +import traceback + +from PIL import Image + +import gradio as gr +import gradio.utils + +from modules.paths import script_path +from modules.shared import opts, cmd_opts +import modules.shared as shared +from modules.sd_samplers import samplers, samplers_for_img2img +import modules.gfpgan_model as gfpgan +import modules.realesrgan_model as realesrgan + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + + +if not cmd_opts.share: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +""" + + +def plaintext_to_html(text): + text = "".join([f"

{html.escape(x)}

\n" for x in text.split('\n')]) + return text + + +def image_from_url_text(filedata): + if type(filedata) == list: + if len(filedata) == 0: + return None + + filedata = filedata[0] + + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] + + filedata = base64.decodebytes(filedata.encode('utf-8')) + image = Image.open(io.BytesIO(filedata)) + return image + + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + + return image_from_url_text(x[0]) + + +def save_files(js_data, images): + import csv + + os.makedirs(opts.outdir_save, exist_ok=True) + + filenames = [] + + data = json.loads(js_data) + + with open("log/log.csv", "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename"]) + + filename_base = str(int(time.time() * 1000)) + for i, filedata in enumerate(images): + filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png" + filepath = os.path.join(opts.outdir_save, filename) + + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] + + with open(filepath, "wb") as imgfile: + imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) + + filenames.append(filename) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0]]) + + return '', '', plaintext_to_html(f"Saved: {filenames[0]}") + + +def wrap_gradio_call(func): + def f(*args, **kwargs): + t = time.perf_counter() + + try: + res = list(func(*args, **kwargs)) + except Exception as e: + print("Error completing request", file=sys.stderr) + print("Arguments:", args, kwargs, file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + res = [None, '', f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + + elapsed = time.perf_counter() - t + + # last item is always HTML + res[-1] = res[-1] + f"

Time taken: {elapsed:.2f}s

" + + shared.state.interrupted = False + + return tuple(res) + + return f + + +def create_ui(opts, cmd_opts, txt2img, img2img, run_extras, run_pnginfo): + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id="txt2img_prompt", show_label=False, placeholder="Prompt", lines=1) + negative_prompt = gr.Textbox(label="Negative prompt", elem_id="txt2img_negative_prompt", show_label=False, placeholder="Negative prompt", lines=1, visible=False) + submit = gr.Button('Generate', elem_id="txt2img_generate", variant='primary') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") + + with gr.Row(): + use_GFPGAN = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan) + prompt_matrix = gr.Checkbox(label='Prompt matrix', value=False) + + with gr.Row(): + batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + + cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0) + + with gr.Group(): + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + + seed = gr.Number(label='Seed', value=-1) + + code = gr.Textbox(label="Python script", visible=cmd_opts.allow_code, lines=1) + + with gr.Column(variant='panel'): + with gr.Group(): + txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery') + + with gr.Group(): + with gr.Row(): + save = gr.Button('Save') + send_to_img2img = gr.Button('Send to img2img') + send_to_inpaint = gr.Button('Send to inpaint') + send_to_extras = gr.Button('Send to extras') + interrupt = gr.Button('Interrupt') + + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) + + txt2img_args = dict( + fn=txt2img, + inputs=[ + prompt, + negative_prompt, + steps, + sampler_index, + use_GFPGAN, + prompt_matrix, + batch_count, + batch_size, + cfg_scale, + seed, + height, + width, + code + ], + outputs=[ + txt2img_gallery, + generation_info, + html_info + ] + ) + + prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + save.click( + fn=wrap_gradio_call(save_files), + inputs=[ + generation_info, + txt2img_gallery, + ], + outputs=[ + html_info, + html_info, + html_info, + ] + ) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id="img2img_prompt", show_label=False, placeholder="Prompt", lines=1) + submit = gr.Button('Generate', elem_id="img2img_generate", variant='primary') + + with gr.Row().style(equal_height=False): + + with gr.Column(variant='panel'): + with gr.Group(): + switch_mode = gr.Radio(label='Mode', elem_id="img2img_mode", choices=['Redraw whole image', 'Inpaint a part of image', 'Loopback', 'SD upscale'], value='Redraw whole image', type="index", show_label=False) + init_img = gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil") + init_img_with_mask = gr.Image(label="Image for inpainting with mask", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", visible=False) + resize_mode = gr.Radio(label="Resize mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") + + steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) + sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False) + inpainting_fill = gr.Radio(label='Msked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False) + + with gr.Row(): + use_GFPGAN = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan) + prompt_matrix = gr.Checkbox(label='Prompt matrix', value=False) + inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=True, visible=False) + + with gr.Row(): + sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=list(shared.sd_upscalers.keys()), value=list(shared.sd_upscalers.keys())[0], visible=False) + sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False) + + with gr.Row(): + batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + + with gr.Group(): + cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0) + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75) + + with gr.Group(): + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) + + seed = gr.Number(label='Seed', value=-1) + + with gr.Column(variant='panel'): + with gr.Group(): + img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery') + + with gr.Group(): + with gr.Row(): + interrupt = gr.Button('Interrupt') + save = gr.Button('Save') + img2img_send_to_extras = gr.Button('Send to extras') + + with gr.Group(): + html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) + + def apply_mode(mode): + is_classic = mode == 0 + is_inpaint = mode == 1 + is_loopback = mode == 2 + is_upscale = mode == 3 + + return { + init_img: gr_show(not is_inpaint), + init_img_with_mask: gr_show(is_inpaint), + mask_blur: gr_show(is_inpaint), + inpainting_fill: gr_show(is_inpaint), + prompt_matrix: gr_show(is_classic), + batch_count: gr_show(not is_upscale), + batch_size: gr_show(not is_loopback), + sd_upscale_upscaler_name: gr_show(is_upscale), + sd_upscale_overlap:gr_show(is_upscale), + inpaint_full_res: gr_show(is_inpaint), + } + + switch_mode.change( + apply_mode, + inputs=[switch_mode], + outputs=[ + init_img, + init_img_with_mask, + mask_blur, + inpainting_fill, + prompt_matrix, + batch_count, + batch_size, + sd_upscale_upscaler_name, + sd_upscale_overlap, + inpaint_full_res, + ] + ) + + img2img_args = dict( + fn=img2img, + inputs=[ + prompt, + init_img, + init_img_with_mask, + steps, + sampler_index, + mask_blur, + inpainting_fill, + use_GFPGAN, + prompt_matrix, + switch_mode, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + height, + width, + resize_mode, + sd_upscale_upscaler_name, + sd_upscale_overlap, + inpaint_full_res, + ], + outputs=[ + img2img_gallery, + generation_info, + html_info + ] + ) + + prompt.submit(**img2img_args) + submit.click(**img2img_args) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + save.click( + fn=wrap_gradio_call(save_files), + inputs=[ + generation_info, + img2img_gallery, + ], + outputs=[ + html_info, + html_info, + html_info, + ] + ) + + send_to_img2img.click( + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", + inputs=[txt2img_gallery], + outputs=[init_img], + ) + + send_to_inpaint.click( + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", + inputs=[txt2img_gallery], + outputs=[init_img_with_mask], + ) + + + + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Group(): + image = gr.Image(label="Source", source="upload", interactive=True, type="pil") + gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=gfpgan.have_gfpgan) + realesrgan_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=realesrgan.have_realesrgan) + realesrgan_model = gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan.realesrgan_models], value=realesrgan.realesrgan_models[0].name, type="index", interactive=realesrgan.have_realesrgan) + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Column(variant='panel'): + result_image = gr.Image(label="Result") + html_info_x = gr.HTML() + html_info = gr.HTML() + + extras_args = dict( + fn=run_extras, + inputs=[ + image, + gfpgan_strength, + realesrgan_resize, + realesrgan_model, + ], + outputs=[ + result_image, + html_info_x, + html_info, + ] + ) + + submit.click(**extras_args) + + + send_to_extras.click( + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", + inputs=[txt2img_gallery], + outputs=[image], + ) + + img2img_send_to_extras.click( + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", + inputs=[img2img_gallery], + outputs=[image], + ) + + + pnginfo_interface = gr.Interface( + wrap_gradio_call(run_pnginfo), + inputs=[ + gr.Image(label="Source", source="upload", interactive=True, type="pil"), + ], + outputs=[ + gr.HTML(), + gr.HTML(), + gr.HTML(), + ], + allow_flagging="never", + analytics_enabled=False, + ) + + + def create_setting_component(key): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + if info.component is not None: + item = info.component(label=info.label, value=fun, **(info.component_args or {})) + elif t == str: + item = gr.Textbox(label=info.label, value=fun, lines=1) + elif t == int: + item = gr.Number(label=info.label, value=fun) + elif t == bool: + item = gr.Checkbox(label=info.label, value=fun) + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + return item + + def run_settings(*args): + up = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, settings_interface.input_components): + opts.data[key] = value + up.append(comp.update(value=value)) + + opts.save(shared.config_filename) + + return 'Settings saved.', '', '' + + settings_interface = gr.Interface( + run_settings, + inputs=[create_setting_component(key) for key in opts.data_labels.keys()], + outputs=[ + gr.Textbox(label='Result'), + gr.HTML(), + gr.HTML(), + ], + title=None, + description=None, + allow_flagging="never", + analytics_enabled=False, + ) + + interfaces = [ + (txt2img_interface, "txt2img"), + (img2img_interface, "img2img"), + (extras_interface, "Extras"), + (pnginfo_interface, "PNG Info"), + (settings_interface, "Settings"), + ] + + with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: + css = file.read() + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + demo = gr.TabbedInterface( + interface_list=[x[0] for x in interfaces], + tab_names=[x[1] for x in interfaces], + analytics_enabled=False, + css=css, + ) + + return demo + + +with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as file: + javascript = file.read() + +def inject_gradio_html(javascript): + import gradio.routes + + def template_response(*args, **kwargs): + res = gradio_routes_templates_response(*args, **kwargs) + res.body = res.body.replace(b'', f''.encode("utf8")) + res.init_headers() + return res + + gradio_routes_templates_response = gradio.routes.templates.TemplateResponse + gradio.routes.templates.TemplateResponse = template_response + + +inject_gradio_html(javascript) -- cgit v1.2.1