aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/batch.py59
-rw-r--r--scripts/sd_upscale.py93
2 files changed, 93 insertions, 59 deletions
diff --git a/scripts/batch.py b/scripts/batch.py
deleted file mode 100644
index 1af4a7bc..00000000
--- a/scripts/batch.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import math
-import os
-import sys
-import traceback
-
-import modules.scripts as scripts
-import gradio as gr
-
-from modules.processing import Processed, process_images
-from PIL import Image
-from modules.shared import opts, cmd_opts, state
-
-
-class Script(scripts.Script):
- def title(self):
- return "Batch processing"
-
- def show(self, is_img2img):
- return is_img2img
-
- def ui(self, is_img2img):
- input_dir = gr.Textbox(label="Input directory", lines=1)
- output_dir = gr.Textbox(label="Output directory", lines=1)
-
- return [input_dir, output_dir]
-
- def run(self, p, input_dir, output_dir):
- images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
-
- batch_count = math.ceil(len(images) / p.batch_size)
- print(f"Will process {len(images)} images in {batch_count} batches.")
-
- p.batch_count = 1
- p.do_not_save_grid = True
- p.do_not_save_samples = True
-
- state.job_count = batch_count
-
- for batch_no in range(batch_count):
- batch_images = []
- for path in images[batch_no*p.batch_size:(batch_no+1)*p.batch_size]:
- try:
- img = Image.open(path)
- batch_images.append((img, path))
- except:
- print(f"Error processing {path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- if len(batch_images) == 0:
- continue
-
- state.job = f"{batch_no} out of {batch_count}: {batch_images[0][1]}"
- p.init_images = [x[0] for x in batch_images]
- proc = process_images(p)
- for image, (_, path) in zip(proc.images, batch_images):
- filename = os.path.basename(path)
- image.save(os.path.join(output_dir, filename))
-
- return Processed(p, [], p.seed, "")
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
new file mode 100644
index 00000000..b87a145b
--- /dev/null
+++ b/scripts/sd_upscale.py
@@ -0,0 +1,93 @@
+import math
+
+import modules.scripts as scripts
+import gradio as gr
+from PIL import Image
+
+from modules import processing, shared, sd_samplers, images, devices
+from modules.processing import Processed
+from modules.shared import opts, cmd_opts, state
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "SD upscale"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image to twice the dimensions; use width and height sliders to set tile size</p>")
+ overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
+ upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False)
+
+ return [info, overlap, upscaler_index]
+
+ def run(self, p, _, overlap, upscaler_index):
+ processing.fix_seed(p)
+ upscaler = shared.sd_upscalers[upscaler_index]
+
+ p.extra_generation_params["SD upscale overlap"] = overlap
+ p.extra_generation_params["SD upscale upscaler"] = upscaler.name
+
+ initial_info = None
+ seed = p.seed
+
+ init_img = p.init_images[0]
+ img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
+
+ devices.torch_gc()
+
+ grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)
+
+ batch_size = p.batch_size
+ upscale_count = p.n_iter
+ p.n_iter = 1
+ p.do_not_save_grid = True
+ p.do_not_save_samples = True
+
+ work = []
+
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ work.append(tiledata[2])
+
+ batch_count = math.ceil(len(work) / batch_size)
+ state.job_count = batch_count * upscale_count
+
+ print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
+
+ result_images = []
+ for n in range(upscale_count):
+ start_seed = seed + n
+ p.seed = start_seed
+
+ work_results = []
+ for i in range(batch_count):
+ p.batch_size = batch_size
+ p.init_images = work[i*batch_size:(i+1)*batch_size]
+
+ state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
+ processed = processing.process_images(p)
+
+ if initial_info is None:
+ initial_info = processed.info
+
+ p.seed = processed.seed + 1
+ work_results += processed.images
+
+ image_index = 0
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
+ image_index += 1
+
+ combined_image = images.combine_grid(grid)
+ result_images.append(combined_image)
+
+ if opts.samples_save:
+ images.save_image(combined_image, p.outpath_samples, "", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
+
+ processed = Processed(p, result_images, seed, initial_info)
+
+ return processed