diff options
author | JashoBell <JoshuaDB@gmail.com> | 2022-09-17 11:16:35 -0700 |
---|---|---|
committer | JashoBell <JoshuaDB@gmail.com> | 2022-09-17 11:16:35 -0700 |
commit | d2c7ad2fec09d89d1348d6d40640259b5a02b8ad (patch) | |
tree | 90f4f6695f318aed33dea72cc340c0f5ff628ae8 /scripts/xy_grid.py | |
parent | 5a797a5612924e50d5b60d2aa1eddfae4c3e157e (diff) | |
parent | 23a0ec04c005957091ab35c26c4c31485e75d146 (diff) |
Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diffusion-webui into Base
Diffstat (limited to 'scripts/xy_grid.py')
-rw-r--r-- | scripts/xy_grid.py | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index eccfda87..6a157722 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,9 @@ import gradio as gr from modules import images
from modules.processing import process_images, Processed
from modules.shared import opts, cmd_opts, state
+import modules.shared as shared
import modules.sd_samplers
+import modules.sd_models
import re
@@ -41,6 +43,15 @@ def apply_sampler(p, x, xs): p.sampler_index = sampler_index
+def apply_checkpoint(p, x, xs):
+ applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
+ assert len(applicable) > 0, f'Checkpoint {x} for found'
+
+ info = applicable[0]
+
+ modules.sd_models.reload_model_weights(shared.sd_model, info)
+
+
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@@ -74,15 +85,16 @@ axis_options = [ AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
-def draw_xy_grid(p, xs, ys, x_label, y_label, cell, draw_legend):
+def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
res = []
- ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
- hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
+ ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
+ hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
first_pocessed = None
@@ -206,8 +218,8 @@ class Script(scripts.Script): p,
xs=xs,
ys=ys,
- x_label=lambda x: x_opt.format_value(p, x_opt, x),
- y_label=lambda y: y_opt.format_value(p, y_opt, y),
+ x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
+ y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
cell=cell,
draw_legend=draw_legend
)
@@ -215,4 +227,7 @@ class Script(scripts.Script): if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+ # restore checkpoint in case it was changed by axes
+ modules.sd_models.reload_model_weights(shared.sd_model)
+
return processed
|