From 98466da4bc312c0fa9c8cea4c825afc64194cb58 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Fri, 20 Jan 2023 00:48:15 -0500 Subject: adds descriptions for merging methods in ui --- modules/ui.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index eb45a128..ee434bde 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1190,10 +1190,19 @@ def create_ui(): outputs=[html, generation_info, html2], ) + def update_interp_description(value): + interp_description_css = "

{}

" + interp_descriptions = { + "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), + "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), + "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") + } + return interp_descriptions[value] + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='compact'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") with FormRow(elem_id="modelmerger_models"): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") @@ -1208,6 +1217,7 @@ def create_ui(): custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description]) with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") @@ -1903,6 +1913,9 @@ def create_ui(): with open(ui_config_file, "w", encoding="utf8") as file: json.dump(ui_settings, file, indent=4) + # Required as a workaround for change() event not triggering when loading values from ui-config.json + interp_description.value = update_interp_description(interp_method.value) + return demo -- cgit v1.2.1 From 3262e825cc542ff634e6ba2e3a162eafdc6c1bba Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 21 Jan 2023 17:42:04 +0900 Subject: add --xformers-flash-attention option & impl --- modules/sd_hijack_optimizations.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + 2 files changed, 25 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 4fa54329..9967359b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -290,7 +290,19 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + # print('xformers_attention_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -365,7 +377,17 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v) + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)): + # print('xformers_attnblock_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + out = xformers.ops.memory_efficient_attention(q, k, v, op=op) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 72fb1934..23328adf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -57,6 +57,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") +parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") -- cgit v1.2.1 From f726df8a2fd2620ba245de5702e2afe620f79b91 Mon Sep 17 00:00:00 2001 From: James Tolton Date: Sat, 21 Jan 2023 12:59:05 -0500 Subject: Compile and serve js from /statica instead of inline in html --- modules/ui.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index fbc3efa0..d19eaf25 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -10,6 +10,7 @@ import sys import tempfile import time import traceback +from collections import OrderedDict from functools import partial, reduce import warnings @@ -1918,27 +1919,51 @@ def create_ui(): def reload_javascript(): + javascript_files = OrderedDict() with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' + contents = jsfile.read() + javascript_files["script.js"] = [contents] + # javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" + contents = jsfile.read() + javascript_files[filename] = [contents] + # javascript += f"\n" if cmd_opts.theme is not None: - javascript += f"\n\n" + javascript_files["theme.js"] = [f"set_theme('{cmd_opts.theme}');"] + # javascript += f"\n\n" - javascript += f"\n" + # javascript += f"\n" + javascript_files["localization.js"] = [f"{localization.localization_js(shared.opts.localization)}"] + + compiled_name = "webui-compiled.js" + head = f""" + + """ def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) + b'', f'{head}'.encode("utf8")) res.init_headers() return res + for k in javascript_files: + javascript_files[k] = "\n".join(javascript_files[k]) + + # make static_path if not exists + statica_path = os.path.join(script_path, 'statica') + if not os.path.exists(statica_path): + os.mkdir(statica_path) + + javascript_out = "\n\n\n".join([f"// \n\n{v}" for k, v in javascript_files.items()]) + with open(os.path.join(script_path, "statica", compiled_name), "w", encoding="utf8") as jsfile: + jsfile.write(javascript_out) + gradio.routes.templates.TemplateResponse = template_response -- cgit v1.2.1 From 17af0fb95574068a1d5032ae96879dab145e173a Mon Sep 17 00:00:00 2001 From: James Tolton Date: Sat, 21 Jan 2023 13:27:05 -0500 Subject: remove commented out lines --- modules/ui.py | 4 ---- 1 file changed, 4 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index d19eaf25..ef85d43c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1923,7 +1923,6 @@ def reload_javascript(): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: contents = jsfile.read() javascript_files["script.js"] = [contents] - # javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") @@ -1931,13 +1930,10 @@ def reload_javascript(): with open(path, "r", encoding="utf8") as jsfile: contents = jsfile.read() javascript_files[filename] = [contents] - # javascript += f"\n" if cmd_opts.theme is not None: javascript_files["theme.js"] = [f"set_theme('{cmd_opts.theme}');"] - # javascript += f"\n\n" - # javascript += f"\n" javascript_files["localization.js"] = [f"{localization.localization_js(shared.opts.localization)}"] compiled_name = "webui-compiled.js" -- cgit v1.2.1 From 50059ea661b63967b217e687819cf7a9081e4a0c Mon Sep 17 00:00:00 2001 From: James Tolton Date: Sat, 21 Jan 2023 14:07:48 -0500 Subject: server individually listed javascript files vs single compiled file --- modules/ui.py | 52 +++++++++++++++++++++++----------------------------- 1 file changed, 23 insertions(+), 29 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index ef85d43c..b372d29c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1919,47 +1919,41 @@ def create_ui(): def reload_javascript(): - javascript_files = OrderedDict() - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - contents = jsfile.read() - javascript_files["script.js"] = [contents] - scripts_list = modules.scripts.list_scripts("javascript", ".js") - + js_files = [] for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - contents = jsfile.read() - javascript_files[filename] = [contents] + path = path[len(script_path) + 1:] + js_files.append(path) + inline = [f"{localization.localization_js(shared.opts.localization)};"] if cmd_opts.theme is not None: - javascript_files["theme.js"] = [f"set_theme('{cmd_opts.theme}');"] + inline.append(f"set_theme('{cmd_opts.theme}');", ) - javascript_files["localization.js"] = [f"{localization.localization_js(shared.opts.localization)}"] - - compiled_name = "webui-compiled.js" - head = f""" - - """ + t = int(time.time()) + head = [ + f""" + + """.strip() + ] + inline_code = "\n".join(inline) + head.append(f""" + + """.strip()) + for file in js_files: + head.append(f""" + + """.strip()) def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + head_inject = "\n".join(head) res.body = res.body.replace( - b'', f'{head}'.encode("utf8")) + b'', f'{head_inject}'.encode("utf8")) res.init_headers() return res - for k in javascript_files: - javascript_files[k] = "\n".join(javascript_files[k]) - - # make static_path if not exists - statica_path = os.path.join(script_path, 'statica') - if not os.path.exists(statica_path): - os.mkdir(statica_path) - - javascript_out = "\n\n\n".join([f"// \n\n{v}" for k, v in javascript_files.items()]) - with open(os.path.join(script_path, "statica", compiled_name), "w", encoding="utf8") as jsfile: - jsfile.write(javascript_out) - gradio.routes.templates.TemplateResponse = template_response -- cgit v1.2.1 From 035459c9a22bebcf68ac454a1f178fefe8c82054 Mon Sep 17 00:00:00 2001 From: James Tolton Date: Sat, 21 Jan 2023 14:11:13 -0500 Subject: remove dead import --- modules/ui.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b372d29c..5fde7fc5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -10,7 +10,6 @@ import sys import tempfile import time import traceback -from collections import OrderedDict from functools import partial, reduce import warnings -- cgit v1.2.1 From e4e0918f58d382b5da400e680d743dcf0e66fd7f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 22:57:19 +0300 Subject: remove timestamp for js files, reformat code --- modules/ui.py | 34 ++++++++-------------------------- 1 file changed, 8 insertions(+), 26 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index b5581a06..ef7becc6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1918,38 +1918,20 @@ def create_ui(): def reload_javascript(): - scripts_list = modules.scripts.list_scripts("javascript", ".js") - js_files = [] - for basedir, filename, path in scripts_list: - path = path[len(script_path) + 1:] - js_files.append(path) + head = f'\n' - inline = [f"{localization.localization_js(shared.opts.localization)};"] + inline = f"{localization.localization_js(shared.opts.localization)};" if cmd_opts.theme is not None: - inline.append(f"set_theme('{cmd_opts.theme}');", ) + inline += f"set_theme('{cmd_opts.theme}');" - t = int(time.time()) - head = [ - f""" - - """.strip() - ] - inline_code = "\n".join(inline) - head.append(f""" - - """.strip()) - for file in js_files: - head.append(f""" - - """.strip()) + head += f'\n' + + for script in modules.scripts.list_scripts("javascript", ".js"): + head += f'\n' def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - head_inject = "\n".join(head) - res.body = res.body.replace( - b'', f'{head_inject}'.encode("utf8")) + res.body = res.body.replace(b'', f'{head}'.encode("utf8")) res.init_headers() return res -- cgit v1.2.1 From 4a8fe09652b304034708d967c47901312940e852 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 23:06:18 +0300 Subject: remove the double loading text --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index ef7becc6..aa39a713 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -75,6 +75,7 @@ css_hide_progressbar = """ .wrap .m-12::before { content:"Loading..." } .wrap .z-20 svg { display:none!important; } .wrap .z-20::before { content:"Loading..." } +.wrap.cover-bg .z-20::before { content:"" } .progress-bar { display:none!important; } .meta-text { display:none!important; } .meta-text-center { display:none!important; } -- cgit v1.2.1 From 78f59a4e014d090bce7df3b218bfbcd7f11e0894 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 23:40:13 +0300 Subject: enable compact view for train tab prevent previews from ruining hypernetwork training --- modules/hypernetworks/hypernetwork.py | 2 ++ modules/processing.py | 8 ++++++-- modules/ui.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 80a47c79..503534e2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -715,6 +715,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi do_not_save_samples=True, ) + p.disable_extra_networks = True + if preview_from_txt2img: p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt diff --git a/modules/processing.py b/modules/processing.py index 6e6371a1..bc541e2f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -140,6 +140,7 @@ class StableDiffusionProcessing: self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings_restore_afterwards = override_settings_restore_afterwards self.is_using_inpainting_conditioning = False + self.disable_extra_networks = False if not seed_enable_extras: self.subseed = -1 @@ -567,7 +568,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) - extra_networks.activate(p, extra_network_data) + if not p.disable_extra_networks: + extra_networks.activate(p, extra_network_data) with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") @@ -684,7 +686,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - extra_networks.deactivate(p, extra_network_data) + if not p.disable_extra_networks: + extra_networks.deactivate(p, extra_network_data) + devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/ui.py b/modules/ui.py index daebbc9f..af6dfb21 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1259,7 +1259,7 @@ def create_ui(): with gr.Row().style(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") - with gr.Row().style(equal_height=False): + with gr.Row(variant="compact").style(equal_height=False): with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding"): -- cgit v1.2.1 From fe7a623e6b7e04bab2cfc96e8fd6cf49b48daee1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 00:02:41 +0300 Subject: add a slider for default value of added extra networks --- modules/shared.py | 5 +++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_hypernets.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 52bbb807..00a1d64c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -398,7 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), + "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), @@ -406,7 +406,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), - 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e2e060c8..4c88193f 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -54,7 +54,7 @@ class ExtraNetworksPage: args = { "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', - "prompt": json.dumps(item["prompt"]), + "prompt": item["prompt"], "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 312dbaf0..65d000cf 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -1,3 +1,4 @@ +import json import os from modules import shared, ui_extra_networks @@ -25,7 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } -- cgit v1.2.1 From f2eae6127d16a80d1516d3f6245b648eeca26330 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 00:16:26 +0300 Subject: fix broken textual inversion extras tab --- modules/ui_extra_networks_textual_inversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index e4a6e3bf..dbd23d2d 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,3 +1,4 @@ +import json import os from modules import ui_extra_networks, sd_hijack @@ -24,7 +25,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "name": embedding.name, "filename": embedding.filename, "preview": preview, - "prompt": embedding.name, + "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", } -- cgit v1.2.1 From bf457b30fbfedb4b6eb2a198cbaa9f2ba071d31f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 21 Jan 2023 16:21:33 -0500 Subject: compact checkbox and fix copy image btn overflow also fixes type for #tab_extensions in style.css --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index af6dfb21..12fc9e6a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -919,7 +919,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): + with FormRow(elem_id="img2img_checkboxes", variant="compact"): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") -- cgit v1.2.1 From 2621566153920eb70bfa439f3d7c126ee8d36ec8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 08:05:21 +0300 Subject: attention ctrl+up/down enhancements --- modules/shared.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index 00a1d64c..d68ac296 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -444,9 +444,11 @@ options_templates.update(options_section(('ui', "User interface"), { "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), - 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), - 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), - 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), + "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), + "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), + "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) options_templates.update(options_section(('ui', "Live previews"), { -- cgit v1.2.1 From 0792fae078ba362a5119f56d84e3f490a88690ae Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 08:20:48 +0300 Subject: fix missing field for aesthetic embedding extension --- modules/sd_disable_initialization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index c72d8efc..e90aa9fe 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -41,7 +41,9 @@ class DisableInitialization: return self.create_model_and_transforms(*args, pretrained=None, **kwargs) def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): - return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + res.name_or_path = pretrained_model_name_or_path + return res def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug -- cgit v1.2.1 From 112416d04171e4bee673f0adc9bd3aeba87ec71a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 10:17:12 +0300 Subject: add option to discard weights in checkpoint merger UI --- modules/extras.py | 9 ++++++++- modules/ui.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 1218f88f..385430dc 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,6 +1,7 @@ from __future__ import annotations import math import os +import re import sys import traceback import shutil @@ -285,7 +286,7 @@ def to_half(tensor, enable): return tensor -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): shared.state.begin() shared.state.job = 'model-merge' @@ -430,6 +431,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ for key in theta_0.keys(): theta_0[key] = to_half(theta_0[key], save_as_half) + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path filename = filename_generator() if custom_name == '' else custom_name diff --git a/modules/ui.py b/modules/ui.py index af6dfb21..eb4b7e6b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1248,6 +1248,9 @@ def create_ui(): bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") + with FormRow(): + discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") + with gr.Row(): modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') @@ -1838,6 +1841,7 @@ def create_ui(): checkpoint_format, config_source, bake_in_vae, + discard_weights, ], outputs=[ primary_model_name, -- cgit v1.2.1 From 35419b274614984e2b511a6ad34f37e41481c809 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 11:00:05 +0300 Subject: add an option to reorder tabs for extra networks --- modules/shared.py | 1 + modules/ui_extra_networks.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index d68ac296..cd78e50a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -448,6 +448,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), + "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4c88193f..285c8ffe 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -79,6 +79,22 @@ class ExtraNetworksUi: self.tabname = None +def pages_in_preferred_order(pages): + tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")] + + def tab_name_score(name): + name = name.lower() + for i, possible_match in enumerate(tab_order): + if possible_match in name: + return i + + return len(pages) + + tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)} + + return sorted(pages, key=lambda x: tab_scores[x.name]) + + def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] @@ -86,7 +102,7 @@ def create_ui(container, button, tabname): ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: - for page in ui.stored_extra_pages: + for page in pages_in_preferred_order(ui.stored_extra_pages): with gr.Tab(page.title): page_elem = gr.HTML(page.create_html(ui.tabname)) ui.pages.append(page_elem) -- cgit v1.2.1 From c98cb0f8ecc904666f47684e238dd022039ca16f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 11:04:02 +0300 Subject: amend previous commit to work in a proper fashion when saving previews --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 285c8ffe..af2b8071 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -98,11 +98,11 @@ def pages_in_preferred_order(pages): def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] - ui.stored_extra_pages = extra_pages.copy() + ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: - for page in pages_in_preferred_order(ui.stored_extra_pages): + for page in ui.stored_extra_pages: with gr.Tab(page.title): page_elem = gr.HTML(page.create_html(ui.tabname)) ui.pages.append(page_elem) -- cgit v1.2.1 From 43ac9ff205910e8207dfd45a842577344d399a92 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:40 +0300 Subject: Split history extras.py to postprocessing.py --- modules/extras.py | 466 ---------------------------------------------- modules/postprocessing.py | 466 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 466 insertions(+), 466 deletions(-) delete mode 100644 modules/extras.py create mode 100644 modules/postprocessing.py (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py deleted file mode 100644 index 385430dc..00000000 --- a/modules/extras.py +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations -import math -import os -import re -import sys -import traceback -import shutil - -import numpy as np -from PIL import Image - -import torch -import tqdm - -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model -from modules.ui import plaintext_to_html -import modules.codeformer_model -import gradio as gr -import safetensors.torch - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
-

