From fd66199769ebe0851d2ff33fdc7b191421822454 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 6 Sep 2022 19:33:51 +0300 Subject: added preview option --- modules/ui.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 8 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index 1df74070..8e7a3ee4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -9,6 +9,8 @@ import sys import time import traceback +import numpy as np +import torch from PIL import Image import gradio as gr @@ -119,6 +121,9 @@ def wrap_gradio_call(func): print("Arguments:", args, kwargs, file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + shared.state.job = "" + shared.state.job_count = 0 + res = [None, '', f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] elapsed = time.perf_counter() - t @@ -134,11 +139,9 @@ def wrap_gradio_call(func): def check_progress_call(): - if not opts.show_progressbar: - return "" if shared.state.job_count == 0: - return "" + return "", gr_show(False), gr_show(False) progress = 0 @@ -149,9 +152,29 @@ def check_progress_call(): progress = min(progress, 1) - progressbar = f"""
{str(int(progress*100))+"%" if progress > 0.01 else ""}
""" + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{str(int(progress*100))+"%" if progress > 0.01 else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps > 0: + if (shared.state.sampling_step-1) % opts.show_progress_every_n_steps == 0 and shared.state.current_latent is not None: + x_sample = shared.sd_model.decode_first_stage(shared.state.current_latent[0:1].type(shared.sd_model.dtype))[0] + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + shared.state.current_image = Image.fromarray(x_sample) - return f"{time.time()}

{progressbar}

" + image = shared.state.current_image + + if image is None or progress >= 1: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + return f"{time.time()}

{progressbar}

", preview_visibility, image def roll_artist(prompt): @@ -204,6 +227,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Column(variant='panel'): with gr.Group(): + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery') @@ -251,8 +275,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): check_progress.click( fn=check_progress_call, + show_progress=False, inputs=[], - outputs=[progressbar], + outputs=[progressbar, txt2img_preview, txt2img_preview], ) @@ -337,13 +362,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Column(variant='panel'): with gr.Group(): + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery') with gr.Group(): with gr.Row(): - interrupt = gr.Button('Interrupt') save = gr.Button('Save') + img2img_send_to_img2img = gr.Button('Send to img2img') + img2img_send_to_inpaint = gr.Button('Send to inpaint') img2img_send_to_extras = gr.Button('Send to extras') + interrupt = gr.Button('Interrupt') progressbar = gr.HTML(elem_id="progressbar") @@ -426,8 +454,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): check_progress.click( fn=check_progress_call, + show_progress=False, inputs=[], - outputs=[progressbar], + outputs=[progressbar, img2img_preview, img2img_preview], ) interrupt.click( @@ -463,6 +492,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): outputs=[init_img_with_mask], ) + img2img_send_to_img2img.click( + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", + inputs=[img2img_gallery], + outputs=[init_img], + ) + + img2img_send_to_inpaint.click( + fn=lambda x: image_from_url_text(x), + _js="extract_image_from_gallery", + inputs=[img2img_gallery], + outputs=[init_img_with_mask], + ) + with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): -- cgit v1.2.1