From b1707553cf70d74fad08c62cfca5a2bdfee936b7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 9 Sep 2022 17:54:04 +0300 Subject: added resize seeds and variation seeds features --- modules/images.py | 2 +- modules/img2img.py | 6 +++- modules/processing.py | 93 +++++++++++++++++++++++++++++++++++++++++---------- modules/shared.py | 1 - modules/txt2img.py | 6 +++- modules/ui.py | 40 ++++++++++++++++++++-- 6 files changed, 125 insertions(+), 23 deletions(-) (limited to 'modules') diff --git a/modules/images.py b/modules/images.py index 3399887d..064849d3 100644 --- a/modules/images.py +++ b/modules/images.py @@ -136,7 +136,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): color_active = (0, 0, 0) color_inactive = (153, 153, 153) - pad_left = width * 3 // 4 if len(ver_texts) > 0 else 0 + pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4 cols = im.width // width rows = im.height // height diff --git a/modules/img2img.py b/modules/img2img.py index 00bd626c..54023df5 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -11,7 +11,7 @@ from modules.ui import plaintext_to_html import modules.images as images import modules.scripts -def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, denoising_strength_change_factor: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): +def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, denoising_strength_change_factor: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): is_inpaint = mode == 1 is_loopback = mode == 2 is_upscale = mode == 3 @@ -34,6 +34,10 @@ def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, ste prompt=prompt, negative_prompt=negative_prompt, seed=seed, + subseed=subseed, + subseed_strength=subseed_strength, + seed_resize_from_h=seed_resize_from_h, + seed_resize_from_w=seed_resize_from_w, sampler_index=sampler_index, batch_size=batch_size, n_iter=n_iter, diff --git a/modules/processing.py b/modules/processing.py index d4c4cfad..b91ade17 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -29,7 +29,7 @@ def torch_gc(): class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -37,6 +37,10 @@ class StableDiffusionProcessing: self.prompt_for_display: str = None self.negative_prompt: str = (negative_prompt or "") self.seed: int = seed + self.subseed: int = subseed + self.subseed_strength: float = subseed_strength + self.seed_resize_from_h: int = seed_resize_from_h + self.seed_resize_from_w: int = seed_resize_from_w self.sampler_index: int = sampler_index self.batch_size: int = batch_size self.n_iter: int = n_iter @@ -84,23 +88,67 @@ class Processed: return json.dumps(obj) +# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 +def slerp(val, low, high): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res -def create_random_tensors(shape, seeds): + +def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0): xs = [] - for seed in seeds: - torch.manual_seed(seed) + for i, seed in enumerate(seeds): + noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) + + subnoise = None + if subseeds is not None: + subseed = 0 if i >= len(subseeds) else subseeds[i] + torch.manual_seed(subseed) + subnoise = torch.randn(noise_shape, device=shared.device) # randn results depend on device; gpu and cpu get different results for same seed; # the way I see it, it's better to do this on CPU, so that everyone gets same result; - # but the original script had it like this so I do not dare change it for now because + # but the original script had it like this, so I do not dare change it for now because # it will break everyone's seeds. - xs.append(torch.randn(shape, device=shared.device)) - x = torch.stack(xs) + torch.manual_seed(seed) + noise = torch.randn(noise_shape, device=shared.device) + + if subnoise is not None: + #noise = subnoise * subseed_strength + noise * (1 - subseed_strength) + noise = slerp(subseed_strength, noise, subnoise) + + if noise_shape != shape: + #noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze() + # noise_shape = (64, 80) + # shape = (64, 72) + + torch.manual_seed(seed) + x = torch.randn(shape, device=shared.device) + dx = (shape[2] - noise_shape[2]) // 2 # -4 + dy = (shape[1] - noise_shape[1]) // 2 + w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx + h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy + tx = 0 if dx < 0 else dx + ty = 0 if dy < 0 else dy + dx = max(-dx, 0) + dy = max(-dy, 0) + + x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] + noise = x + + + + xs.append(noise) + x = torch.stack(xs).to(shared.device) return x -def set_seed(seed): - return int(random.randrange(4294967294)) if seed is None or seed == -1 else seed +def fix_seed(p): + p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed + p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed def process_images(p: StableDiffusionProcessing) -> Processed: @@ -111,7 +159,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: assert p.prompt is not None torch_gc() - seed = set_seed(p.seed) + fix_seed(p) os.makedirs(p.outpath_samples, exist_ok=True) os.makedirs(p.outpath_grids, exist_ok=True) @@ -125,20 +173,31 @@ def process_images(p: StableDiffusionProcessing) -> Processed: else: all_prompts = p.batch_size * p.n_iter * [prompt] - if type(seed) == list: - all_seeds = seed + if type(p.seed) == list: + all_seeds = int(p.seed) else: - all_seeds = [int(seed + x) for x in range(len(all_prompts))] + all_seeds = [int(p.seed + x) for x in range(len(all_prompts))] + + if type(p.subseed) == list: + all_subseeds = p.subseed + else: + all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))] def infotext(iteration=0, position_in_batch=0): + index = position_in_batch + iteration * p.batch_size + generation_params = { "Steps": p.steps, "Sampler": samplers[p.sampler_index].name, "CFG scale": p.cfg_scale, - "Seed": all_seeds[position_in_batch + iteration * p.batch_size], + "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), + "Size": f"{p.width}x{p.height}", "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), + "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), + "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), + "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), } if p.extra_generation_params is not None: @@ -174,7 +233,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: comments += model_hijack.comments # we manually generate all input noises because each one should have a specific seed - x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds) + x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=all_subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" @@ -231,10 +290,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: output_images.insert(0, grid) if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", seed, all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) + images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) torch_gc() - return Processed(p, output_images, seed, infotext()) + return Processed(p, output_images, all_seeds[0], infotext()) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): diff --git a/modules/shared.py b/modules/shared.py index 280c07ff..e577332d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -62,7 +62,6 @@ class State: current_image = None current_image_sampling_step = 0 - def interrupt(self): self.interrupted = True diff --git a/modules/txt2img.py b/modules/txt2img.py index 410a7a7b..606421ea 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args): +def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -14,6 +14,10 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, r prompt=prompt, negative_prompt=negative_prompt, seed=seed, + subseed=subseed, + subseed_strength=subseed_strength, + seed_resize_from_h=seed_resize_from_h, + seed_resize_from_w=seed_resize_from_w, sampler_index=sampler_index, batch_size=batch_size, n_iter=n_iter, diff --git a/modules/ui.py b/modules/ui.py index a2ff660a..6784de57 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -192,6 +192,40 @@ def visit(x, func, path=""): func(path + "/" + str(x.label), x) +def create_seed_inputs(): + with gr.Row(): + seed = gr.Number(label='Seed', value=-1) + subseed = gr.Number(label='Variation seed', value=-1, visible=False) + seed_checkbox = gr.Checkbox(label="Extra", elem_id="subseed_show", value=False) + + with gr.Row(): + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, visible=False) + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0, visible=False) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0, visible=False) + + def change_visiblity(show): + + return { + subseed: gr_show(show), + subseed_strength: gr_show(show), + seed_resize_from_h: gr_show(show), + seed_resize_from_w: gr_show(show), + } + + seed_checkbox.change( + change_visiblity, + inputs=[seed_checkbox], + outputs=[ + subseed, + subseed_strength, + seed_resize_from_h, + seed_resize_from_w + ] + ) + + return seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w + + def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Row(): @@ -220,7 +254,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - seed = gr.Number(label='Seed', value=-1) + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs() with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) @@ -260,6 +294,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): batch_size, cfg_scale, seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, height, width, ] + custom_inputs, @@ -357,7 +392,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - seed = gr.Number(label='Seed', value=-1) + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs() with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) @@ -440,6 +475,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): denoising_strength, denoising_strength_change_factor, seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, height, width, resize_mode, -- cgit v1.2.1