{plaintext_to_html(str(key))}

-

{plaintext_to_html(str(text))}

-
-""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

{message}

" - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/postprocessing.py b/modules/postprocessing.py new file mode 100644 index 00000000..385430dc --- /dev/null +++ b/modules/postprocessing.py @@ -0,0 +1,466 @@ +from __future__ import annotations +import math +import os +import re +import sys +import traceback +import shutil + +import numpy as np +from PIL import Image + +import torch +import tqdm + +from typing import Callable, List, OrderedDict, Tuple +from functools import partial +from dataclasses import dataclass + +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules.shared import opts +import modules.gfpgan_model +from modules.ui import plaintext_to_html +import modules.codeformer_model +import gradio as gr +import safetensors.torch + +class LruCache(OrderedDict): + @dataclass(frozen=True) + class Key: + image_hash: int + info_hash: int + args_hash: int + + @dataclass + class Value: + image: Image.Image + info: str + + def __init__(self, max_size: int = 5, *args, **kwargs): + super().__init__(*args, **kwargs) + self._max_size = max_size + + def get(self, key: LruCache.Key) -> LruCache.Value: + ret = super().get(key) + if ret is not None: + self.move_to_end(key) # Move to end of eviction list + return ret + + def put(self, key: LruCache.Key, value: LruCache.Value) -> None: + self[key] = value + while len(self) > self._max_size: + self.popitem(last=False) + + +cached_images: LruCache = LruCache(max_size=5) + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + devices.torch_gc() + + shared.state.begin() + shared.state.job = 'extras' + + imageArr = [] + # Also keep track of original file names + imageNameArr = [] + outputs = [] + + if extras_mode == 1: + #convert file to pillow image + for img in image_folder: + image = Image.open(img) + imageArr.append(image) + imageNameArr.append(os.path.splitext(img.orig_name)[0]) + elif extras_mode == 2: + assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + + if input_dir == '': + return outputs, "Please select an input directory.", '' + image_list = shared.listfiles(input_dir) + for img in image_list: + try: + image = Image.open(img) + except Exception: + continue + imageArr.append(image) + imageNameArr.append(img) + else: + imageArr.append(image) + imageNameArr.append(None) + + if extras_mode == 2 and output_dir != '': + outpath = output_dir + else: + outpath = opts.outdir_samples or opts.outdir_extras_samples + + # Extra operation definitions + + def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-gfpgan' + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) + + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + return (res, info) + + def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-codeformer' + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) + + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" + return (res, info) + + def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): + shared.state.job = 'extras-upscale' + upscaler = shared.sd_upscalers[scaler_index] + res = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) + res = cropped + return res + + def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text + nonlocal upscaling_resize + if resize_mode == 1: + upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) + crop_info = " (crop)" if upscaling_crop else "" + info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" + return (image, info) + + @dataclass + class UpscaleParams: + upscaler_idx: int + blend_alpha: float + + def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: + blended_result: Image.Image = None + image_hash: str = hash(np.array(image.getdata()).tobytes()) + for upscaler in params: + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, + upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = LruCache.Key(image_hash=image_hash, + info_hash=hash(info), + args_hash=hash(upscale_args)) + cached_entry = cached_images.get(cache_key) + if cached_entry is None: + res = upscale(image, *upscale_args) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + cached_images.put(cache_key, LruCache.Value(image=res, info=info)) + else: + res, info = cached_entry.image, cached_entry.info + + if blended_result is None: + blended_result = res + else: + blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) + return (blended_result, info) + + # Build a list of operations to run + facefix_ops: List[Callable] = [] + facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] + facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] + + upscale_ops: List[Callable] = [] + upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] + + if upscaling_resize != 0: + step_params: List[UpscaleParams] = [] + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) + + upscale_ops.append(partial(run_upscalers_blend, step_params)) + + extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) + + for image, image_name in zip(imageArr, imageNameArr): + if image is None: + return outputs, "Please select an input image.", '' + + shared.state.textinfo = f'Processing image {image_name}' + + existing_pnginfo = image.info or {} + + image = image.convert("RGB") + info = "" + # Run each operation on each image + for op in extras_ops: + image, info = op(image, info) + + if opts.use_original_name_batch and image_name is not None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' + + if opts.enable_pnginfo: # append info before save + image.info = existing_pnginfo + image.info["extras"] = info + + if save_output: + # Add upscaler name as a suffix. + suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" + # Add second upscaler if applicable. + if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: + suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + + if extras_mode != 2 or show_extras_results : + outputs.append(image) + + devices.torch_gc() + + return outputs, plaintext_to_html(info), '' + +def clear_cache(): + cached_images.clear() + + +def run_pnginfo(image): + if image is None: + return '', '', '' + + geninfo, items = images.read_info_from_image(image) + items = {**{'parameters': geninfo}, **items} + + info = '' + for key, text in items.items(): + info += f""" +
+

