aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/feature_request.yml2
-rw-r--r--.github/pull_request_template.md (renamed from .github/PULL_REQUEST_TEMPLATE/pull_request_template.md)4
-rw-r--r--.github/workflows/on_pull_request.yaml13
-rw-r--r--.github/workflows/run_tests.yaml10
-rw-r--r--javascript/hints.js10
-rw-r--r--javascript/hires_fix.js25
-rw-r--r--modules/api/api.py48
-rw-r--r--modules/api/models.py4
-rw-r--r--modules/generation_parameters_copypaste.py17
-rw-r--r--modules/hypernetworks/hypernetwork.py9
-rw-r--r--modules/processing.py30
-rw-r--r--modules/sd_hijack.py7
-rw-r--r--modules/sd_hijack_clip.py4
-rw-r--r--modules/sd_vae.py20
-rw-r--r--modules/shared.py2
-rw-r--r--modules/sub_quadratic_attention.py15
-rw-r--r--modules/textual_inversion/dataset.py10
-rw-r--r--modules/textual_inversion/textual_inversion.py202
-rw-r--r--modules/ui.py37
-rw-r--r--scripts/sd_upscale.py2
-rw-r--r--style.css17
-rw-r--r--test/advanced_features/__init__.py0
-rw-r--r--test/advanced_features/extras_test.py29
-rw-r--r--test/advanced_features/txt2img_test.py47
-rw-r--r--test/basic_features/extras_test.py54
-rw-r--r--test/basic_features/img2img_test.py13
-rw-r--r--test/basic_features/txt2img_test.py11
-rw-r--r--test/basic_features/utils_test.py17
-rw-r--r--test/server_poll.py2
-rw-r--r--webui.py3
30 files changed, 443 insertions, 221 deletions
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
index 8ca6e21f..35a88740 100644
--- a/.github/ISSUE_TEMPLATE/feature_request.yml
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -1,7 +1,7 @@
name: Feature request
description: Suggest an idea for this project
title: "[Feature Request]: "
-labels: ["suggestion"]
+labels: ["enhancement"]
body:
- type: checkboxes
diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/pull_request_template.md
index 86009613..69056331 100644
--- a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -18,8 +18,8 @@ More technical discussion about your changes go here, plus anything that a maint
List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
- OS: [e.g. Windows, Linux]
- - Browser [e.g. chrome, safari]
- - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
+ - Browser: [e.g. chrome, safari]
+ - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
**Screenshots or videos of your changes**
diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml
index b097d180..a168be5b 100644
--- a/.github/workflows/on_pull_request.yaml
+++ b/.github/workflows/on_pull_request.yaml
@@ -19,22 +19,19 @@ jobs:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.10
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
python-version: 3.10.6
- - uses: actions/cache@v2
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
- restore-keys: |
- ${{ runner.os }}-pip-
+ cache: pip
+ cache-dependency-path: |
+ **/requirements*txt
- name: Install PyLint
run: |
python -m pip install --upgrade pip
pip install pylint
# This lets PyLint check to see if it can resolve imports
- name: Install dependencies
- run : |
+ run: |
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
python launch.py
- name: Analysing the code with pylint
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index 49dc92bd..be7ffa23 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -14,13 +14,11 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: 3.10.6
- - uses: actions/cache@v3
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
- restore-keys: ${{ runner.os }}-pip-
+ cache: pip
+ cache-dependency-path: |
+ **/requirements*txt
- name: Run tests
- run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
+ run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3
if: always()
diff --git a/javascript/hints.js b/javascript/hints.js
index dda66e09..856e1389 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -4,7 +4,7 @@ titles = {
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
"Sampling method": "Which algorithm to use to produce the image",
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
- "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
+ "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
@@ -74,7 +74,7 @@ titles = {
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Apply style": "Insert selected styles into prompt fields",
- "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
+ "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style uses that as a placeholder for your prompt when you use the style in the future.",
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
@@ -92,12 +92,12 @@ titles = {
"Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M",
- "Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
+ "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
- "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
- "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.",
+ "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resolution and lower quality.",
+ "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resolution and extremely low quality.",
"Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition",
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js
new file mode 100644
index 00000000..07fba549
--- /dev/null
+++ b/javascript/hires_fix.js
@@ -0,0 +1,25 @@
+
+function setInactive(elem, inactive){
+ console.log(elem)
+ if(inactive){
+ elem.classList.add('inactive')
+ } else{
+ elem.classList.remove('inactive')
+ }
+}
+
+function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
+ console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y)
+
+ hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
+ hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
+ hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
+
+ gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
+
+ setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
+ setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
+ setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
+
+ return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
+}
diff --git a/modules/api/api.py b/modules/api/api.py
index 1c121ff0..6c564ad8 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras
@@ -28,8 +28,13 @@ def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
+def script_name_to_index(name, scripts):
+ try:
+ return [script.title().lower() for script in scripts].index(name.lower())
+ except:
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
@@ -144,7 +149,21 @@ class Api:
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
+ def get_script(self, script_name, script_runner):
+ if script_name is None:
+ return None, None
+
+ if not script_runner.scripts:
+ script_runner.initialize_scripts(False)
+ ui.create_ui()
+
+ script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
+ script = script_runner.selectable_scripts[script_idx]
+ return script, script_idx
+
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+ script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
+
populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True,
@@ -154,14 +173,22 @@ class Api:
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
+ args = vars(populate)
+ args.pop('script_name', None)
+
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
+ p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
shared.state.begin()
- processed = process_images(p)
+ if script is not None:
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+ p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args)
+ else:
+ processed = process_images(p)
shared.state.end()
-
b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@@ -171,6 +198,8 @@ class Api:
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
+ script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
+
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
@@ -187,13 +216,20 @@ class Api:
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
+ args.pop('script_name', None)
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [decode_base64_to_image(x) for x in init_images]
shared.state.begin()
- processed = process_images(p)
+ if script is not None:
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+ p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args)
+ else:
+ processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images))
diff --git a/modules/api/models.py b/modules/api/models.py
index 49bf1e7a..880edde6 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -100,13 +100,13 @@ class PydanticModelGenerator:
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
- [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
+ [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
).generate_model()
class TextToImageResponse(BaseModel):
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 12a9de3d..f7f68b67 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res):
firstpass_width = res.get('First pass size-1', None)
firstpass_height = res.get('First pass size-2', None)
+ if shared.opts.use_old_hires_fix_width_height:
+ hires_width = int(res.get("Hires resize-1", None))
+ hires_height = int(res.get("Hires resize-2", None))
+
+ if hires_width is not None and hires_height is not None:
+ res['Size-1'] = hires_width
+ res['Size-2'] = hires_height
+ return
+
if firstpass_width is None or firstpass_height is None:
return
@@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res):
height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0:
- # old algorithm for auto-calculating first pass size
- desired_pixel_count = 512 * 512
- actual_pixel_count = width * height
- scale = math.sqrt(desired_pixel_count / actual_pixel_count)
- firstpass_width = math.ceil(scale * width / 64) * 64
- firstpass_height = math.ceil(scale * height / 64) * 64
+ from modules import processing
+ firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width
res['Size-2'] = firstpass_height
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index b0cfbe71..ea3f1db9 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -24,6 +24,7 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
+
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
- textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
+ textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
+ template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
@@ -456,7 +459,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
diff --git a/modules/processing.py b/modules/processing.py
index 82157bc9..f04a0e1e 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -98,7 +98,7 @@ class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -149,7 +149,7 @@ class StableDiffusionProcessing():
self.seed_resize_from_w = 0
self.scripts = None
- self.script_args = None
+ self.script_args = script_args
self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None
@@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
return res
+def old_hires_fix_first_pass_dimensions(width, height):
+ """old algorithm for auto-calculating first pass size"""
+
+ desired_pixel_count = 512 * 512
+ actual_pixel_count = width * height
+ scale = math.sqrt(desired_pixel_count / actual_pixel_count)
+ width = math.ceil(scale * width / 64) * 64
+ height = math.ceil(scale * height / 64) * 64
+
+ return width, height
+
+
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
@@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_upscale_to_y = hr_resize_y
if firstphase_width != 0 or firstphase_height != 0:
- print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr)
- self.hr_scale = self.width / firstphase_width
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
self.width = firstphase_width
self.height = firstphase_height
self.truncate_x = 0
self.truncate_y = 0
+ self.applied_old_hires_behavior_to = None
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
+ self.hr_resize_x = self.width
+ self.hr_resize_y = self.height
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
+
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
+ self.applied_old_hires_behavior_to = (self.width, self.height)
+
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
self.extra_generation_params["Hires upscale"] = self.hr_scale
self.hr_upscale_to_x = int(self.width * self.hr_scale)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index cfdb09d6..6b0d95af 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
- def hijack(self, m):
+ def __init__(self):
+ self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
+ def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
@@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
-
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 5520c9b2..852afc66 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean()
- z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
- z *= original_mean / new_mean
+ z = z * (original_mean / new_mean)
return z
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index ac71d62d..0a49daa1 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,8 +1,9 @@
import torch
+import safetensors.torch
import os
import collections
from collections import namedtuple
-from modules import shared, devices, script_callbacks
+from modules import shared, devices, script_callbacks, sd_models
from modules.paths import models_path
import glob
from copy import deepcopy
@@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
+ *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
- *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
+ *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
+ *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
@@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print(f"Using VAE found similar to selected model: {vae_file}")
+ # if still not found, try look for ".vae.safetensors" beside model
+ if vae_file == "auto":
+ vae_file_try = model_path + ".vae.safetensors"
+ if os.path.isfile(vae_file_try):
+ vae_file = vae_file_try
+ print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
@@ -163,8 +172,9 @@ def load_vae(model, vae_file=None):
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}")
store_base_vae(model)
- vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
+
+ vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
+ vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
_load_vae_dict(model, vae_dict_1)
if cache_enabled:
@@ -195,10 +205,12 @@ def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
+
def clear_loaded_vae():
global loaded_vae_file
loaded_vae_file = None
+
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
diff --git a/modules/shared.py b/modules/shared.py
index a6712dae..aa37c8ce 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
+parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
@@ -398,6 +399,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+ "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index fea7aaac..55052815 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -15,7 +15,8 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
-from typing import Optional, NamedTuple, Protocol, List
+from typing import Optional, NamedTuple, List
+
def narrow_trunc(
input: Tensor,
@@ -25,12 +26,14 @@ def narrow_trunc(
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
+
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
-class SummarizeChunk(Protocol):
+
+class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
@@ -38,7 +41,8 @@ class SummarizeChunk(Protocol):
value: Tensor,
) -> AttnChunk: ...
-class ComputeQueryChunkAttn(Protocol):
+
+class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
@@ -46,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
value: Tensor,
) -> Tensor: ...
+
def _summarize_chunk(
query: Tensor,
key: Tensor,
@@ -66,6 +71,7 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
+
def _query_chunk_attention(
query: Tensor,
key: Tensor,
@@ -106,6 +112,7 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
+
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
@@ -125,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice
+
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
+
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 88d68c76..fa48708e 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -28,13 +28,11 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
- self.width = width
- self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
@@ -47,10 +45,10 @@ class PersonalizedBase(Dataset):
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
+ assert batch_size == 1 or not varsize, 'variable img size must have batch size 1'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
@@ -59,7 +57,9 @@ class PersonalizedBase(Dataset):
if shared.state.interrupted:
raise Exception("interrupted")
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB')
+ if not varsize:
+ image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
continue
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 45882ed6..5420903f 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
import inspect
+from collections import namedtuple
import torch
import tqdm
@@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
- insert_image_data_embed, extract_image_data_embed,
- caption_image_overlay)
+from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
from modules.textual_inversion.logging import save_settings_to_file
+TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
+textual_inversion_templates = {}
+
+
+def list_textual_inversion_templates():
+ textual_inversion_templates.clear()
+
+ for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for fn in fns:
+ path = os.path.join(root, fn)
+
+ textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
+
+ return textual_inversion_templates
+
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
@@ -66,17 +81,41 @@ class Embedding:
return self.cached_checksum
+class DirWithTextualInversionEmbeddings:
+ def __init__(self, path):
+ self.path = path
+ self.mtime = None
+
+ def has_changed(self):
+ if not os.path.isdir(self.path):
+ return False
+
+ mt = os.path.getmtime(self.path)
+ if self.mtime is None or mt > self.mtime:
+ return True
+
+ def update(self):
+ if not os.path.isdir(self.path):
+ return
+
+ self.mtime = os.path.getmtime(self.path)
+
+
class EmbeddingDatabase:
- def __init__(self, embeddings_dir):
+ def __init__(self):
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
- self.dir_mtime = None
- self.embeddings_dir = embeddings_dir
self.expected_shape = -1
+ self.embedding_dirs = {}
- def register_embedding(self, embedding, model):
+ def add_embedding_dir(self, path):
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
+
+ def clear_embedding_dirs(self):
+ self.embedding_dirs.clear()
+ def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenize([embedding.name])[0]
@@ -93,65 +132,62 @@ class EmbeddingDatabase:
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
- def load_textual_inversion_embeddings(self, force_reload = False):
- mt = os.path.getmtime(self.embeddings_dir)
- if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
- return
+ def load_from_file(self, path, filename):
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
- self.skipped_embeddings.clear()
- self.expected_shape = self.get_expected_shape()
-
- def process_file(path, filename):
- name, ext = os.path.splitext(filename)
- ext = ext.upper()
-
- if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
- embed_image = Image.open(path)
- if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
- data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- name = data.get('name', name)
- else:
- data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
- elif ext in ['.BIN', '.PT']:
- data = torch.load(path, map_location="cpu")
- else:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ _, second_ext = os.path.splitext(name)
+ if second_ext.upper() == '.PREVIEW':
return
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
- else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
-
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('sd_checkpoint', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vec.shape[0]
- embedding.shape = vec.shape[-1]
-
- if self.expected_shape == -1 or self.expected_shape == embedding.shape:
- self.register_embedding(embedding, shared.sd_model)
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
else:
- self.skipped_embeddings[name] = embedding
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ elif ext in ['.BIN', '.PT']:
+ data = torch.load(path, map_location="cpu")
+ else:
+ return
- for root, dirs, fns in os.walk(self.embeddings_dir):
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
+
+ def load_from_dir(self, embdir):
+ if not os.path.isdir(embdir.path):
+ return
+
+ for root, dirs, fns in os.walk(embdir.path):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -159,12 +195,32 @@ class EmbeddingDatabase:
if os.stat(fullfn).st_size == 0:
continue
- process_file(fullfn, fn)
+ self.load_from_file(fullfn, fn)
except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
+ def load_textual_inversion_embeddings(self, force_reload=False):
+ if not force_reload:
+ need_reload = False
+ for path, embdir in self.embedding_dirs.items():
+ if embdir.has_changed():
+ need_reload = True
+ break
+
+ if not need_reload:
+ return
+
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.expected_shape = self.get_expected_shape()
+
+ for path, embdir in self.embedding_dirs.items():
+ self.load_from_dir(embdir)
+ embdir.update()
+
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
@@ -233,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
@@ -243,22 +299,26 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
- assert template_file, "Prompt template file is empty"
- assert os.path.isfile(template_file), "Prompt template file doesn't exist"
+ assert template_filename, "Prompt template file not selected"
+ assert template_file, f"Prompt template file {template_filename} not found"
+ assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
- assert steps > 0 , "Max steps must be positive"
+ assert steps > 0, "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer"
- assert save_model_every >= 0 , "Save {name} must be positive or 0"
+ assert save_model_every >= 0, "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer"
- assert create_image_every >= 0 , "Create image must be positive or 0"
+ assert create_image_every >= 0, "Create image must be positive or 0"
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
- validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = textual_inversion_templates.get(template_filename, None)
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = template_file.path
shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
@@ -309,7 +369,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
if shared.opts.save_training_settings_to_txt:
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
diff --git a/modules/ui.py b/modules/ui.py
index 99483130..b6079aec 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -37,7 +37,7 @@ from modules import prompt_parser
from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
-import modules.textual_inversion.ui
+from modules.textual_inversion import textual_inversion
import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
@@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
with devices.autocast():
p.init([""], [0], [0])
- return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
+ return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
@@ -745,15 +745,20 @@ def create_ui():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- hr_resolution_preview_args = dict(
- fn=calc_resolution_hires,
- inputs=hr_resolution_preview_inputs,
- outputs=[hr_final_resolution],
- show_progress=False
- )
-
for input in hr_resolution_preview_inputs:
- input.change(**hr_resolution_preview_args)
+ input.change(
+ fn=calc_resolution_hires,
+ inputs=hr_resolution_preview_inputs,
+ outputs=[hr_final_resolution],
+ show_progress=False,
+ )
+ input.change(
+ None,
+ _js="onCalcResolutionHires",
+ inputs=hr_resolution_preview_inputs,
+ outputs=[],
+ show_progress=False,
+ )
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -1317,6 +1322,9 @@ def create_ui():
outputs=[process_focal_crop_row],
)
+ def get_textual_inversion_template_names():
+ return sorted([x for x in textual_inversion.textual_inversion_templates])
+
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with FormRow():
@@ -1340,9 +1348,14 @@ def create_ui():
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory")
- template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file")
+
+ with FormRow():
+ template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
+ create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
+
training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
with FormRow():
@@ -1449,6 +1462,7 @@ def create_ui():
log_directory,
training_width,
training_height,
+ varsize,
steps,
clip_grad_mode,
clip_grad_value,
@@ -1480,6 +1494,7 @@ def create_ui():
log_directory,
training_width,
training_height,
+ varsize,
steps,
clip_grad_mode,
clip_grad_value,
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index 9b8ffd85..332d76d9 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -25,6 +25,8 @@ class Script(scripts.Script):
return [info, overlap, upscaler_index, scale_factor]
def run(self, p, _, overlap, upscaler_index, scale_factor):
+ if isinstance(upscaler_index, str):
+ upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
processing.fix_seed(p)
upscaler = shared.sd_upscalers[upscaler_index]
diff --git a/style.css b/style.css
index 76721756..ec5e4182 100644
--- a/style.css
+++ b/style.css
@@ -512,7 +512,7 @@ input[type="range"]{
border: none;
background: none;
flex: unset;
- gap: 0.5em;
+ gap: 1em;
}
#quicksettings > div > div{
@@ -521,6 +521,17 @@ input[type="range"]{
padding: 0;
}
+#quicksettings > div > div > div > div > label > span {
+ position: relative;
+ margin-right: 9em;
+ margin-bottom: -1em;
+}
+
+#quicksettings > div > div > label > span {
+ position: relative;
+ margin-bottom: -1em;
+}
+
canvas[key="mask"] {
z-index: 12 !important;
filter: invert();
@@ -659,6 +670,10 @@ footer {
min-width: auto;
}
+.inactive{
+ opacity: 0.5;
+}
+
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
diff --git a/test/advanced_features/__init__.py b/test/advanced_features/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/test/advanced_features/__init__.py
+++ /dev/null
diff --git a/test/advanced_features/extras_test.py b/test/advanced_features/extras_test.py
deleted file mode 100644
index 8763f8ed..00000000
--- a/test/advanced_features/extras_test.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import unittest
-
-
-class TestExtrasWorking(unittest.TestCase):
- def setUp(self):
- self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image"
- self.simple_extras = {
- "resize_mode": 0,
- "show_extras_results": True,
- "gfpgan_visibility": 0,
- "codeformer_visibility": 0,
- "codeformer_weight": 0,
- "upscaling_resize": 2,
- "upscaling_resize_w": 128,
- "upscaling_resize_h": 128,
- "upscaling_crop": True,
- "upscaler_1": "None",
- "upscaler_2": "None",
- "extras_upscaler_2_visibility": 0,
- "image": ""
- }
-
-
-class TestExtrasCorrectness(unittest.TestCase):
- pass
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/advanced_features/txt2img_test.py b/test/advanced_features/txt2img_test.py
deleted file mode 100644
index 36ed7b9a..00000000
--- a/test/advanced_features/txt2img_test.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import unittest
-import requests
-
-
-class TestTxt2ImgWorking(unittest.TestCase):
- def setUp(self):
- self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
- self.simple_txt2img = {
- "enable_hr": False,
- "denoising_strength": 0,
- "firstphase_width": 0,
- "firstphase_height": 0,
- "prompt": "example prompt",
- "styles": [],
- "seed": -1,
- "subseed": -1,
- "subseed_strength": 0,
- "seed_resize_from_h": -1,
- "seed_resize_from_w": -1,
- "batch_size": 1,
- "n_iter": 1,
- "steps": 3,
- "cfg_scale": 7,
- "width": 64,
- "height": 64,
- "restore_faces": False,
- "tiling": False,
- "negative_prompt": "",
- "eta": 0,
- "s_churn": 0,
- "s_tmax": 0,
- "s_tmin": 0,
- "s_noise": 1,
- "sampler_index": "Euler a"
- }
-
- def test_txt2img_with_restore_faces_performed(self):
- self.simple_txt2img["restore_faces"] = True
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
-
-class TestTxt2ImgCorrectness(unittest.TestCase):
- pass
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py
new file mode 100644
index 00000000..0170c511
--- /dev/null
+++ b/test/basic_features/extras_test.py
@@ -0,0 +1,54 @@
+import unittest
+import requests
+from gradio.processing_utils import encode_pil_to_base64
+from PIL import Image
+
+class TestExtrasWorking(unittest.TestCase):
+ def setUp(self):
+ self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image"
+ self.extras_single = {
+ "resize_mode": 0,
+ "show_extras_results": True,
+ "gfpgan_visibility": 0,
+ "codeformer_visibility": 0,
+ "codeformer_weight": 0,
+ "upscaling_resize": 2,
+ "upscaling_resize_w": 128,
+ "upscaling_resize_h": 128,
+ "upscaling_crop": True,
+ "upscaler_1": "None",
+ "upscaler_2": "None",
+ "extras_upscaler_2_visibility": 0,
+ "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
+ }
+
+ def test_simple_upscaling_performed(self):
+ self.extras_single["upscaler_1"] = "Lanczos"
+ self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200)
+
+
+class TestPngInfoWorking(unittest.TestCase):
+ def setUp(self):
+ self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
+ self.png_info = {
+ "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))
+ }
+
+ def test_png_info_performed(self):
+ self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200)
+
+
+class TestInterrogateWorking(unittest.TestCase):
+ def setUp(self):
+ self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
+ self.interrogate = {
+ "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")),
+ "model": "clip"
+ }
+
+ def test_interrogate_performed(self):
+ self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py
index 0a9c1e8a..08c5c903 100644
--- a/test/basic_features/img2img_test.py
+++ b/test/basic_features/img2img_test.py
@@ -16,7 +16,7 @@ class TestImg2ImgWorking(unittest.TestCase):
"inpainting_fill": 0,
"inpaint_full_res": False,
"inpaint_full_res_padding": 0,
- "inpainting_mask_invert": 0,
+ "inpainting_mask_invert": False,
"prompt": "example prompt",
"styles": [],
"seed": -1,
@@ -50,6 +50,17 @@ class TestImg2ImgWorking(unittest.TestCase):
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+ def test_inpainting_with_inverted_masked_performed(self):
+ self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
+ self.simple_img2img["inpainting_mask_invert"] = True
+ self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+
+ def test_img2img_sd_upscale_performed(self):
+ self.simple_img2img["script_name"] = "sd upscale"
+ self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0]
+
+ self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
+
if __name__ == "__main__":
unittest.main()
diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py
index 1c2674b2..5b27a7ec 100644
--- a/test/basic_features/txt2img_test.py
+++ b/test/basic_features/txt2img_test.py
@@ -41,6 +41,9 @@ class TestTxt2ImgWorking(unittest.TestCase):
self.simple_txt2img["negative_prompt"] = "example negative prompt"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+ def test_txt2img_with_complex_prompt_performed(self):
+ self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]"
+
def test_txt2img_not_square_image_performed(self):
self.simple_txt2img["height"] = 128
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
@@ -53,6 +56,10 @@ class TestTxt2ImgWorking(unittest.TestCase):
self.simple_txt2img["tiling"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+ def test_txt2img_with_restore_faces_performed(self):
+ self.simple_txt2img["restore_faces"] = True
+ self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
def test_txt2img_with_vanilla_sampler_performed(self):
self.simple_txt2img["sampler_index"] = "PLMS"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
@@ -63,6 +70,10 @@ class TestTxt2ImgWorking(unittest.TestCase):
self.simple_txt2img["n_iter"] = 2
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+ def test_txt2img_batch_performed(self):
+ self.simple_txt2img["batch_size"] = 2
+ self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
+
if __name__ == "__main__":
unittest.main()
diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py
index 765470c9..94e00253 100644
--- a/test/basic_features/utils_test.py
+++ b/test/basic_features/utils_test.py
@@ -14,10 +14,25 @@ class UtilsTests(unittest.TestCase):
self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories"
self.url_artists = "http://localhost:7860/sdapi/v1/artists"
+ self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
def test_options_get(self):
self.assertEqual(requests.get(self.url_options).status_code, 200)
+ def test_options_write(self):
+ response = requests.get(self.url_options)
+ self.assertEqual(response.status_code, 200)
+
+ pre_value = response.json()["send_seed"]
+
+ self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
+
+ response = requests.get(self.url_options)
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.json()["send_seed"], not pre_value)
+
+ requests.post(self.url_options, json={"send_seed": pre_value})
+
def test_cmd_flags(self):
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
@@ -48,6 +63,8 @@ class UtilsTests(unittest.TestCase):
def test_artists(self):
self.assertEqual(requests.get(self.url_artists).status_code, 200)
+ def test_embeddings(self):
+ self.assertEqual(requests.get(self.url_artists).status_code, 200)
if __name__ == "__main__":
unittest.main()
diff --git a/test/server_poll.py b/test/server_poll.py
index d4df697b..42d56a4c 100644
--- a/test/server_poll.py
+++ b/test/server_poll.py
@@ -15,7 +15,7 @@ def run_tests(proc, test_dir):
break
if proc.poll() is None:
if test_dir is None:
- test_dir = ""
+ test_dir = "test"
suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
result = unittest.TextTestRunner(verbosity=2).run(suite)
return len(result.failures) + len(result.errors)
diff --git a/webui.py b/webui.py
index 8737e593..47d372c7 100644
--- a/webui.py
+++ b/webui.py
@@ -33,6 +33,7 @@ import modules.sd_models
import modules.sd_vae
import modules.txt2img
import modules.script_callbacks
+import modules.textual_inversion.textual_inversion
import modules.ui
from modules import modelloader
@@ -67,6 +68,8 @@ def initialize():
modules.sd_vae.refresh_vae_list()
+ modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
+
try:
modules.sd_models.load_model()
except Exception as e: