diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 5 | ||||
-rw-r--r-- | modules/esrgan_model.py | 17 | ||||
-rw-r--r-- | modules/img2img.py | 5 | ||||
-rw-r--r-- | modules/modelloader.py | 3 | ||||
-rw-r--r-- | modules/processing.py | 33 | ||||
-rw-r--r-- | modules/script_callbacks.py | 41 | ||||
-rw-r--r-- | modules/sd_samplers.py | 11 | ||||
-rw-r--r-- | modules/shared.py | 20 | ||||
-rw-r--r-- | modules/txt2img.py | 2 | ||||
-rw-r--r-- | modules/ui.py | 20 | ||||
-rw-r--r-- | modules/upscaler.py | 17 |
11 files changed, 138 insertions, 36 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index bb87d795..71c9c160 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -5,10 +5,9 @@ import uvicorn from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, HTTPException import modules.shared as shared -from modules import devices from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers +from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid from modules.extras import run_extras, run_pnginfo @@ -179,6 +178,8 @@ class Api: progress = min(progress, 1) + shared.state.set_current_image() + current_image = None if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a13cf6ac..c61669b4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -50,6 +50,7 @@ def mod2normal(state_dict): def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
+ re8x = 0
crt_net = {}
items = []
for k, v in state_dict.items():
@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23): crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
- crt_net['model.8.weight'] = state_dict['conv_hr.weight']
- crt_net['model.8.bias'] = state_dict['conv_hr.bias']
- crt_net['model.10.weight'] = state_dict['conv_last.weight']
- crt_net['model.10.bias'] = state_dict['conv_last.bias']
+
+ if 'conv_up3.weight' in state_dict:
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
+ re8x = 3
+ crt_net['model.9.weight'] = state_dict['conv_up3.weight']
+ crt_net['model.9.bias'] = state_dict['conv_up3.bias']
+
+ crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
+ crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
+ crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
+ crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
+
state_dict = crt_net
return state_dict
diff --git a/modules/img2img.py b/modules/img2img.py index 35c5df9b..be9f3653 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -81,7 +81,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro mask = None
# Use the EXIF orientation of photos taken by smartphones.
- image = ImageOps.exif_transpose(image)
+ if image is not None:
+ image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
@@ -137,6 +138,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro if processed is None:
processed = process_images(p)
+ p.close()
+
shared.total_tqdm.clear()
generation_info_js = processed.js()
diff --git a/modules/modelloader.py b/modules/modelloader.py index b0f2f33d..e4a6f8ac 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -85,6 +85,9 @@ def cleanup_models(): src_path = os.path.join(root_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN") move_files(src_path, dest_path) + src_path = os.path.join(models_path, "BSRGAN") + dest_path = os.path.join(models_path, "ESRGAN") + move_files(src_path, dest_path, ".pth") src_path = os.path.join(root_path, "gfpgan") dest_path = os.path.join(models_path, "GFPGAN") move_files(src_path, dest_path) diff --git a/modules/processing.py b/modules/processing.py index 57d3a523..3a364b5f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -199,9 +199,13 @@ class StableDiffusionProcessing(): def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
+ def close(self):
+ self.sd_model = None
+ self.sampler = None
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
@@ -517,7 +521,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
@@ -597,9 +601,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None:
p.scripts.postprocess(p, res)
- p.sd_model = None
- p.sampler = None
-
return res
@@ -648,7 +649,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
@@ -661,9 +662,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
+ """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
+ def save_intermediate(image, index):
+ if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+ return
+
+ if not isinstance(image, Image.Image):
+ image = sd_samplers.sample_to_image(image, index)
+
+ images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
+
if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+ for i in range(samples.shape[0]):
+ save_intermediate(samples, i)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -673,6 +686,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
+
+ save_intermediate(image, i)
+
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
@@ -830,8 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
-
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
@@ -842,4 +857,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): del x
devices.torch_gc()
- return samples
\ No newline at end of file + return samples
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index ce264690..c28e220e 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,6 +2,7 @@ import sys import traceback
from collections import namedtuple
import inspect
+from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
@@ -26,6 +27,24 @@ class ImageSaveParams: """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+class CFGDenoiserParams:
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.image_cond = image_cond
+ """Conditioning image"""
+
+ self.sigma = sigma
+ """Current sigma noise step value"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = []
@@ -33,6 +52,7 @@ callbacks_ui_tabs = [] callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = []
+callbacks_cfg_denoiser = []
def clear_callbacks():
@@ -41,9 +61,9 @@ def clear_callbacks(): callbacks_ui_settings.clear()
callbacks_before_image_saved.clear()
callbacks_image_saved.clear()
+ callbacks_cfg_denoiser.clear()
-
-def app_started_callback(demo: Blocks, app: FastAPI):
+def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callbacks_app_started:
try:
c.callback(demo, app)
@@ -95,6 +115,14 @@ def image_saved_callback(params: ImageSaveParams): report_exception(c, 'image_saved_callback')
+def cfg_denoiser_callback(params: CFGDenoiserParams):
+ for c in callbacks_cfg_denoiser:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_denoiser_callback')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -147,3 +175,12 @@ def on_image_saved(callback): - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callbacks_image_saved, callback)
+
+
+def on_cfg_denoiser(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+ The callback is called with one argument:
+ - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callbacks_cfg_denoiser, callback)
+
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8772db56..c7c414ef 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -12,6 +12,7 @@ from modules import prompt_parser, devices, processing, images from modules.shared import opts, cmd_opts, state
import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -92,8 +93,8 @@ def single_sample_to_image(sample): return Image.fromarray(x_sample)
-def sample_to_image(samples):
- return single_sample_to_image(samples[0])
+def sample_to_image(samples, index=0):
+ return single_sample_to_image(samples[index])
def samples_to_image_grid(samples):
@@ -280,6 +281,12 @@ class CFGDenoiser(torch.nn.Module): image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
diff --git a/modules/shared.py b/modules/shared.py index cbef5c43..d8e99f85 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,6 +4,7 @@ import json import os
import sys
from collections import OrderedDict
+import time
import gradio as gr
import tqdm
@@ -135,6 +136,7 @@ class State: current_image = None
current_image_sampling_step = 0
textinfo = None
+ time_start = None
need_restart = False
def skip(self):
@@ -172,6 +174,7 @@ class State: self.skipped = False
self.interrupted = False
self.textinfo = None
+ self.time_start = time.time()
devices.torch_gc()
@@ -181,6 +184,20 @@ class State: devices.torch_gc()
+ """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
+ def set_current_image(self):
+ if not parallel_processing_allowed:
+ return
+
+ if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None:
+ if opts.show_progress_grid:
+ self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
+ else:
+ self.current_image = sd_samplers.sample_to_image(self.current_latent)
+
+ self.current_image_sampling_step = self.sampling_step
+
+
state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
@@ -238,6 +255,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
+ "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
+ "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
@@ -305,7 +324,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
diff --git a/modules/txt2img.py b/modules/txt2img.py index c9d5a090..8e4e8677 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: if processed is None:
processed = process_images(p)
+ p.close()
+
shared.total_tqdm.clear()
generation_info_js = processed.js()
diff --git a/modules/ui.py b/modules/ui.py index 2c15abb7..2609857e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -277,15 +277,7 @@ def check_progress_call(id_part): preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
- if shared.parallel_processing_allowed:
-
- if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
- if opts.show_progress_grid:
- shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
- else:
- shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
- shared.state.current_image_sampling_step = shared.state.sampling_step
-
+ shared.state.set_current_image()
image = shared.state.current_image
if image is None:
@@ -671,6 +663,8 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img
import modules.txt2img
+ reload_javascript()
+
parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
@@ -1060,7 +1054,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'):
- upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
+ upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
with gr.TabItem('Scale to'):
with gr.Group():
with gr.Row():
@@ -1570,8 +1564,7 @@ def create_ui(wrap_gradio_gpu_call): reload_script_bodies.click(
fn=reload_scripts,
inputs=[],
- outputs=[],
- _js='function(){}'
+ outputs=[]
)
def request_restart():
@@ -1583,7 +1576,7 @@ def create_ui(wrap_gradio_gpu_call): fn=request_restart,
inputs=[],
outputs=[],
- _js='function(){restart_reload()}'
+ _js='restart_reload'
)
if column is not None:
@@ -1782,4 +1775,3 @@ def load_javascript(raw_response): reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
-reload_javascript()
diff --git a/modules/upscaler.py b/modules/upscaler.py index 6ab2fb40..83fde7ca 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -10,6 +10,7 @@ import modules.shared from modules import modelloader, shared LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) from modules.paths import models_path @@ -57,7 +58,7 @@ class Upscaler: dest_w = img.width * scale dest_h = img.height * scale for i in range(3): - if img.width >= dest_w and img.height >= dest_h: + if img.width > dest_w and img.height > dest_h: break img = self.do_upscale(img, selected_model) if img.width != dest_w or img.height != dest_h: @@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler): self.name = "Lanczos" self.scalers = [UpscalerData("Lanczos", None, self)] + +class UpscalerNearest(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Nearest" + self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file |