{plaintext_to_html(str(key))}

+

{plaintext_to_html(str(text))}

+
+""".strip()+"\n" + + if len(info) == 0: + message = "Nothing found in the image." + info = f"

{message}

" + + return '', geninfo, info + + +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def to_half(tensor, enable): + if enable and tensor.dtype == torch.float: + return tensor.half() + + return tensor + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): + shared.state.begin() + shared.state.job = 'model-merge' + + def fail(message): + shared.state.textinfo = message + shared.state.end() + return [*[gr.update() for _ in range(4)], message] + + def weighted_sum(theta0, theta1, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + def get_difference(theta1, theta2): + return theta1 - theta2 + + def add_difference(theta0, theta1_2_diff, alpha): + return theta0 + (alpha * theta1_2_diff) + + def filename_weighted_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_difference(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighted_sum, None, weighted_sum), + "Add difference": (filename_add_difference, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + + if not primary_model_name: + return fail("Failed: Merging requires a primary model.") + + primary_model_info = sd_models.checkpoints_list[primary_model_name] + + if theta_func2 and not secondary_model_name: + return fail("Failed: Merging requires a secondary model.") + + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None + + if theta_func1 and not tertiary_model_name: + return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") + + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None + + result_is_inpainting_model = False + + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None + + if theta_func1: + shared.state.textinfo = f"Loading C" + print(f"Loading {tertiary_model_info.filename}...") + theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + + shared.state.textinfo = 'Merging B and C' + shared.state.sampling_steps = len(theta_1.keys()) + for key in tqdm.tqdm(theta_1.keys()): + if key in checkpoint_dict_skip_on_merge: + continue + + if 'model' in key: + if key in theta_2: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + else: + theta_1[key] = torch.zeros_like(theta_1[key]) + + shared.state.sampling_step += 1 + del theta_2 + + shared.state.nextjob() + + shared.state.textinfo = f"Loading {primary_model_info.filename}..." + print(f"Loading {primary_model_info.filename}...") + theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') + + print("Merging...") + shared.state.textinfo = 'Merging A and B' + shared.state.sampling_steps = len(theta_0.keys()) + for key in tqdm.tqdm(theta_0.keys()): + if theta_1 and 'model' in key and key in theta_1: + + if key in checkpoint_dict_skip_on_merge: + continue + + a = theta_0[key] + b = theta_1[key] + + # this enables merging an inpainting model (A) with another one (B); + # where normal model would have 4 channels, for latenst space, inpainting model would + # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 + if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: + if a.shape[1] == 4 and b.shape[1] == 9: + raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True + else: + theta_0[key] = theta_func2(a, b, multiplier) + + theta_0[key] = to_half(theta_0[key], save_as_half) + + shared.state.sampling_step += 1 + + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') + + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) + + del vae_dict + + if save_as_half and not theta_func2: + for key in theta_0.keys(): + theta_0[key] = to_half(theta_0[key], save_as_half) + + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path + + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format + + output_modelname = os.path.join(ckpt_dir, filename) + + shared.state.nextjob() + shared.state.textinfo = "Saving" + print(f"Saving to {output_modelname}...") + + _, extension = os.path.splitext(output_modelname) + if extension.lower() == ".safetensors": + safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) + else: + torch.save(theta_0, output_modelname) + + sd_models.list_models() + + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" + shared.state.end() + + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] -- cgit v1.2.1 From b238b14ee459486c4734cc2899b83f547813a467 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:40 +0300 Subject: Split history extras.py to postprocessing.py --- modules/extras.py | 466 ------------------------------------------------------ modules/temp | 466 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 466 insertions(+), 466 deletions(-) delete mode 100644 modules/extras.py create mode 100644 modules/temp (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py deleted file mode 100644 index 385430dc..00000000 --- a/modules/extras.py +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations -import math -import os -import re -import sys -import traceback -import shutil - -import numpy as np -from PIL import Image - -import torch -import tqdm - -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model -from modules.ui import plaintext_to_html -import modules.codeformer_model -import gradio as gr -import safetensors.torch - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
-

{plaintext_to_html(str(key))}

-

{plaintext_to_html(str(text))}

-
-""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

{message}

" - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/temp b/modules/temp new file mode 100644 index 00000000..385430dc --- /dev/null +++ b/modules/temp @@ -0,0 +1,466 @@ +from __future__ import annotations +import math +import os +import re +import sys +import traceback +import shutil + +import numpy as np +from PIL import Image + +import torch +import tqdm + +from typing import Callable, List, OrderedDict, Tuple +from functools import partial +from dataclasses import dataclass + +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules.shared import opts +import modules.gfpgan_model +from modules.ui import plaintext_to_html +import modules.codeformer_model +import gradio as gr +import safetensors.torch + +class LruCache(OrderedDict): + @dataclass(frozen=True) + class Key: + image_hash: int + info_hash: int + args_hash: int + + @dataclass + class Value: + image: Image.Image + info: str + + def __init__(self, max_size: int = 5, *args, **kwargs): + super().__init__(*args, **kwargs) + self._max_size = max_size + + def get(self, key: LruCache.Key) -> LruCache.Value: + ret = super().get(key) + if ret is not None: + self.move_to_end(key) # Move to end of eviction list + return ret + + def put(self, key: LruCache.Key, value: LruCache.Value) -> None: + self[key] = value + while len(self) > self._max_size: + self.popitem(last=False) + + +cached_images: LruCache = LruCache(max_size=5) + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + devices.torch_gc() + + shared.state.begin() + shared.state.job = 'extras' + + imageArr = [] + # Also keep track of original file names + imageNameArr = [] + outputs = [] + + if extras_mode == 1: + #convert file to pillow image + for img in image_folder: + image = Image.open(img) + imageArr.append(image) + imageNameArr.append(os.path.splitext(img.orig_name)[0]) + elif extras_mode == 2: + assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + + if input_dir == '': + return outputs, "Please select an input directory.", '' + image_list = shared.listfiles(input_dir) + for img in image_list: + try: + image = Image.open(img) + except Exception: + continue + imageArr.append(image) + imageNameArr.append(img) + else: + imageArr.append(image) + imageNameArr.append(None) + + if extras_mode == 2 and output_dir != '': + outpath = output_dir + else: + outpath = opts.outdir_samples or opts.outdir_extras_samples + + # Extra operation definitions + + def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-gfpgan' + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) + + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + return (res, info) + + def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-codeformer' + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) + + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" + return (res, info) + + def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): + shared.state.job = 'extras-upscale' + upscaler = shared.sd_upscalers[scaler_index] + res = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) + res = cropped + return res + + def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text + nonlocal upscaling_resize + if resize_mode == 1: + upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) + crop_info = " (crop)" if upscaling_crop else "" + info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" + return (image, info) + + @dataclass + class UpscaleParams: + upscaler_idx: int + blend_alpha: float + + def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: + blended_result: Image.Image = None + image_hash: str = hash(np.array(image.getdata()).tobytes()) + for upscaler in params: + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, + upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = LruCache.Key(image_hash=image_hash, + info_hash=hash(info), + args_hash=hash(upscale_args)) + cached_entry = cached_images.get(cache_key) + if cached_entry is None: + res = upscale(image, *upscale_args) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + cached_images.put(cache_key, LruCache.Value(image=res, info=info)) + else: + res, info = cached_entry.image, cached_entry.info + + if blended_result is None: + blended_result = res + else: + blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) + return (blended_result, info) + + # Build a list of operations to run + facefix_ops: List[Callable] = [] + facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] + facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] + + upscale_ops: List[Callable] = [] + upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] + + if upscaling_resize != 0: + step_params: List[UpscaleParams] = [] + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) + + upscale_ops.append(partial(run_upscalers_blend, step_params)) + + extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) + + for image, image_name in zip(imageArr, imageNameArr): + if image is None: + return outputs, "Please select an input image.", '' + + shared.state.textinfo = f'Processing image {image_name}' + + existing_pnginfo = image.info or {} + + image = image.convert("RGB") + info = "" + # Run each operation on each image + for op in extras_ops: + image, info = op(image, info) + + if opts.use_original_name_batch and image_name is not None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' + + if opts.enable_pnginfo: # append info before save + image.info = existing_pnginfo + image.info["extras"] = info + + if save_output: + # Add upscaler name as a suffix. + suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" + # Add second upscaler if applicable. + if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: + suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + + if extras_mode != 2 or show_extras_results : + outputs.append(image) + + devices.torch_gc() + + return outputs, plaintext_to_html(info), '' + +def clear_cache(): + cached_images.clear() + + +def run_pnginfo(image): + if image is None: + return '', '', '' + + geninfo, items = images.read_info_from_image(image) + items = {**{'parameters': geninfo}, **items} + + info = '' + for key, text in items.items(): + info += f""" +
+

