diff options
Diffstat (limited to 'modules')
-rwxr-xr-x | modules/processing.py | 22 | ||||
-rw-r--r-- | modules/processing_scripts/refiner.py | 55 | ||||
-rw-r--r-- | modules/scripts.py | 24 | ||||
-rw-r--r-- | modules/sd_models.py | 3 | ||||
-rw-r--r-- | modules/sd_samplers_cfg_denoiser.py | 6 | ||||
-rw-r--r-- | modules/sd_samplers_common.py | 40 | ||||
-rw-r--r-- | modules/sd_samplers_kdiffusion.py | 3 | ||||
-rw-r--r-- | modules/shared_items.py | 4 | ||||
-rw-r--r-- | modules/shared_options.py | 2 | ||||
-rw-r--r-- | modules/ui.py | 58 | ||||
-rw-r--r-- | modules/ui_components.py | 18 |
11 files changed, 168 insertions, 67 deletions
diff --git a/modules/processing.py b/modules/processing.py index 131c4c3c..5996cbac 100755 --- a/modules/processing.py +++ b/modules/processing.py @@ -373,9 +373,10 @@ class StableDiffusionProcessing: negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
- self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
- self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
+ self.step_multiplier = total_steps // self.steps
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
def get_conds(self):
return self.c, self.uc
@@ -579,8 +580,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
- "VAE hash": sd_vae.get_loaded_vae_hash() if opts.add_model_hash_to_info else None,
- "VAE": sd_vae.get_loaded_vae_name() if opts.add_model_name_to_info else None,
+ "VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
+ "VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -669,6 +670,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.tiling is None:
p.tiling = opts.tiling
+ p.loaded_vae_name = sd_vae.get_loaded_vae_name()
+ p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()
+
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
@@ -1188,8 +1192,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+ sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
+ steps = self.hr_second_pass_steps or self.steps
+ total_steps = sampler_config.total_steps(steps) if sampler_config else steps
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
def setup_conds(self):
super().setup_conds()
diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py new file mode 100644 index 00000000..5a82991a --- /dev/null +++ b/modules/processing_scripts/refiner.py @@ -0,0 +1,55 @@ +import gradio as gr
+
+from modules import scripts, sd_models
+from modules.ui_common import create_refresh_button
+from modules.ui_components import InputAccordion
+
+
+class ScriptRefiner(scripts.Script):
+ section = "accordions"
+ create_group = False
+
+ def __init__(self):
+ pass
+
+ def title(self):
+ return "Refiner"
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def ui(self, is_img2img):
+ with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
+ with gr.Row():
+ refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
+ create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
+
+ refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
+
+ def lookup_checkpoint(title):
+ info = sd_models.get_closet_checkpoint_match(title)
+ return None if info is None else info.title
+
+ self.infotext_fields = [
+ (enable_refiner, lambda d: 'Refiner' in d),
+ (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
+ (refiner_switch_at, 'Refiner switch at'),
+ ]
+
+ return enable_refiner, refiner_checkpoint, refiner_switch_at
+
+ def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
+ # the actual implementation is in sd_samplers_common.py, apply_refiner
+
+ p.refiner_checkpoint_info = None
+ p.refiner_switch_at = None
+
+ if not enable_refiner or refiner_checkpoint in (None, "", "None"):
+ return
+
+ refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
+ if refiner_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')
+
+ p.refiner_checkpoint_info = refiner_checkpoint_info
+ p.refiner_switch_at = refiner_switch_at
diff --git a/modules/scripts.py b/modules/scripts.py index f7d060aa..51da732a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -37,7 +37,10 @@ class Script: is_img2img = False
group = None
- """A gr.Group component that has all script's UI inside it"""
+ """A gr.Group component that has all script's UI inside it."""
+
+ create_group = True
+ """If False, for alwayson scripts, a group component will not be created."""
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@@ -232,6 +235,7 @@ class Script: """
pass
+
current_basedir = paths.script_path
@@ -250,7 +254,7 @@ postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
-def list_scripts(scriptdirname, extension):
+def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = []
basedir = os.path.join(paths.script_path, scriptdirname)
@@ -258,8 +262,9 @@ def list_scripts(scriptdirname, extension): for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
- for ext in extensions.active():
- scripts_list += ext.list_files(scriptdirname, extension)
+ if include_extensions:
+ for ext in extensions.active():
+ scripts_list += ext.list_files(scriptdirname, extension)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@@ -288,7 +293,7 @@ def load_scripts(): postprocessing_scripts_data.clear()
script_callbacks.clear_callbacks()
- scripts_list = list_scripts("scripts", ".py")
+ scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
syspath = sys.path
@@ -429,10 +434,13 @@ class ScriptRunner: if script.alwayson and script.section != section:
continue
- with gr.Group(visible=script.alwayson) as group:
- self.create_script_ui(script)
+ if script.create_group:
+ with gr.Group(visible=script.alwayson) as group:
+ self.create_script_ui(script)
- script.group = group
+ script.group = group
+ else:
+ self.create_script_ui(script)
def prepare_ui(self):
self.inputs = [None]
diff --git a/modules/sd_models.py b/modules/sd_models.py index a178adca..f6fbdcd6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") def get_closet_checkpoint_match(search_string):
+ if not search_string:
+ return None
+
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
return checkpoint_info
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index a532e013..113425b2 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -45,6 +45,11 @@ class CFGDenoiser(torch.nn.Module): self.nmask = None
self.init_latent = None
self.steps = None
+ """number of steps as specified by user in UI"""
+
+ self.total_steps = None
+ """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
+
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
@@ -56,7 +61,6 @@ class CFGDenoiser(torch.nn.Module): def inner_model(self):
raise NotImplementedError()
-
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 35c4d657..85f3c7e0 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s from modules.shared import opts, state
import k_diffusion.sampling
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+
+class SamplerData(SamplerDataTuple):
+ def total_steps(self, steps):
+ if self.options.get("second_order", False):
+ steps = steps * 2
+
+ return steps
def setup_img2img_steps(p, steps=None):
@@ -131,31 +140,26 @@ def replace_torchsde_browinan(): replace_torchsde_browinan()
-def apply_refiner(sampler):
- completed_ratio = sampler.step / sampler.steps
+def apply_refiner(cfg_denoiser):
+ completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
+ refiner_switch_at = cfg_denoiser.p.refiner_switch_at
+ refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
- if completed_ratio <= shared.opts.sd_refiner_switch_at:
+ if refiner_switch_at is not None and completed_ratio <= refiner_switch_at:
return False
- if shared.opts.sd_refiner_checkpoint == "None":
+ if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False
- if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
- return False
-
- refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
- if refiner_checkpoint_info is None:
- raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
-
- sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
- sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
+ cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
+ cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
- sampler.p.setup_conds()
- sampler.update_inner_model()
+ cfg_denoiser.p.setup_conds()
+ cfg_denoiser.update_inner_model()
return True
@@ -192,7 +196,7 @@ class Sampler: self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.config = None # set by the function calling the constructor
+ self.config: SamplerData = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
self.s_churn = 0.0
@@ -208,6 +212,7 @@ class Sampler: self.p = None
self.model_wrap_cfg = None
self.sampler_extra_args = None
+ self.options = {}
def callback_state(self, d):
step = d['i']
@@ -220,6 +225,7 @@ class Sampler: def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps
+ self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
state.sampling_steps = steps
state.sampling_step = 0
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index d10fe12e..1f8e9c4b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -64,9 +64,10 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): class KDiffusionSampler(sd_samplers_common.Sampler):
- def __init__(self, funcname, sd_model):
+ def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname)
+ self.options = options or {}
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
diff --git a/modules/shared_items.py b/modules/shared_items.py index e4ec40a8..754166d2 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -69,8 +69,8 @@ def reload_hypernetworks(): ui_reorder_categories_builtin_items = [
"inpaint",
"sampler",
+ "accordions",
"checkboxes",
- "hires_fix",
"dimensions",
"cfg",
"seed",
@@ -86,7 +86,7 @@ def ui_reorder_categories(): sections = {}
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
- if isinstance(script.section, str):
+ if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
sections[script.section] = 1
yield from sections
diff --git a/modules/shared_options.py b/modules/shared_options.py index 1e5b64ea..9ae51f18 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -140,8 +140,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
- "sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"),
- "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
diff --git a/modules/ui.py b/modules/ui.py index 05292734..3321b94d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -438,35 +438,38 @@ def create_ui(): with FormRow(elem_classes="checkboxes-row", variant="compact"):
pass
- elif category == "hires_fix":
- with InputAccordion(False, label="Hires. fix") as enable_hr:
- with enable_hr.extra():
- hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
+ elif category == "accordions":
+ with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
+ with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
+ with enable_hr.extra():
+ hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
- with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
- hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
- hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
- denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
+ with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
+ hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
+ hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
+ denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
- with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
- hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
- hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
- hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
+ with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
+ hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
+ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
+ hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
- with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
+ with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
- hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
- create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
+ hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
+ create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
- hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
+ hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
- with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
- with gr.Column(scale=80):
- with gr.Row():
- hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
- with gr.Column(scale=80):
- with gr.Row():
- hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
+ with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
+ with gr.Column(scale=80):
+ with gr.Row():
+ hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
+ with gr.Column(scale=80):
+ with gr.Row():
+ hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
+
+ scripts.scripts_txt2img.setup_ui_for_section(category)
elif category == "batch":
if not opts.dimensions_and_batch_together:
@@ -482,7 +485,7 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = scripts.scripts_txt2img.setup_ui()
- else:
+ if category not in {"accordions"}:
scripts.scripts_txt2img.setup_ui_for_section(category)
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
@@ -794,6 +797,10 @@ def create_ui(): with FormRow(elem_classes="checkboxes-row", variant="compact"):
pass
+ elif category == "accordions":
+ with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
+ scripts.scripts_img2img.setup_ui_for_section(category)
+
elif category == "batch":
if not opts.dimensions_and_batch_together:
with FormRow(elem_id="img2img_column_batch"):
@@ -836,7 +843,8 @@ def create_ui(): inputs=[],
outputs=[inpaint_controls, mask_alpha],
)
- else:
+
+ if category not in {"accordions"}:
scripts.scripts_img2img.setup_ui_for_section(category)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
diff --git a/modules/ui_components.py b/modules/ui_components.py index bfe2fbd9..d08b2b99 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -87,13 +87,23 @@ class InputAccordion(gr.Checkbox): self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
InputAccordion.global_index += 1
- kwargs['elem_id'] = self.accordion_id + "-checkbox"
- kwargs['visible'] = False
- super().__init__(value, **kwargs)
+ kwargs_checkbox = {
+ **kwargs,
+ "elem_id": f"{self.accordion_id}-checkbox",
+ "visible": False,
+ }
+ super().__init__(value, **kwargs_checkbox)
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
- self.accordion = gr.Accordion(kwargs.get('label', 'Accordion'), open=value, elem_id=self.accordion_id, elem_classes=['input-accordion'])
+ kwargs_accordion = {
+ **kwargs,
+ "elem_id": self.accordion_id,
+ "label": kwargs.get('label', 'Accordion'),
+ "elem_classes": ['input-accordion'],
+ "open": value,
+ }
+ self.accordion = gr.Accordion(**kwargs_accordion)
def extra(self):
"""Allows you to put something into the label of the accordion.
|