{plaintext_to_html(str(key))}

+

{plaintext_to_html(str(text))}

+
+""".strip()+"\n" + + if len(info) == 0: + message = "Nothing found in the image." + info = f"

{message}

" + + return '', geninfo, info + + +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def to_half(tensor, enable): + if enable and tensor.dtype == torch.float: + return tensor.half() + + return tensor + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): + shared.state.begin() + shared.state.job = 'model-merge' + + def fail(message): + shared.state.textinfo = message + shared.state.end() + return [*[gr.update() for _ in range(4)], message] + + def weighted_sum(theta0, theta1, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + def get_difference(theta1, theta2): + return theta1 - theta2 + + def add_difference(theta0, theta1_2_diff, alpha): + return theta0 + (alpha * theta1_2_diff) + + def filename_weighted_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_difference(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighted_sum, None, weighted_sum), + "Add difference": (filename_add_difference, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + + if not primary_model_name: + return fail("Failed: Merging requires a primary model.") + + primary_model_info = sd_models.checkpoints_list[primary_model_name] + + if theta_func2 and not secondary_model_name: + return fail("Failed: Merging requires a secondary model.") + + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None + + if theta_func1 and not tertiary_model_name: + return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") + + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None + + result_is_inpainting_model = False + + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None + + if theta_func1: + shared.state.textinfo = f"Loading C" + print(f"Loading {tertiary_model_info.filename}...") + theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + + shared.state.textinfo = 'Merging B and C' + shared.state.sampling_steps = len(theta_1.keys()) + for key in tqdm.tqdm(theta_1.keys()): + if key in checkpoint_dict_skip_on_merge: + continue + + if 'model' in key: + if key in theta_2: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + else: + theta_1[key] = torch.zeros_like(theta_1[key]) + + shared.state.sampling_step += 1 + del theta_2 + + shared.state.nextjob() + + shared.state.textinfo = f"Loading {primary_model_info.filename}..." + print(f"Loading {primary_model_info.filename}...") + theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') + + print("Merging...") + shared.state.textinfo = 'Merging A and B' + shared.state.sampling_steps = len(theta_0.keys()) + for key in tqdm.tqdm(theta_0.keys()): + if theta_1 and 'model' in key and key in theta_1: + + if key in checkpoint_dict_skip_on_merge: + continue + + a = theta_0[key] + b = theta_1[key] + + # this enables merging an inpainting model (A) with another one (B); + # where normal model would have 4 channels, for latenst space, inpainting model would + # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 + if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: + if a.shape[1] == 4 and b.shape[1] == 9: + raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True + else: + theta_0[key] = theta_func2(a, b, multiplier) + + theta_0[key] = to_half(theta_0[key], save_as_half) + + shared.state.sampling_step += 1 + + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') + + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) + + del vae_dict + + if save_as_half and not theta_func2: + for key in theta_0.keys(): + theta_0[key] = to_half(theta_0[key], save_as_half) + + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path + + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format + + output_modelname = os.path.join(ckpt_dir, filename) + + shared.state.nextjob() + shared.state.textinfo = "Saving" + print(f"Saving to {output_modelname}...") + + _, extension = os.path.splitext(output_modelname) + if extension.lower() == ".safetensors": + safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) + else: + torch.save(theta_0, output_modelname) + + sd_models.list_models() + + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" + shared.state.end() + + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] -- cgit v1.2.1 From c56b36712289020a98f0c77794b9045a251ecd55 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:41 +0300 Subject: Split history extras.py to postprocessing.py --- modules/extras.py | 466 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/temp | 466 ------------------------------------------------------ 2 files changed, 466 insertions(+), 466 deletions(-) create mode 100644 modules/extras.py delete mode 100644 modules/temp (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py new file mode 100644 index 00000000..385430dc --- /dev/null +++ b/modules/extras.py @@ -0,0 +1,466 @@ +from __future__ import annotations +import math +import os +import re +import sys +import traceback +import shutil + +import numpy as np +from PIL import Image + +import torch +import tqdm + +from typing import Callable, List, OrderedDict, Tuple +from functools import partial +from dataclasses import dataclass + +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules.shared import opts +import modules.gfpgan_model +from modules.ui import plaintext_to_html +import modules.codeformer_model +import gradio as gr +import safetensors.torch + +class LruCache(OrderedDict): + @dataclass(frozen=True) + class Key: + image_hash: int + info_hash: int + args_hash: int + + @dataclass + class Value: + image: Image.Image + info: str + + def __init__(self, max_size: int = 5, *args, **kwargs): + super().__init__(*args, **kwargs) + self._max_size = max_size + + def get(self, key: LruCache.Key) -> LruCache.Value: + ret = super().get(key) + if ret is not None: + self.move_to_end(key) # Move to end of eviction list + return ret + + def put(self, key: LruCache.Key, value: LruCache.Value) -> None: + self[key] = value + while len(self) > self._max_size: + self.popitem(last=False) + + +cached_images: LruCache = LruCache(max_size=5) + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + devices.torch_gc() + + shared.state.begin() + shared.state.job = 'extras' + + imageArr = [] + # Also keep track of original file names + imageNameArr = [] + outputs = [] + + if extras_mode == 1: + #convert file to pillow image + for img in image_folder: + image = Image.open(img) + imageArr.append(image) + imageNameArr.append(os.path.splitext(img.orig_name)[0]) + elif extras_mode == 2: + assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + + if input_dir == '': + return outputs, "Please select an input directory.", '' + image_list = shared.listfiles(input_dir) + for img in image_list: + try: + image = Image.open(img) + except Exception: + continue + imageArr.append(image) + imageNameArr.append(img) + else: + imageArr.append(image) + imageNameArr.append(None) + + if extras_mode == 2 and output_dir != '': + outpath = output_dir + else: + outpath = opts.outdir_samples or opts.outdir_extras_samples + + # Extra operation definitions + + def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-gfpgan' + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) + + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + return (res, info) + + def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-codeformer' + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) + + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" + return (res, info) + + def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): + shared.state.job = 'extras-upscale' + upscaler = shared.sd_upscalers[scaler_index] + res = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) + res = cropped + return res + + def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text + nonlocal upscaling_resize + if resize_mode == 1: + upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) + crop_info = " (crop)" if upscaling_crop else "" + info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" + return (image, info) + + @dataclass + class UpscaleParams: + upscaler_idx: int + blend_alpha: float + + def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: + blended_result: Image.Image = None + image_hash: str = hash(np.array(image.getdata()).tobytes()) + for upscaler in params: + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, + upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = LruCache.Key(image_hash=image_hash, + info_hash=hash(info), + args_hash=hash(upscale_args)) + cached_entry = cached_images.get(cache_key) + if cached_entry is None: + res = upscale(image, *upscale_args) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + cached_images.put(cache_key, LruCache.Value(image=res, info=info)) + else: + res, info = cached_entry.image, cached_entry.info + + if blended_result is None: + blended_result = res + else: + blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) + return (blended_result, info) + + # Build a list of operations to run + facefix_ops: List[Callable] = [] + facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] + facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] + + upscale_ops: List[Callable] = [] + upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] + + if upscaling_resize != 0: + step_params: List[UpscaleParams] = [] + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) + + upscale_ops.append(partial(run_upscalers_blend, step_params)) + + extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) + + for image, image_name in zip(imageArr, imageNameArr): + if image is None: + return outputs, "Please select an input image.", '' + + shared.state.textinfo = f'Processing image {image_name}' + + existing_pnginfo = image.info or {} + + image = image.convert("RGB") + info = "" + # Run each operation on each image + for op in extras_ops: + image, info = op(image, info) + + if opts.use_original_name_batch and image_name is not None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' + + if opts.enable_pnginfo: # append info before save + image.info = existing_pnginfo + image.info["extras"] = info + + if save_output: + # Add upscaler name as a suffix. + suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" + # Add second upscaler if applicable. + if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: + suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + + if extras_mode != 2 or show_extras_results : + outputs.append(image) + + devices.torch_gc() + + return outputs, plaintext_to_html(info), '' + +def clear_cache(): + cached_images.clear() + + +def run_pnginfo(image): + if image is None: + return '', '', '' + + geninfo, items = images.read_info_from_image(image) + items = {**{'parameters': geninfo}, **items} + + info = '' + for key, text in items.items(): + info += f""" +
+

{plaintext_to_html(str(key))}

+

{plaintext_to_html(str(text))}

+
+""".strip()+"\n" + + if len(info) == 0: + message = "Nothing found in the image." + info = f"

{message}

" + + return '', geninfo, info + + +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def to_half(tensor, enable): + if enable and tensor.dtype == torch.float: + return tensor.half() + + return tensor + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): + shared.state.begin() + shared.state.job = 'model-merge' + + def fail(message): + shared.state.textinfo = message + shared.state.end() + return [*[gr.update() for _ in range(4)], message] + + def weighted_sum(theta0, theta1, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + def get_difference(theta1, theta2): + return theta1 - theta2 + + def add_difference(theta0, theta1_2_diff, alpha): + return theta0 + (alpha * theta1_2_diff) + + def filename_weighted_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_difference(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighted_sum, None, weighted_sum), + "Add difference": (filename_add_difference, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + + if not primary_model_name: + return fail("Failed: Merging requires a primary model.") + + primary_model_info = sd_models.checkpoints_list[primary_model_name] + + if theta_func2 and not secondary_model_name: + return fail("Failed: Merging requires a secondary model.") + + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None + + if theta_func1 and not tertiary_model_name: + return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") + + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None + + result_is_inpainting_model = False + + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None + + if theta_func1: + shared.state.textinfo = f"Loading C" + print(f"Loading {tertiary_model_info.filename}...") + theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + + shared.state.textinfo = 'Merging B and C' + shared.state.sampling_steps = len(theta_1.keys()) + for key in tqdm.tqdm(theta_1.keys()): + if key in checkpoint_dict_skip_on_merge: + continue + + if 'model' in key: + if key in theta_2: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + else: + theta_1[key] = torch.zeros_like(theta_1[key]) + + shared.state.sampling_step += 1 + del theta_2 + + shared.state.nextjob() + + shared.state.textinfo = f"Loading {primary_model_info.filename}..." + print(f"Loading {primary_model_info.filename}...") + theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') + + print("Merging...") + shared.state.textinfo = 'Merging A and B' + shared.state.sampling_steps = len(theta_0.keys()) + for key in tqdm.tqdm(theta_0.keys()): + if theta_1 and 'model' in key and key in theta_1: + + if key in checkpoint_dict_skip_on_merge: + continue + + a = theta_0[key] + b = theta_1[key] + + # this enables merging an inpainting model (A) with another one (B); + # where normal model would have 4 channels, for latenst space, inpainting model would + # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 + if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: + if a.shape[1] == 4 and b.shape[1] == 9: + raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True + else: + theta_0[key] = theta_func2(a, b, multiplier) + + theta_0[key] = to_half(theta_0[key], save_as_half) + + shared.state.sampling_step += 1 + + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') + + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) + + del vae_dict + + if save_as_half and not theta_func2: + for key in theta_0.keys(): + theta_0[key] = to_half(theta_0[key], save_as_half) + + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path + + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format + + output_modelname = os.path.join(ckpt_dir, filename) + + shared.state.nextjob() + shared.state.textinfo = "Saving" + print(f"Saving to {output_modelname}...") + + _, extension = os.path.splitext(output_modelname) + if extension.lower() == ".safetensors": + safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) + else: + torch.save(theta_0, output_modelname) + + sd_models.list_models() + + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" + shared.state.end() + + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/temp b/modules/temp deleted file mode 100644 index 385430dc..00000000 --- a/modules/temp +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations -import math -import os -import re -import sys -import traceback -import shutil - -import numpy as np -from PIL import Image - -import torch -import tqdm - -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model -from modules.ui import plaintext_to_html -import modules.codeformer_model -import gradio as gr -import safetensors.torch - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
-

{plaintext_to_html(str(key))}

-

{plaintext_to_html(str(text))}

-
-""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

{message}

" - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] -- cgit v1.2.1 From 68303c96e5ab31576a8238a24bf5b6191cf16ed1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:38:39 +0300 Subject: split oversize extras.py to postprocessing.py --- modules/extras.py | 217 +------------------------------------- modules/postprocessing.py | 257 +--------------------------------------------- modules/ui.py | 10 +- modules/ui_components.py | 7 ++ 4 files changed, 18 insertions(+), 473 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index 385430dc..f04ddfc2 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,231 +1,16 @@ -from __future__ import annotations -import math import os import re -import sys -import traceback import shutil -import numpy as np -from PIL import Image import torch import tqdm -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model +from modules import shared, images, sd_models, sd_vae from modules.ui import plaintext_to_html -import modules.codeformer_model import gradio as gr import safetensors.torch -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - def run_pnginfo(image): if image is None: diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 385430dc..cb85720b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,28 +1,18 @@ from __future__ import annotations -import math import os -import re -import sys -import traceback -import shutil import numpy as np from PIL import Image -import torch -import tqdm - from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules import shared, images, devices, ui_components from modules.shared import opts import modules.gfpgan_model -from modules.ui import plaintext_to_html import modules.codeformer_model -import gradio as gr -import safetensors.torch + class LruCache(OrderedDict): @dataclass(frozen=True) @@ -55,7 +45,7 @@ class LruCache(OrderedDict): cached_images: LruCache = LruCache(max_size=5) -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): devices.torch_gc() shared.state.begin() @@ -221,246 +211,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ devices.torch_gc() - return outputs, plaintext_to_html(info), '' + return outputs, ui_components.plaintext_to_html(info), '' + def clear_cache(): cached_images.clear() - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
-

{plaintext_to_html(str(key))}

-

{plaintext_to_html(str(text))}

-
-""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

{message}

" - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/ui.py b/modules/ui.py index eb4b7e6b..4116e167 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -95,8 +95,8 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text + return ui_components.plaintext_to_html(text) + def send_gradio_gallery_to_image(x): if len(x) == 0: @@ -1152,7 +1152,7 @@ def create_ui(): result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -1183,7 +1183,7 @@ def create_ui(): parameters_copypaste.add_paste_fields("extras", extras_image, None) extras_image.change( - fn=modules.extras.clear_cache, + fn=postprocessing.clear_cache, inputs=[], outputs=[] ) diff --git a/modules/ui_components.py b/modules/ui_components.py index 46324425..989cc87b 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,3 +1,5 @@ +import html + import gradio as gr @@ -47,3 +49,8 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text -- cgit v1.2.1 From 985c0b8e9abdd67734d638badefb6ea806b1f28b Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sat, 21 Jan 2023 17:45:36 -0300 Subject: feat(extra-networks): add thumbs view style --- modules/ui_extra_networks.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index af2b8071..ce4801b5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,7 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname): + def create_html(self, tabname, view = 'cards'): items_html = '' for item in self.list_items(): @@ -36,7 +36,7 @@ class ExtraNetworksPage: items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = f""" -
+
{items_html}
""" @@ -75,6 +75,7 @@ class ExtraNetworksUi: self.button_save_preview = None self.preview_target_filename = None + self.view_dropdown = None self.tabname = None @@ -110,6 +111,7 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value='cards') ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -117,16 +119,17 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(): + def refresh(view='cards'): res = [] for pg in ui.stored_extra_pages: pg.refresh() - res.append(pg.create_html(ui.tabname)) + res.append(pg.create_html(ui.tabname, view)) return res - button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + ui.view_dropdown.change(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) + button_refresh.click(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) return ui @@ -139,7 +142,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename): + def save_preview(index, images, filename, view='cards'): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] @@ -161,11 +164,11 @@ def setup_ui(ui, gallery): image.save(filename) - return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + return [page.create_html(ui.tabname, view) for page in ui.stored_extra_pages] ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", - inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + _js="function(x, y, z, a){console.log(x, y, z, a); return [selected_gallery_index(), y, z, a]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename, ui.view_dropdown], outputs=[*ui.pages] ) -- cgit v1.2.1 From 66eef11ce7f3db108225668c573cb4a763a43fb3 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sat, 21 Jan 2023 18:27:57 -0300 Subject: feat(extra-networks): add default view setting --- modules/shared.py | 4 ++++ modules/ui_extra_networks.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) (limited to 'modules') diff --git a/modules/shared.py b/modules/shared.py index cd78e50a..e9548864 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,6 +430,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), })) +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), +})) + options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index ce4801b5..179ba47a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,7 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname, view = 'cards'): + def create_html(self, tabname, view=shared.opts.extra_networks_default_view): items_html = '' for item in self.list_items(): @@ -111,7 +111,7 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") - ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value='cards') + ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value=lambda: shared.opts.extra_networks_default_view) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -119,7 +119,7 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(view='cards'): + def refresh(view): res = [] for pg in ui.stored_extra_pages: @@ -142,7 +142,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename, view='cards'): + def save_preview(index, images, filename, view): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] -- cgit v1.2.1 From f80ff3c1e444926879c284be9384a26ca38d4955 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sun, 22 Jan 2023 22:01:24 -0300 Subject: feat(extra-networks): remove view dropdown --- modules/ui_extra_networks.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 179ba47a..2ddac3d8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,8 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname, view=shared.opts.extra_networks_default_view): + def create_html(self, tabname): + view = shared.opts.extra_networks_default_view items_html = '' for item in self.list_items(): @@ -75,7 +76,6 @@ class ExtraNetworksUi: self.button_save_preview = None self.preview_target_filename = None - self.view_dropdown = None self.tabname = None @@ -111,7 +111,6 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") - ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value=lambda: shared.opts.extra_networks_default_view) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -119,17 +118,16 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(view): + def refresh(): res = [] for pg in ui.stored_extra_pages: pg.refresh() - res.append(pg.create_html(ui.tabname, view)) + res.append(pg.create_html(ui.tabname)) return res - ui.view_dropdown.change(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) - button_refresh.click(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui @@ -142,7 +140,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename, view): + def save_preview(index, images, filename): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] @@ -164,11 +162,11 @@ def setup_ui(ui, gallery): image.save(filename) - return [page.create_html(ui.tabname, view) for page in ui.stored_extra_pages] + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z, a){console.log(x, y, z, a); return [selected_gallery_index(), y, z, a]}", - inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename, ui.view_dropdown], + _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], outputs=[*ui.pages] ) -- cgit v1.2.1 From b5230197a69d36a79fdc4919c59a03e00e872dd3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 09:24:43 +0300 Subject: rework extras tab to use script system --- modules/api/api.py | 13 +- modules/postprocessing.py | 236 +++++++++------------------------ modules/scripts.py | 28 ++-- modules/scripts_postprocessing.py | 147 +++++++++++++++++++++ modules/shared.py | 5 + modules/ui.py | 265 +------------------------------------- modules/ui_common.py | 202 +++++++++++++++++++++++++++++ modules/ui_components.py | 6 - modules/ui_postprocessing.py | 57 ++++++++ 9 files changed, 500 insertions(+), 459 deletions(-) create mode 100644 modules/scripts_postprocessing.py create mode 100644 modules/ui_common.py create mode 100644 modules/ui_postprocessing.py (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index f2e9e884..5d60fc0a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.extras import run_extras from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork @@ -45,10 +44,8 @@ def validate_sampler_name(name): def setUpscalers(req: dict): reqDict = vars(req) - reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) - reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) + reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def decode_base64_to_image(encoding): @@ -244,7 +241,7 @@ class Api: reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) @@ -260,7 +257,7 @@ class Api: reqDict.pop('imageList') with self.queue_lock: - result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index cb85720b..8514fea7 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,219 +1,103 @@ -from __future__ import annotations import os -import numpy as np from PIL import Image -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import shared, images, devices, ui_components +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste from modules.shared import opts -import modules.gfpgan_model -import modules.codeformer_model - - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) -def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin() shared.state.job = 'extras' - imageArr = [] - # Also keep track of original file names - imageNameArr = [] + image_data = [] + image_names = [] outputs = [] if extras_mode == 1: - #convert file to pillow image for img in image_folder: image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) + image_data.append(image) + image_names.append(os.path.splitext(img.orig_name)[0]) elif extras_mode == 2: assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + assert input_dir, 'input directory not selected' - if input_dir == '': - return outputs, "Please select an input directory.", '' image_list = shared.listfiles(input_dir) - for img in image_list: + for filename in image_list: try: - image = Image.open(img) + image = Image.open(filename) except Exception: continue - imageArr.append(image) - imageNameArr.append(img) + image_data.append(image) + image_names.append(filename) else: - imageArr.append(image) - imageNameArr.append(None) + assert image, 'image not selected' + + image_data.append(image) + image_names.append(None) if extras_mode == 2 and output_dir != '': outpath = output_dir else: outpath = opts.outdir_samples or opts.outdir_extras_samples - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - + infotext = '' + + for image, name in zip(image_data, image_names): + shared.state.textinfo = name + existing_pnginfo = image.info or {} - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) + pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] + scripts.scripts_postproc.run(pp, args) + + if opts.use_original_name_batch and name is not None: + basename = os.path.splitext(os.path.basename(name))[0] else: basename = '' - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info + infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + if opts.enable_pnginfo: + pp.image.info = existing_pnginfo + pp.image.info["postprocessing"] = infotext - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + if save_output: + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) - if extras_mode != 2 or show_extras_results : - outputs.append(image) + if extras_mode != 2 or show_extras_results: + outputs.append(pp.image) devices.torch_gc() - return outputs, ui_components.plaintext_to_html(info), '' - - -def clear_cache(): - cached_images.clear() - + return outputs, ui_common.plaintext_to_html(infotext), '' + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + """old handler for API""" + + args = scripts.scripts_postproc.create_args_for_run({ + "Upscale": { + "upscale_mode": resize_mode, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + }, + "GFPGAN": { + "gfpgan_visibility": gfpgan_visibility, + }, + "CodeFormer": { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + }, + }) + + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/scripts.py b/modules/scripts.py index 4ffc369b..03907a63 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ from collections import namedtuple import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks, extensions, script_loading +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() @@ -150,8 +150,10 @@ def basedir(): return current_basedir -scripts_data = [] ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) + +scripts_data = [] +postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) @@ -190,23 +192,31 @@ def list_files_with_name(filename): def load_scripts(): global current_basedir scripts_data.clear() + postprocessing_scripts_data.clear() script_callbacks.clear_callbacks() scripts_list = list_scripts("scripts", ".py") syspath = sys.path + def register_scripts_from_module(module): + for key, script_class in module.__dict__.items(): + if type(script_class) != type: + continue + + if issubclass(script_class, Script): + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): + postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + for scriptfile in sorted(scripts_list): try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path current_basedir = scriptfile.basedir - module = script_loading.load_module(scriptfile.path) - - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + script_module = script_loading.load_module(scriptfile.path) + register_scripts_from_module(script_module) except Exception: print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) @@ -413,6 +423,7 @@ class ScriptRunner: scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() +scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() scripts_current: ScriptRunner = None @@ -423,12 +434,13 @@ def reload_script_body_only(): def reload_scripts(): - global scripts_txt2img, scripts_img2img + global scripts_txt2img, scripts_img2img, scripts_postproc load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() def IOComponent_init(self, *args, **kwargs): diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py new file mode 100644 index 00000000..25de02d0 --- /dev/null +++ b/modules/scripts_postprocessing.py @@ -0,0 +1,147 @@ +import os +import gradio as gr + +from modules import errors, shared + + +class PostprocessedImage: + def __init__(self, image): + self.image = image + self.info = {} + + +class ScriptPostprocessing: + filename = None + controls = None + args_from = None + args_to = None + + order = 1000 + """scripts will be ordred by this value in postprocessing UI""" + + name = None + """this function should return the title of the script.""" + + group = None + """A gr.Group component that has all script's UI inside it""" + + def ui(self): + """ + This function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be a dictionary that maps parameter names to components used in processing. + Values of those components will be passed to process() function. + """ + + pass + + def process(self, pp: PostprocessedImage, **args): + """ + This function is called to postprocess the image. + args contains a dictionary with all values returned by components from ui() + """ + + pass + + def image_changed(self): + pass + + +def wrap_call(func, filename, funcname, *args, default=None, **kwargs): + try: + res = func(*args, **kwargs) + return res + except Exception as e: + errors.display(e, f"calling {filename}/{funcname}") + + return default + + +class ScriptPostprocessingRunner: + def __init__(self): + self.scripts = None + self.ui_created = False + + def initialize_scripts(self, scripts_data): + self.scripts = [] + + for script_class, path, basedir, script_module in scripts_data: + script: ScriptPostprocessing = script_class() + script.filename = path + + self.scripts.append(script) + + def create_script_ui(self, script, inputs): + script.args_from = len(inputs) + script.args_to = len(inputs) + + script.controls = wrap_call(script.ui, script.filename, "ui") + + for control in script.controls.values(): + control.custom_script_source = os.path.basename(script.filename) + + inputs += list(script.controls.values()) + script.args_to = len(inputs) + + def scripts_in_preferred_order(self): + if self.scripts is None: + import modules.scripts + self.initialize_scripts(modules.scripts.postprocessing_scripts_data) + + scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + + def script_score(name): + name = name.lower() + for i, possible_match in enumerate(scripts_order): + if possible_match in name: + return i + + return len(self.scripts) + + script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)} + + return sorted(self.scripts, key=lambda x: script_scores[x.name]) + + def setup_ui(self): + inputs = [] + + for script in self.scripts_in_preferred_order(): + with gr.Box() as group: + self.create_script_ui(script, inputs) + + script.group = group + + self.ui_created = True + return inputs + + def run(self, pp: PostprocessedImage, args): + for script in self.scripts_in_preferred_order(): + shared.state.job = script.name + + script_args = args[script.args_from:script.args_to] + + process_args = {} + for (name, component), value in zip(script.controls.items(), script_args): + process_args[name] = value + + script.process(pp, **process_args) + + def create_args_for_run(self, scripts_args): + if not self.ui_created: + with gr.Blocks(analytics_enabled=False): + self.setup_ui() + + scripts = self.scripts_in_preferred_order() + args = [None] * max([x.args_to for x in scripts]) + + for script in scripts: + script_args_dict = scripts_args.get(script.name, None) + if script_args_dict is not None: + + for i, name in enumerate(script.controls): + args[script.args_from + i] = script_args_dict.get(name, None) + + return args + + def image_changed(self): + for script in self.scripts_in_preferred_order(): + script.image_changed() diff --git a/modules/shared.py b/modules/shared.py index cd78e50a..cb73bf31 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -474,6 +474,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), })) +options_templates.update(options_section(('postprocessing', "Postprocessing"), { + 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), +})) + options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable those extensions"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), diff --git a/modules/ui.py b/modules/ui.py index 4116e167..8cb8e613 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -86,7 +86,6 @@ css_hide_progressbar = """ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️ paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 @@ -95,7 +94,7 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - return ui_components.plaintext_to_html(text) + return ui_common.plaintext_to_html(text) def send_gradio_gallery_to_image(x): @@ -103,70 +102,6 @@ def send_gradio_gallery_to_image(x): return None return image_from_url_text(x[0]) -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -444,19 +379,6 @@ def apply_setting(key, value): opts.save(shared.config_filename) return getattr(opts, key) - -def update_generation_info(generation_info, html_info, img_index): - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info, gr.update() - return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info, gr.update() - - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -477,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel', elem_id=f"{tabname}_results"): - with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ], - show_progress=False, - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return ui_common.create_output_panel(tabname, outdir) def create_sampler_and_steps_selection(choices, tabname): @@ -1106,86 +928,7 @@ def create_ui(): modules.scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='compact'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=postprocessing.clear_cache, - inputs=[], outputs=[] - ) + ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): diff --git a/modules/ui_common.py b/modules/ui_common.py new file mode 100644 index 00000000..8ce75b8c --- /dev/null +++ b/modules/ui_common.py @@ -0,0 +1,202 @@ +import json +import html +import os +import platform +import sys + +import gradio as gr +import scipy as sp + +from modules import call_queue, shared +from modules.generation_parameters_copypaste import image_from_url_text +import modules.images + +folder_symbol = '\U0001f4c2' # 📂 + + +def update_generation_info(generation_info, html_info, img_index): + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info, gr.update() + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = shared.opts.outdir_save + save_to_dirs = shared.opts.use_save_to_dirs_for_ui + extension: str = shared.opts.samples_format + start_index = 0 + + if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(shared.opts.outdir_save, exist_ok=True) + + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def create_output_panel(tabname, outdir): + from modules import shared + import modules.generation_parameters_copypaste as parameters_copypaste + + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(shared.opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], + ) + + save.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ], + show_progress=False, + ) + + save_zip.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log diff --git a/modules/ui_components.py b/modules/ui_components.py index 989cc87b..9aec3097 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,5 +1,3 @@ -import html - import gradio as gr @@ -50,7 +48,3 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py new file mode 100644 index 00000000..b418d955 --- /dev/null +++ b/modules/ui_postprocessing.py @@ -0,0 +1,57 @@ +import gradio as gr +from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue +import modules.generation_parameters_copypaste as parameters_copypaste + + +def create_ui(): + tab_index = gr.State(value=0) + + with gr.Row().style(equal_height=False, variant='compact'): + with gr.Column(variant='compact'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single: + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + script_inputs = scripts.scripts_postproc.setup_ui() + + with gr.Column(): + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + + tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) + tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) + tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) + + submit.click( + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + inputs=[ + tab_index, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + *script_inputs + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=scripts.scripts_postproc.image_changed, + inputs=[], outputs=[] + ) -- cgit v1.2.1 From 669dbd9725b3a285503e093a75c0dfa332073d8a Mon Sep 17 00:00:00 2001 From: Shondoit Date: Mon, 23 Jan 2023 09:54:42 +0100 Subject: Fix dark mode Fixes #7048 Co-Authored-By: J.J. Tolton --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index eb4b7e6b..43bcb7e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1942,11 +1942,11 @@ def reload_javascript(): if cmd_opts.theme is not None: inline += f"set_theme('{cmd_opts.theme}');" - head += f'\n' - for script in modules.scripts.list_scripts("javascript", ".js"): head += f'\n' + head += f'\n' + def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace(b'', f'{head}'.encode("utf8")) -- cgit v1.2.1 From fabdae089e476c66eba3b0562e4e1881891804b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 14:42:49 +0300 Subject: add missing import to previous commit --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 8cb8e613..6b5dfd61 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,6 +41,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +import modules.extras warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) -- cgit v1.2.1 From 41265a026de699cc223ca5b76c69b4e8e74aa7c1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 14:50:20 +0300 Subject: third time's the charm --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index f04ddfc2..36123aa5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import torch import tqdm from modules import shared, images, sd_models, sd_vae -from modules.ui import plaintext_to_html +from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch -- cgit v1.2.1 From 194cbd065e4644e986889b78a5a949e075b610e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 15:50:32 +0300 Subject: fix open directory button failing --- modules/ui.py | 1 - modules/ui_common.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui.py b/modules/ui.py index 94d4a80a..85ae62c7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,7 +5,6 @@ import mimetypes import os import platform import random -import subprocess as sp import sys import tempfile import time diff --git a/modules/ui_common.py b/modules/ui_common.py index 8ce75b8c..9405ac1f 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -5,7 +5,7 @@ import platform import sys import gradio as gr -import scipy as sp +import subprocess as sp from modules import call_queue, shared from modules.generation_parameters_copypaste import image_from_url_text -- cgit v1.2.1 From 59146621e256269b85feb536edeb745da20daf68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 16:40:20 +0300 Subject: better support for xformers flash attention on older versions of torch --- modules/errors.py | 12 +++++++++++ modules/sd_hijack_optimizations.py | 42 ++++++++++++++++---------------------- 2 files changed, 30 insertions(+), 24 deletions(-) (limited to 'modules') diff --git a/modules/errors.py b/modules/errors.py index a10e8708..f6b80dbb 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable """) +already_displayed = {} + + +def display_once(e: Exception, task): + if task in already_displayed: + return + + display(e, task) + + already_displayed[task] = 1 + + def run(code, task): try: code() diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 9967359b..74452709 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared +from modules import shared, errors from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ ) +def get_xformers_flash_attention_op(q, k, v): + if not shared.cmd_opts.xformers_flash_attention: + return None + + try: + flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = flash_attention_op + if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + return flash_attention_op + except Exception as e: + errors.display_once(e, "enabling flash attention") + + return None + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - if shared.cmd_opts.xformers_flash_attention: - op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp - fw, bw = op - if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): - # print('xformers_attention_forward', q.shape, k.shape, v.shape) - # Flash Attention is not availabe for the input arguments. - # Fallback to default xFormers' backend. - op = None - else: - op = None - - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - if shared.cmd_opts.xformers_flash_attention: - op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp - fw, bw = op - if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)): - # print('xformers_attnblock_forward', q.shape, k.shape, v.shape) - # Flash Attention is not availabe for the input arguments. - # Fallback to default xFormers' backend. - op = None - else: - op = None - out = xformers.ops.memory_efficient_attention(q, k, v, op=op) + out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out -- cgit v1.2.1 From 925dd09c91e7338aef72e4ec99d67b8b57280215 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 09:03:17 -0500 Subject: improve interrogate --- modules/interrogate.py | 29 +++++++++++++++++------------ modules/shared.py | 1 + 2 files changed, 18 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index 19938cbb..1d1ac572 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -20,6 +20,7 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +category_types = ["artists", "flavors", "mediums", "movements"] def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") @@ -27,12 +28,8 @@ def download_default_clip_interrogate_categories(content_dir): tmpdir = content_dir + "_tmp" try: os.makedirs(tmpdir) - - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) - + for category_type in category_types: + torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt")) os.rename(tmpdir, content_dir) except Exception as e: @@ -51,12 +48,13 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None + self.selected_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None: - return self.loaded_categories + if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + return self.loaded_categories self.loaded_categories = [] @@ -64,14 +62,19 @@ class InterrogateModels: download_default_clip_interrogate_categories(self.content_dir) if os.path.exists(self.content_dir): - for filename in os.listdir(self.content_dir): + self.selected_categories = shared.opts.interrogate_clip_categories + for category_type in category_types: + if 'all' not in self.selected_categories and category_type not in self.selected_categories: + continue + filename = os.path.join(self.content_dir, f"{category_type}.txt") + if not os.path.isfile(filename): + continue m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) - - with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: + with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) return self.loaded_categories @@ -139,6 +142,8 @@ class InterrogateModels: def rank(self, image_features, text_array, top_count=1): import clip + devices.torch_gc() + if shared.opts.interrogate_clip_dict_limit != 0: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] diff --git a/modules/shared.py b/modules/shared.py index a644c0be..63b236c5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), + "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), -- cgit v1.2.1 From e8c3d03f7d9966b81458944efb25666b2143153f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 17:59:58 +0300 Subject: a possible fix for broken image upscaling --- modules/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 8514fea7..09d8e605 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -67,7 +67,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, pp.image.info["postprocessing"] = infotext if save_output: - images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if extras_mode != 2 or show_extras_results: outputs.append(pp.image) -- cgit v1.2.1 From 6e1b296baf7a2cdc0ee747225f1704bd2d45c118 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 10:10:59 -0500 Subject: api-image-format --- modules/api/api.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index 5d60fc0a..b1dd14cc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List +import piexif +import piexif.helper def upscaler_to_index(name: str): try: @@ -56,18 +58,30 @@ def decode_base64_to_image(encoding): def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - # Copy any text-only metadata - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True + if opts.samples_format.lower() == 'png': + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + + elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + parameters = image.info.get('parameters', None) + exif_bytes = piexif.dump({ + "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } + }) + if opts.samples_format.lower() in ("jpg", "jpeg"): + image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + else: + image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") - image.save( - output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) - ) bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) def api_middleware(app: FastAPI): -- cgit v1.2.1 From 04a561c11c9bf9a00d7f9b50ca3f7962aa59ba6e Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 12:29:23 -0500 Subject: add option to skip interrogate categories --- modules/interrogate.py | 32 ++++++++++++++++++-------------- modules/shared.py | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index 1d1ac572..c252b148 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -2,6 +2,7 @@ import os import sys import traceback from collections import namedtuple +from pathlib import Path import re import torch @@ -20,12 +21,16 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") -category_types = ["artists", "flavors", "mediums", "movements"] +def category_types(): + return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')] + def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") tmpdir = content_dir + "_tmp" + category_types = ["artists", "flavors", "mediums", "movements"] + try: os.makedirs(tmpdir) for category_type in category_types: @@ -48,33 +53,32 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None - self.selected_categories = [] + self.skip_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories: return self.loaded_categories self.loaded_categories = [] - if not os.path.exists(self.content_dir): - download_default_clip_interrogate_categories(self.content_dir) - if os.path.exists(self.content_dir): - self.selected_categories = shared.opts.interrogate_clip_categories - for category_type in category_types: - if 'all' not in self.selected_categories and category_type not in self.selected_categories: - continue - filename = os.path.join(self.content_dir, f"{category_type}.txt") - if not os.path.isfile(filename): + self.skip_categories = shared.opts.interrogate_clip_skip_categories + category_types = [] + for filename in Path(self.content_dir).glob('*.txt'): + category_types.append(filename.stem) + if filename.stem in self.skip_categories: continue - m = re_topn.search(filename) + m = re_topn.search(filename.stem) topn = 1 if m is None else int(m.group(1)) with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines)) return self.loaded_categories diff --git a/modules/shared.py b/modules/shared.py index d7a18f6a..5f713bee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,7 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), - "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}), + "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), -- cgit v1.2.1 From 7b1c7ba87b14da9960d0347269421233f4cb5838 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 23:11:34 +0300 Subject: add support for apostrophe in extra network names --- modules/ui_extra_networks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 2ddac3d8..8b4f97f8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -3,6 +3,7 @@ import os.path from modules import shared import gradio as gr import json +import html from modules.generation_parameters_copypaste import image_from_url_text @@ -54,12 +55,13 @@ class ExtraNetworksPage: preview = item.get("preview", None) args = { - "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', + "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', "prompt": item["prompt"], "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], - "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"', + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', } return self.card_page.format(**args) -- cgit v1.2.1 From 5c1cb9263f980641007088a37360fcab01761d37 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 00:24:17 +0300 Subject: fix BLIP failing to import depending on configuration --- modules/interrogate.py | 3 ++- modules/paths.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/interrogate.py b/modules/interrogate.py index c252b148..236e6983 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -83,7 +83,8 @@ class InterrogateModels: return self.loaded_categories def load_blip_model(self): - import models.blip + with paths.Prioritize("BLIP"): + import models.blip files = modelloader.load_models( model_path=os.path.join(paths.models_path, "BLIP"), diff --git a/modules/paths.py b/modules/paths.py index 4dd03a35..20b3e4d8 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,3 +38,17 @@ for d, must_exist, what, options in path_dirs: else: sys.path.append(d) paths[what] = d + + +class Prioritize: + def __init__(self, name): + self.name = name + self.path = None + + def __enter__(self): + self.path = sys.path.copy() + sys.path = [paths[self.name]] + sys.path + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.path = self.path + self.path = None -- cgit v1.2.1 From 45e270dfc853216b2c413f915946f0f2842e57a4 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 17:11:22 -0500 Subject: add image decod exception handling --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index b1dd14cc..e6e31e41 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -53,7 +53,11 @@ def setUpscalers(req: dict): def decode_base64_to_image(encoding): if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] - return Image.open(BytesIO(base64.b64decode(encoding))) + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as err: + raise HTTPException(status_code=500, detail="Invalid encoded image") def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: -- cgit v1.2.1 From 42a70d74771e8920f658e741679768ed145dd76a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 10:05:45 +0300 Subject: repair sdapi/v1/upscalers returning bogus results --- modules/api/api.py | 16 +++++++++------- modules/api/models.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index e6e31e41..da2a5daf 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -375,13 +375,15 @@ class Api: return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): - upscalers = [] - - for upscaler in shared.sd_upscalers: - u = upscaler.scaler - upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url}) - - return upscalers + return [ + { + "name": upscaler.name, + "model_name": upscaler.scaler.model_name, + "model_path": upscaler.data_path, + "scale": upscaler.scale, + } + for upscaler in shared.sd_upscalers + ] def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] diff --git a/modules/api/models.py b/modules/api/models.py index 1eb1fcf1..e562ab54 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -219,7 +219,7 @@ class UpscalerItem(BaseModel): name: str = Field(title="Name") model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") - model_url: Optional[str] = Field(title="URL") + scale: Optional[float] = Field(title="Scale") class SDModelItem(BaseModel): title: str = Field(title="Title") -- cgit v1.2.1 From 602a1864b05075ca4283986e6f5c7d5bce864e11 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 10:09:30 +0300 Subject: also return the removed field to sdapi/v1/upscalers because someone might have relied on it existing --- modules/api/api.py | 1 + modules/api/models.py | 1 + 2 files changed, 2 insertions(+) (limited to 'modules') diff --git a/modules/api/api.py b/modules/api/api.py index da2a5daf..25c65e57 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -380,6 +380,7 @@ class Api: "name": upscaler.name, "model_name": upscaler.scaler.model_name, "model_path": upscaler.data_path, + "model_url": None, "scale": upscaler.scale, } for upscaler in shared.sd_upscalers diff --git a/modules/api/models.py b/modules/api/models.py index e562ab54..805bd8f7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -219,6 +219,7 @@ class UpscalerItem(BaseModel): name: str = Field(title="Name") model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") + model_url: Optional[str] = Field(title="URL") scale: Optional[float] = Field(title="Scale") class SDModelItem(BaseModel): -- cgit v1.2.1