aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.yml29
-rw-r--r--README.md5
-rw-r--r--configs/instruct-pix2pix.yaml3
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py8
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py3
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py3
-rw-r--r--html/extra-networks-card.html1
-rw-r--r--javascript/extensions.js20
-rw-r--r--javascript/extraNetworks.js42
-rw-r--r--javascript/hints.js10
-rw-r--r--javascript/ui.js43
-rw-r--r--launch.py42
-rw-r--r--modules/api/api.py2
-rw-r--r--modules/devices.py82
-rw-r--r--modules/esrgan_model_arch.py1
-rw-r--r--modules/extensions.py6
-rw-r--r--modules/extra_networks_hypernet.py8
-rw-r--r--modules/extras.py17
-rw-r--r--modules/generation_parameters_copypaste.py217
-rw-r--r--modules/hashes.py4
-rw-r--r--modules/hypernetworks/hypernetwork.py18
-rw-r--r--modules/images.py43
-rw-r--r--modules/img2img.py15
-rw-r--r--modules/mac_specific.py53
-rw-r--r--modules/modelloader.py3
-rw-r--r--modules/paths.py2
-rw-r--r--modules/processing.py49
-rw-r--r--modules/realesrgan_model.py2
-rw-r--r--modules/script_callbacks.py29
-rw-r--r--modules/scripts.py14
-rw-r--r--modules/sd_disable_initialization.py17
-rw-r--r--modules/sd_hijack.py58
-rw-r--r--modules/sd_hijack_inpainting.py1
-rw-r--r--modules/sd_hijack_unet.py19
-rw-r--r--modules/sd_models.py23
-rw-r--r--modules/sd_models_config.py51
-rw-r--r--modules/sd_samplers.py519
-rw-r--r--modules/sd_samplers_common.py62
-rw-r--r--modules/sd_samplers_compvis.py160
-rw-r--r--modules/sd_samplers_kdiffusion.py357
-rw-r--r--modules/shared.py55
-rw-r--r--modules/shared_items.py2
-rw-r--r--modules/textual_inversion/dataset.py58
-rw-r--r--modules/textual_inversion/textual_inversion.py22
-rw-r--r--modules/txt2img.py6
-rw-r--r--modules/ui.py83
-rw-r--r--modules/ui_common.py6
-rw-r--r--modules/ui_extensions.py36
-rw-r--r--modules/ui_extra_networks.py82
-rw-r--r--modules/ui_extra_networks_checkpoints.py39
-rw-r--r--modules/ui_extra_networks_hypernets.py3
-rw-r--r--modules/ui_extra_networks_textual_inversion.py3
-rw-r--r--requirements_versions.txt1
-rw-r--r--scripts/img2imgalt.py6
-rw-r--r--scripts/prompt_matrix.py47
-rw-r--r--scripts/xyz_grid.py63
-rw-r--r--style.css46
-rw-r--r--webui-macos-env.sh2
-rw-r--r--webui.py18
59 files changed, 1740 insertions, 879 deletions
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index ed372f22..7d435297 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -37,20 +37,20 @@ body:
id: what-should
attributes:
label: What should have happened?
- description: tell what you think the normal behavior should be
+ description: Tell what you think the normal behavior should be
validations:
required: true
- type: input
id: commit
attributes:
label: Commit where the problem happens
- description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
+ description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
validations:
required: true
- type: dropdown
id: platforms
attributes:
- label: What platforms do you use to access UI ?
+ label: What platforms do you use to access the UI ?
multiple: true
options:
- Windows
@@ -74,10 +74,27 @@ body:
id: cmdargs
attributes:
label: Command Line Arguments
- description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below
+ description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
render: Shell
+ validations:
+ required: true
+ - type: textarea
+ id: extensions
+ attributes:
+ label: List of extensions
+ description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
+ validations:
+ required: true
+ - type: textarea
+ id: logs
+ attributes:
+ label: Console logs
+ description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
+ render: Shell
+ validations:
+ required: true
- type: textarea
id: misc
attributes:
- label: Additional information, context and logs
- description: Please provide us with any relevant additional info, context or log output.
+ label: Additional information
+ description: Please provide us with any relevant additional info or context.
diff --git a/README.md b/README.md
index 2149dcc5..2ceb4d2d 100644
--- a/README.md
+++ b/README.md
@@ -104,8 +104,7 @@ Alternatively, use online services (like Google Colab):
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
-4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
-5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
+4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
### Automatic Installation on Linux
1. Install the dependencies:
@@ -121,7 +120,7 @@ sudo pacman -S wget git python3
```bash
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
```
-
+3. Run `webui.sh`.
### Installation on Apple Silicon
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml
index 437ddcef..4e896879 100644
--- a/configs/instruct-pix2pix.yaml
+++ b/configs/instruct-pix2pix.yaml
@@ -20,8 +20,7 @@ model:
conditioning_key: hybrid
monitor: val/loss_simple_ema
scale_factor: 0.18215
- use_ema: true
- load_ema: true
+ use_ema: false
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index 8f2e753e..6be6ef73 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -1,4 +1,4 @@
-from modules import extra_networks
+from modules import extra_networks, shared
import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork):
@@ -6,6 +6,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
super().__init__('lora')
def activate(self, p, params_list):
+ additional = shared.opts.sd_lora
+
+ if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
+ p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
+
names = []
multipliers = []
for params in params_list:
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 544b228d..2e860160 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,4 +1,5 @@
import torch
+import gradio as gr
import lora
import extra_networks_lora
@@ -31,5 +32,7 @@ script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
+
}))
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 54a80d36..22cabcb0 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -20,13 +20,14 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
preview = None
for file in previews:
if os.path.isfile(file):
- preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ preview = self.link_preview(file)
break
yield {
"name": name,
"filename": path,
"preview": preview,
+ "search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index aa9fca87..8a5e2fbd 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -4,6 +4,7 @@
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul>
+ <span style="display:none" class='search_term'>{search_term}</span>
</div>
<span class='name'>{name}</span>
</div>
diff --git a/javascript/extensions.js b/javascript/extensions.js
index ac6e35b9..c593cd2e 100644
--- a/javascript/extensions.js
+++ b/javascript/extensions.js
@@ -1,7 +1,8 @@
function extensions_apply(_, _){
- disable = []
- update = []
+ var disable = []
+ var update = []
+
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
@@ -16,11 +17,24 @@ function extensions_apply(_, _){
}
function extensions_check(){
+ var disable = []
+
+ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
+ if(x.name.startsWith("enable_") && ! x.checked)
+ disable.push(x.name.substr(7))
+ })
+
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
- return []
+
+ var id = randomId()
+ requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
+
+ })
+
+ return [id, JSON.stringify(disable)]
}
function install_extension_from_index(button, url){
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index c5a9adb3..17bf2000 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -16,7 +16,7 @@ function setupExtraNetworksForTab(tabname){
searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
- text = elem.querySelector('.name').textContent.toLowerCase()
+ text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
})
});
@@ -48,10 +48,39 @@ function setupExtraNetworks(){
onUiLoaded(setupExtraNetworks)
+var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
+var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
+
+function tryToRemoveExtraNetworkFromPrompt(textarea, text){
+ var m = text.match(re_extranet)
+ if(! m) return false
+
+ var partToSearch = m[1]
+ var replaced = false
+ var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
+ m = found.match(re_extranet);
+ if(m[1] == partToSearch){
+ replaced = true;
+ return ""
+ }
+ return found;
+ })
+
+ if(replaced){
+ textarea.value = newTextareaText
+ return true;
+ }
+
+ return false
+}
+
function cardClicked(tabname, textToAdd, allowNegativePrompt){
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
- textarea.value = textarea.value + " " + textToAdd
+ if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
+ textarea.value = textarea.value + " " + textToAdd
+ }
+
updateInput(textarea)
}
@@ -67,3 +96,12 @@ function saveCardPreview(event, tabname, filename){
event.stopPropagation()
event.preventDefault()
}
+
+function extraNetworksSearchButton(tabs_id, event){
+ searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
+ button = event.target
+ text = button.classList.contains("search-all") ? "" : button.textContent.trim()
+
+ searchTextarea.value = text
+ updateInput(searchTextarea)
+} \ No newline at end of file
diff --git a/javascript/hints.js b/javascript/hints.js
index 7b60b25e..f1199009 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -8,8 +8,8 @@ titles = {
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
- "Batch count": "How many batches of images to create",
- "Batch size": "How many image to create in a single batch",
+ "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
+ "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
@@ -17,7 +17,7 @@ titles = {
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style",
- "\U0001F5D1": "Clear prompt",
+ "\u{1f5d1}": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show extra networks",
@@ -66,8 +66,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
- "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
- "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
+ "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
+ "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.",
diff --git a/javascript/ui.js b/javascript/ui.js
index ba72623c..b7a8268a 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -191,6 +191,28 @@ function confirm_clear_prompt(prompt, negative_prompt) {
return [prompt, negative_prompt]
}
+
+promptTokecountUpdateFuncs = {}
+
+function recalculatePromptTokens(name){
+ if(promptTokecountUpdateFuncs[name]){
+ promptTokecountUpdateFuncs[name]()
+ }
+}
+
+function recalculate_prompts_txt2img(){
+ recalculatePromptTokens('txt2img_prompt')
+ recalculatePromptTokens('txt2img_neg_prompt')
+ return args_to_array(arguments);
+}
+
+function recalculate_prompts_img2img(){
+ recalculatePromptTokens('img2img_prompt')
+ recalculatePromptTokens('img2img_neg_prompt')
+ return args_to_array(arguments);
+}
+
+
opts = {}
onUiUpdate(function(){
if(Object.keys(opts).length != 0) return;
@@ -232,14 +254,12 @@ onUiUpdate(function(){
return
}
-
prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative"
- textarea.addEventListener("input", function(){
- update_token_counter(id_button);
- });
+ promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
+ textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
}
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
@@ -273,7 +293,7 @@ onOptionsChanged(function(){
let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
-let token_timeout;
+let token_timeouts = {};
function update_txt2img_tokens(...args) {
update_token_counter("txt2img_token_button")
@@ -290,9 +310,9 @@ function update_img2img_tokens(...args) {
}
function update_token_counter(button_id) {
- if (token_timeout)
- clearTimeout(token_timeout);
- token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
+ if (token_timeouts[button_id])
+ clearTimeout(token_timeouts[button_id]);
+ token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
function restart_reload(){
@@ -309,3 +329,10 @@ function updateInput(target){
Object.defineProperty(e, "target", {value: target})
target.dispatchEvent(e);
}
+
+
+var desiredCheckpointName = null;
+function selectCheckpoint(name){
+ desiredCheckpointName = name;
+ gradioApp().getElementById('change_checkpoint').click()
+}
diff --git a/launch.py b/launch.py
index 9d6f4a8c..a68bb3a9 100644
--- a/launch.py
+++ b/launch.py
@@ -17,6 +17,37 @@ stored_commit_hash = None
skip_install = False
+def check_python_version():
+ is_windows = platform.system() == "Windows"
+ major = sys.version_info.major
+ minor = sys.version_info.minor
+ micro = sys.version_info.micro
+
+ if is_windows:
+ supported_minors = [10]
+ else:
+ supported_minors = [7, 8, 9, 10, 11]
+
+ if not (major == 3 and minor in supported_minors):
+ import modules.errors
+
+ modules.errors.print_error_explanation(f"""
+INCOMPATIBLE PYTHON VERSION
+
+This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
+If you encounter an error with "RuntimeError: Couldn't install torch." message,
+or any other error regarding unsuccessful package (library) installation,
+please downgrade (or upgrade) to the latest version of 3.10 Python
+and delete current Python and "venv" folder in WebUI's directory.
+
+You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
+
+{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
+
+Use --skip-python-version-check to suppress this warning.
+""")
+
+
def commit_hash():
global stored_commit_hash
@@ -192,6 +223,7 @@ def prepare_environment():
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
@@ -210,12 +242,13 @@ def prepare_environment():
sys.argv += shlex.split(commandline_args)
- parser = argparse.ArgumentParser()
+ parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
args, _ = parser.parse_known_args(sys.argv)
sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
+ sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
@@ -224,6 +257,9 @@ def prepare_environment():
xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv
+ if not skip_python_version_check:
+ check_python_version()
+
commit = commit_hash()
print(f"Python {sys.version}")
@@ -247,14 +283,14 @@ def prepare_environment():
if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows":
if platform.python_version().startswith("3.10"):
- run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
if not is_installed("xformers"):
exit(0)
elif platform.system() == "Linux":
- run_pip("install xformers==0.0.16rc425", "xformers")
+ run_pip(f"install {xformers_package}", "xformers")
if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok")
diff --git a/modules/api/api.py b/modules/api/api.py
index eb7b1da5..5a9ac5f1 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -498,7 +498,7 @@ class Api:
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
- hypernetwork, filename = train_hypernetwork(*args)
+ hypernetwork, filename = train_hypernetwork(**args)
except Exception as e:
error = e
finally:
diff --git a/modules/devices.py b/modules/devices.py
index 4687944e..52c3e7cd 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,21 +1,17 @@
-import sys, os, shlex
+import sys
import contextlib
import torch
from modules import errors
-from packaging import version
+
+if sys.platform == "darwin":
+ from modules import mac_specific
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
def has_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
- try:
- torch.zeros(1).to(torch.device("mps"))
- return True
- except Exception:
+ if sys.platform != "darwin":
return False
-
+ else:
+ return mac_specific.has_mps
def extract_device_id(args, name):
for x in range(len(args)):
@@ -87,6 +83,14 @@ dtype_unet = torch.float16
unet_needs_upcast = False
+def cond_cast_unet(input):
+ return input.to(dtype_unet) if unet_needs_upcast else input
+
+
+def cond_cast_float(input):
+ return input.float() if unet_needs_upcast else input
+
+
def randn(seed, shape):
torch.manual_seed(seed)
if device.type == 'mps':
@@ -146,59 +150,3 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
-orig_tensor_to = torch.Tensor.to
-def tensor_to_fix(self, *args, **kwargs):
- if self.device.type != 'mps' and \
- ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
- (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
- self = self.contiguous()
- return orig_tensor_to(self, *args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
-orig_layer_norm = torch.nn.functional.layer_norm
-def layer_norm_fix(*args, **kwargs):
- if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
- args = list(args)
- args[0] = args[0].contiguous()
- return orig_layer_norm(*args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
-orig_tensor_numpy = torch.Tensor.numpy
-def numpy_fix(self, *args, **kwargs):
- if self.requires_grad:
- self = self.detach()
- return orig_tensor_numpy(self, *args, **kwargs)
-
-
-# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
-orig_cumsum = torch.cumsum
-orig_Tensor_cumsum = torch.Tensor.cumsum
-def cumsum_fix(input, cumsum_func, *args, **kwargs):
- if input.device.type == 'mps':
- output_dtype = kwargs.get('dtype', input.dtype)
- if output_dtype == torch.int64:
- return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
- return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
- return cumsum_func(input, *args, **kwargs)
-
-
-if has_mps():
- if version.parse(torch.__version__) < version.parse("1.13"):
- # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
- torch.Tensor.to = tensor_to_fix
- torch.nn.functional.layer_norm = layer_norm_fix
- torch.Tensor.numpy = numpy_fix
- elif version.parse(torch.__version__) > version.parse("1.13.1"):
- cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
- torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
- torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
- orig_narrow = torch.narrow
- torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
-
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
index bc9ceb2a..1b52b0f5 100644
--- a/modules/esrgan_model_arch.py
+++ b/modules/esrgan_model_arch.py
@@ -1,5 +1,6 @@
# this file is adapted from https://github.com/victorca25/iNNfer
+from collections import OrderedDict
import math
import functools
import torch
diff --git a/modules/extensions.py b/modules/extensions.py
index 5e12b1aa..3eef9eaf 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -2,6 +2,7 @@ import os
import sys
import traceback
+import time
import git
from modules import paths, shared
@@ -25,6 +26,7 @@ class Extension:
self.status = ''
self.can_update = False
self.is_builtin = is_builtin
+ self.version = ''
repo = None
try:
@@ -40,6 +42,10 @@ class Extension:
try:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
+ head = repo.head.commit
+ ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
+ self.version = f'{head.hexsha[:8]} ({ts})'
+
except Exception:
self.remote = None
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
index ff279a1f..d3a4d7ad 100644
--- a/modules/extra_networks_hypernet.py
+++ b/modules/extra_networks_hypernet.py
@@ -1,4 +1,4 @@
-from modules import extra_networks
+from modules import extra_networks, shared, extra_networks
from modules.hypernetworks import hypernetwork
@@ -7,6 +7,12 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
super().__init__('hypernet')
def activate(self, p, params_list):
+ additional = shared.opts.sd_hypernetwork
+
+ if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
+ params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
+
names = []
multipliers = []
for params in params_list:
diff --git a/modules/extras.py b/modules/extras.py
index 4f842be9..d8ece955 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False
+ result_is_instruct_pix2pix_model = False
if theta_func2:
shared.state.textinfo = f"Loading B"
@@ -185,14 +186,19 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
if a.shape[1] == 4 and b.shape[1] == 9:
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")
+ if a.shape[1] == 4 and b.shape[1] == 8:
+ raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.")
- assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
-
- theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
- result_is_inpainting_model = True
+ if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
+ result_is_instruct_pix2pix_model = True
+ else:
+ assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
+ theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
+ result_is_inpainting_model = True
else:
theta_0[key] = theta_func2(a, b, multiplier)
-
+
theta_0[key] = to_half(theta_0[key], save_as_half)
shared.state.sampling_step += 1
@@ -226,6 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
filename = filename_generator() if custom_name == '' else custom_name
filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else ""
filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index 773c5c0e..89dc23bf 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,4 +1,5 @@
import base64
+import html
import io
import math
import os
@@ -11,19 +12,28 @@ from modules import shared, ui_tempdir, script_callbacks
import tempfile
from PIL import Image
-re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
+re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
-re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update())
+
paste_fields = {}
-bind_list = []
+registered_param_bindings = []
+
+
+class ParamBinding:
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
+ self.paste_button = paste_button
+ self.tabname = tabname
+ self.source_text_component = source_text_component
+ self.source_image_component = source_image_component
+ self.source_tabname = source_tabname
+ self.override_settings_component = override_settings_component
def reset():
paste_fields.clear()
- bind_list.clear()
def quote(text):
@@ -64,8 +74,8 @@ def image_from_url_text(filedata):
return image
-def add_paste_fields(tabname, init_img, fields):
- paste_fields[tabname] = {"init_img": init_img, "fields": fields}
+def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
+ paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
# backwards compatibility for existing extensions
import modules.ui
@@ -75,26 +85,6 @@ def add_paste_fields(tabname, init_img, fields):
modules.ui.img2img_paste_fields = fields
-def integrate_settings_paste_fields(component_dict):
- from modules import ui
-
- settings_map = {
- 'CLIP_stop_at_last_layers': 'Clip skip',
- 'inpainting_mask_weight': 'Conditional mask weight',
- 'sd_model_checkpoint': 'Model hash',
- 'eta_noise_seed_delta': 'ENSD',
- 'initial_noise_multiplier': 'Noise multiplier',
- }
- settings_paste_fields = [
- (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
- for k, v in settings_map.items()
- ]
-
- for tabname, info in paste_fields.items():
- if info["fields"] is not None:
- info["fields"] += settings_paste_fields
-
-
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
@@ -102,9 +92,61 @@ def create_buttons(tabs_list):
return buttons
-#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
def bind_buttons(buttons, send_image, send_generate_info):
- bind_list.append([buttons, send_image, send_generate_info])
+ """old function for backwards compatibility; do not use this, use register_paste_params_button"""
+ for tabname, button in buttons.items():
+ source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
+ source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
+
+ register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
+
+
+def register_paste_params_button(binding: ParamBinding):
+ registered_param_bindings.append(binding)
+
+
+def connect_paste_params_buttons():
+ binding: ParamBinding
+ for binding in registered_param_bindings:
+ destination_image_component = paste_fields[binding.tabname]["init_img"]
+ fields = paste_fields[binding.tabname]["fields"]
+ override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
+
+ destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
+ destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
+
+ if binding.source_image_component and destination_image_component:
+ if isinstance(binding.source_image_component, gr.Gallery):
+ func = send_image_and_dimensions if destination_width_component else image_from_url_text
+ jsfunc = "extract_image_from_gallery"
+ else:
+ func = send_image_and_dimensions if destination_width_component else lambda x: x
+ jsfunc = None
+
+ binding.paste_button.click(
+ fn=func,
+ _js=jsfunc,
+ inputs=[binding.source_image_component],
+ outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
+ )
+
+ if binding.source_text_component is not None and fields is not None:
+ connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
+
+ if binding.source_tabname is not None and fields is not None:
+ paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
+ binding.paste_button.click(
+ fn=lambda *x: x,
+ inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
+ outputs=[field for field, name in fields if name in paste_field_names],
+ )
+
+ binding.paste_button.click(
+ fn=None,
+ _js=f"switch_to_{binding.tabname}",
+ inputs=None,
+ outputs=None,
+ )
def send_image_and_dimensions(x):
@@ -123,49 +165,6 @@ def send_image_and_dimensions(x):
return img, w, h
-def run_bind():
- for buttons, source_image_component, send_generate_info in bind_list:
- for tab in buttons:
- button = buttons[tab]
- destination_image_component = paste_fields[tab]["init_img"]
- fields = paste_fields[tab]["fields"]
-
- destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
- destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
-
- if source_image_component and destination_image_component:
- if isinstance(source_image_component, gr.Gallery):
- func = send_image_and_dimensions if destination_width_component else image_from_url_text
- jsfunc = "extract_image_from_gallery"
- else:
- func = send_image_and_dimensions if destination_width_component else lambda x: x
- jsfunc = None
-
- button.click(
- fn=func,
- _js=jsfunc,
- inputs=[source_image_component],
- outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
- )
-
- if send_generate_info and fields is not None:
- if send_generate_info in paste_fields:
- paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
- button.click(
- fn=lambda *x: x,
- inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
- outputs=[field for field, name in fields if name in paste_field_names],
- )
- else:
- connect_paste(button, fields, send_generate_info)
-
- button.click(
- fn=None,
- _js=f"switch_to_{tab}",
- inputs=None,
- outputs=None,
- )
-
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
@@ -243,7 +242,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
done_with_prompt = False
*lines, lastline = x.strip().split("\n")
- if not re_params.match(lastline):
+ if len(re_param.findall(lastline)) < 3:
lines.append(lastline)
lastline = ''
@@ -262,6 +261,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
+ v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k+"-1"] = m.group(1)
@@ -286,7 +286,50 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
-def connect_paste(button, paste_fields, input_comp, jsfunc=None):
+settings_map = {}
+
+infotext_to_setting_name_mapping = [
+ ('Clip skip', 'CLIP_stop_at_last_layers', ),
+ ('Conditional mask weight', 'inpainting_mask_weight'),
+ ('Model hash', 'sd_model_checkpoint'),
+ ('ENSD', 'eta_noise_seed_delta'),
+ ('Noise multiplier', 'initial_noise_multiplier'),
+ ('Eta', 'eta_ancestral'),
+ ('Eta DDIM', 'eta_ddim'),
+ ('Discard penultimate sigma', 'always_discard_next_to_last_sigma')
+]
+
+
+def create_override_settings_dict(text_pairs):
+ """creates processing's override_settings parameters from gradio's multiselect
+
+ Example input:
+ ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
+
+ Example output:
+ {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
+ """
+
+ res = {}
+
+ params = {}
+ for pair in text_pairs:
+ k, v = pair.split(":", maxsplit=1)
+
+ params[k] = v.strip()
+
+ for param_name, setting_name in infotext_to_setting_name_mapping:
+ value = params.get(param_name, None)
+
+ if value is None:
+ continue
+
+ res[setting_name] = shared.opts.cast_value(setting_name, value)
+
+ return res
+
+
+def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(data_path, "params.txt")
@@ -323,9 +366,35 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None):
return res
+ if override_settings_component is not None:
+ def paste_settings(params):
+ vals = {}
+
+ for param_name, setting_name in infotext_to_setting_name_mapping:
+ v = params.get(param_name, None)
+ if v is None:
+ continue
+
+ if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
+ continue
+
+ v = shared.opts.cast_value(setting_name, v)
+ current_value = getattr(shared.opts, setting_name, None)
+
+ if v == current_value:
+ continue
+
+ vals[param_name] = v
+
+ vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
+
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
+
+ paste_fields = paste_fields + [(override_settings_component, paste_settings)]
+
button.click(
fn=paste_func,
- _js=jsfunc,
+ _js=f"recalculate_prompts_{tabname}",
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
diff --git a/modules/hashes.py b/modules/hashes.py
index 819362a3..83272a07 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -4,6 +4,7 @@ import os.path
import filelock
+from modules import shared
from modules.paths import data_path
@@ -68,6 +69,9 @@ def sha256(filename, title):
if sha256_value is not None:
return sha256_value
+ if shared.cmd_opts.no_hashing:
+ return None
+
print(f"Calculating sha256 for {filename}: ", end='')
sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 503534e2..f6ef42d5 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -307,7 +307,7 @@ class Hypernetwork:
def shorthash(self):
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
- return sha256[0:10]
+ return sha256[0:10] if sha256 else None
def list_hypernetworks(path):
@@ -380,8 +380,8 @@ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
- context_k = hypernetwork_layers[0](context_k)
- context_v = hypernetwork_layers[1](context_v)
+ context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
+ context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
return context_k, context_v
@@ -496,7 +496,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
shared.reload_hypernetworks()
-def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -554,7 +554,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
if shared.opts.save_training_settings_to_txt:
saved_params = dict(
@@ -640,13 +640,19 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
- loss = shared.sd_model(x, c)[0] / gradient_step
+ if use_weight:
+ loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
+ del w
+ else:
+ loss = shared.sd_model.forward(x, c)[0] / gradient_step
del x
del c
diff --git a/modules/images.py b/modules/images.py
index 0bc3d524..38404de3 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -16,8 +16,9 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
import json
+import hashlib
-from modules import sd_samplers, shared, script_callbacks
+from modules import sd_samplers, shared, script_callbacks, errors
from modules.shared import opts, cmd_opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
@@ -36,6 +37,8 @@ def image_grid(imgs, batch_size=1, rows=None):
else:
rows = math.sqrt(len(imgs))
rows = round(rows)
+ if rows > len(imgs):
+ rows = len(imgs)
cols = math.ceil(len(imgs) / rows)
@@ -128,7 +131,7 @@ class GridAnnotation:
self.size = None
-def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
+def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
def wrap(drawing, text, font, line_length):
lines = ['']
for word in text.split():
@@ -192,32 +195,35 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
line.allowed_width = allowed_width
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
- ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
- ver_texts]
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
- result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
- result.paste(im, (pad_left, pad_top))
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+
+ for row in range(rows):
+ for col in range(cols):
+ cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
+ result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))
d = ImageDraw.Draw(result)
for col in range(cols):
- x = pad_left + width * col + width / 2
+ x = pad_left + (width + margin) * col + width / 2
y = pad_top / 2 - hor_text_heights[col] / 2
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
for row in range(rows):
x = pad_left / 2
- y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
+ y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
return result
-def draw_prompt_matrix(im, width, height, all_prompts):
+def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
prompts = all_prompts[1:]
boundary = math.ceil(len(prompts) / 2)
@@ -227,7 +233,7 @@ def draw_prompt_matrix(im, width, height, all_prompts):
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
- return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
def resize_image(resize_mode, im, width, height, upscaler_name=None):
@@ -338,6 +344,7 @@ class FilenameGenerator:
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
+ 'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
'prompt': lambda self: sanitize_filename_part(self.prompt),
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
@@ -546,6 +553,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
if image_to_save.mode == 'RGBA':
image_to_save = image_to_save.convert("RGB")
+ elif image_to_save.mode == 'I;16':
+ image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
@@ -568,17 +577,19 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image.already_saved_as = fullfn
- target_side_length = 4000
- oversize = image.width > target_side_length or image.height > target_side_length
- if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
+ oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height
if oversize and ratio > 1:
- image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
+ image = image.resize((opts.target_side_length, image.height * opts.target_side_length // image.width), LANCZOS)
elif oversize:
- image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
+ image = image.resize((image.width * opts.target_side_length // image.height, opts.target_side_length), LANCZOS)
- _atomically_save_image(image, fullfn_without_extension, ".jpg")
+ try:
+ _atomically_save_image(image, fullfn_without_extension, ".jpg")
+ except Exception as e:
+ errors.display(e, "saving image as downscaled JPG")
if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt"
diff --git a/modules/img2img.py b/modules/img2img.py
index fe9447c7..c973b770 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -7,6 +7,7 @@ import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
from modules import devices, sd_samplers
+from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
@@ -21,8 +22,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
images = shared.listfiles(input_dir)
- inpaint_masks = shared.listfiles(inpaint_mask_dir)
- is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0
+ is_inpaint_batch = False
+ if inpaint_mask_dir:
+ inpaint_masks = shared.listfiles(inpaint_mask_dir)
+ is_inpaint_batch = len(inpaint_masks) > 0
if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
@@ -70,10 +73,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
if not save_normally:
os.makedirs(output_dir, exist_ok=True)
+ if processed_image.mode == 'RGBA':
+ processed_image = processed_image.convert("RGB")
processed_image.save(os.path.join(output_dir, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+ override_settings = create_override_settings_dict(override_settings_texts)
+
is_batch = mode == 5
if mode == 0: # img2img
@@ -137,9 +144,11 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
inpainting_fill=inpainting_fill,
resize_mode=resize_mode,
denoising_strength=denoising_strength,
+ image_cfg_scale=image_cfg_scale,
inpaint_full_res=inpaint_full_res,
inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert,
+ override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
new file mode 100644
index 00000000..ddcea53b
--- /dev/null
+++ b/modules/mac_specific.py
@@ -0,0 +1,53 @@
+import torch
+from modules import paths
+from modules.sd_hijack_utils import CondFunc
+from packaging import version
+
+
+# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
+# check `getattr` and try it for compatibility
+def check_for_mps() -> bool:
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+has_mps = check_for_mps()
+
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
+def cumsum_fix(input, cumsum_func, *args, **kwargs):
+ if input.device.type == 'mps':
+ output_dtype = kwargs.get('dtype', input.dtype)
+ if output_dtype == torch.int64:
+ return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
+ elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
+ return cumsum_func(input, *args, **kwargs)
+
+
+if has_mps:
+ # MPS fix for randn in torchsde
+ CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
+
+ if version.parse(torch.__version__) < version.parse("1.13"):
+ # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
+
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+ CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
+ lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
+ lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
+ CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
+ elif version.parse(torch.__version__) > version.parse("1.13.1"):
+ cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
+ cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
+ cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
+ CondFunc('torch.cumsum', cumsum_fix_func, None)
+ CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
+ CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
+
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e9aa514e..fc3f6249 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -45,6 +45,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
full_path = file
if os.path.isdir(full_path):
continue
+ if os.path.islink(full_path) and not os.path.exists(full_path):
+ print(f"Skipping broken symlink: {full_path}")
+ continue
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue
if len(ext_filter) != 0:
diff --git a/modules/paths.py b/modules/paths.py
index 08e6f9b9..d991cc71 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -6,7 +6,7 @@ import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
-parser = argparse.ArgumentParser()
+parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
diff --git a/modules/processing.py b/modules/processing.py
index 5072fc40..2009d3bf 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -173,8 +173,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
- conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
@@ -187,7 +186,7 @@ class StableDiffusionProcessing:
return conditioning
def edit_image_conditioning(self, source_image):
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
return conditioning_image
@@ -218,7 +217,7 @@ class StableDiffusionProcessing:
)
# Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -229,16 +228,18 @@ class StableDiffusionProcessing:
return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
+ source_image = devices.cond_cast_float(source_image)
+
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
- return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
+ return self.depth2img_image_conditioning(source_image)
if self.sd_model.cond_stage_key == "edit":
return self.edit_image_conditioning(source_image)
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
- return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
+ return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@@ -267,6 +268,7 @@ class Processed:
self.height = p.height
self.sampler_name = p.sampler_name
self.cfg_scale = p.cfg_scale
+ self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.steps = p.steps
self.batch_size = p.batch_size
self.restore_faces = p.restore_faces
@@ -418,7 +420,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
+ x = model.decode_first_stage(x)
return x
@@ -444,19 +446,17 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Steps": p.steps,
"Sampler": p.sampler_name,
"CFG scale": p.cfg_scale,
+ "Image CFG scale": getattr(p, 'image_cfg_scale', None),
"Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Batch size": (None if p.batch_size < 2 else p.batch_size),
- "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
- "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
}
@@ -543,8 +543,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
- _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
-
if p.scripts is not None:
p.scripts.process(p)
@@ -582,13 +580,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
- if not p.disable_extra_networks:
- extra_networks.activate(p, extra_network_data)
-
- with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
- file.write(processed.infotext(p, 0))
-
if state.job_count == -1:
state.job_count = p.n_iter
@@ -609,11 +600,24 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0:
break
- prompts, _ = extra_networks.parse_prompts(prompts)
+ prompts, extra_network_data = extra_networks.parse_prompts(prompts)
+
+ if not p.disable_extra_networks:
+ with devices.autocast():
+ extra_networks.activate(p, extra_network_data)
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ # params.txt should be saved after scripts.process_batch, since the
+ # infotext could be modified by that callback
+ # Example: a wildcard processed by process_batch sets an extra model
+ # strength, which is saved as "Model Strength: 1.0" in the infotext
+ if n == 0:
+ with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
+ processed = Processed(p, [], p.seed, "")
+ file.write(processed.infotext(p, 0))
+
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
@@ -903,12 +907,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
self.resize_mode: int = resize_mode
self.denoising_strength: float = denoising_strength
+ self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
self.init_latent = None
self.image_mask = mask
self.latent_mask = None
@@ -1007,7 +1012,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
+ image = image.to(shared.device)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index 47f70251..aad4a629 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler):
scale=info.scale,
model_path=info.local_data_path,
model=info.model(),
- half=not cmd_opts.no_half,
+ half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 4bb45ec7..edd0e2a7 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -46,6 +46,18 @@ class CFGDenoiserParams:
"""Total number of sampling steps planned"""
+class CFGDenoisedParams:
+ def __init__(self, x, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+
class UiTrainTabParams:
def __init__(self, txt2img_preview_params):
self.txt2img_preview_params = txt2img_preview_params
@@ -68,6 +80,7 @@ callback_map = dict(
callbacks_before_image_saved=[],
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
+ callbacks_cfg_denoised=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -150,6 +163,14 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
report_exception(c, 'cfg_denoiser_callback')
+def cfg_denoised_callback(params: CFGDenoisedParams):
+ for c in callback_map['callbacks_cfg_denoised']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_denoised_callback')
+
+
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
@@ -283,6 +304,14 @@ def on_cfg_denoiser(callback):
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
+def on_cfg_denoised(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+ The callback is called with one argument:
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callback_map['callbacks_cfg_denoised'], callback)
+
+
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
diff --git a/modules/scripts.py b/modules/scripts.py
index 6e9dc0c0..24056a12 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -345,6 +345,20 @@ class ScriptRunner:
outputs=[script.group for script in self.selectable_scripts]
)
+ self.script_load_ctr = 0
+ def onload_script_visibility(params):
+ title = params.get('Script', None)
+ if title:
+ title_index = self.titles.index(title)
+ visibility = title_index == self.script_load_ctr
+ self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)
+ return gr.update(visible=visibility)
+ else:
+ return gr.update(visible=False)
+
+ self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
+ self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
+
return inputs
def run(self, p, *args):
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index e90aa9fe..c4a09d15 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -20,8 +20,9 @@ class DisableInitialization:
```
"""
- def __init__(self):
+ def __init__(self, disable_clip=True):
self.replaced = []
+ self.disable_clip = disable_clip
def replace(self, obj, field, func):
original = getattr(obj, field, None)
@@ -75,12 +76,14 @@ class DisableInitialization:
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
- self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
- self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
- self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
- self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
- self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
- self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
+
+ if self.disable_clip:
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
for obj, field, original in self.replaced:
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f9652d21..79476783 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -1,5 +1,6 @@
import torch
from torch.nn.functional import silu
+from types import MethodType
import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
@@ -76,6 +77,54 @@ def fix_checkpoint():
pass
+def weighted_loss(sd_model, pred, target, mean=True):
+ #Calculate the weight normally, but ignore the mean
+ loss = sd_model._old_get_loss(pred, target, mean=False)
+
+ #Check if we have weights available
+ weight = getattr(sd_model, '_custom_loss_weight', None)
+ if weight is not None:
+ loss *= weight
+
+ #Return the loss, as mean if specified
+ return loss.mean() if mean else loss
+
+def weighted_forward(sd_model, x, c, w, *args, **kwargs):
+ try:
+ #Temporarily append weights to a place accessible during loss calc
+ sd_model._custom_loss_weight = w
+
+ #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
+ #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
+ if not hasattr(sd_model, '_old_get_loss'):
+ sd_model._old_get_loss = sd_model.get_loss
+ sd_model.get_loss = MethodType(weighted_loss, sd_model)
+
+ #Run the standard forward function, but with the patched 'get_loss'
+ return sd_model.forward(x, c, *args, **kwargs)
+ finally:
+ try:
+ #Delete temporary weights if appended
+ del sd_model._custom_loss_weight
+ except AttributeError as e:
+ pass
+
+ #If we have an old loss function, reset the loss function to the original one
+ if hasattr(sd_model, '_old_get_loss'):
+ sd_model.get_loss = sd_model._old_get_loss
+ del sd_model._old_get_loss
+
+def apply_weighted_forward(sd_model):
+ #Add new function 'weighted_forward' that can be called to calc weighted loss
+ sd_model.weighted_forward = MethodType(weighted_forward, sd_model)
+
+def undo_weighted_forward(sd_model):
+ try:
+ del sd_model.weighted_forward
+ except AttributeError as e:
+ pass
+
+
class StableDiffusionModelHijack:
fixes = None
comments = []
@@ -104,6 +153,10 @@ class StableDiffusionModelHijack:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+ apply_weighted_forward(m)
+ if m.cond_stage_key == "edit":
+ sd_hijack_unet.hijack_ddpm_edit()
+
self.optimization_method = apply_optimizations()
self.clip = m.cond_stage_model
@@ -131,6 +184,9 @@ class StableDiffusionModelHijack:
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped
+ undo_optimizations()
+ undo_weighted_forward(m)
+
self.apply_circular(False)
self.layers = None
self.clip = None
@@ -171,7 +227,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
- emb = embedding.vec
+ emb = devices.cond_cast_unet(embedding.vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 478cd499..55a2ce4d 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -11,6 +11,7 @@ import ldm.models.diffusion.plms
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
@torch.no_grad()
diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index a6ee577c..843ab66c 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -44,6 +44,7 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
+
class GELUHijack(torch.nn.GELU, torch.nn.Module):
def __init__(self, *args, **kwargs):
torch.nn.GELU.__init__(self, *args, **kwargs)
@@ -53,10 +54,26 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
else:
return torch.nn.GELU.forward(self, x)
+
+ddpm_edit_hijack = None
+def hijack_ddpm_edit():
+ global ddpm_edit_hijack
+ if not ddpm_edit_hijack:
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+ CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+ ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
+
+
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
-CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
+CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"):
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
+
+first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
+first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
+CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
diff --git a/modules/sd_models.py b/modules/sd_models.py
index b2d48a51..127e9663 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -41,6 +41,7 @@ class CheckpointInfo:
name = name[1:]
self.name = name
+ self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
@@ -58,13 +59,17 @@ class CheckpointInfo:
def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
+ if self.sha256 is None:
+ return
+
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
- self.ids += [self.shorthash, self.sha256]
- self.register()
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
+ checkpoints_list.pop(self.title)
self.title = f'{self.name} [{self.shorthash}]'
+ self.register()
return self.shorthash
@@ -100,7 +105,7 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
checkpoint_alisases.clear()
- model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
+ model_list = modelloader.load_models(model_path=model_path, model_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
@@ -157,7 +162,7 @@ def select_checkpoint():
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
+ print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
checkpoint_info = next(iter(checkpoints_list.values()))
@@ -231,12 +236,10 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
- title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
- if checkpoint_info.title != title:
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
@@ -351,6 +354,9 @@ def repair_config(sd_config):
sd_config.model.params.unet_config.params.use_fp16 = True
+sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
+sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -371,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+ clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
timer.record("find config")
@@ -383,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
sd_model = None
try:
- with sd_disable_initialization.DisableInitialization():
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
pass
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 00217990..91c21700 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -1,7 +1,9 @@
import re
import os
-from modules import shared, paths
+import torch
+
+from modules import shared, paths, sd_disable_initialization
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
@@ -16,12 +18,51 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
-re_parametrization_v = re.compile(r'-v\b')
+def is_using_v_parameterization_for_sd2(state_dict):
+ """
+ Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.
+ """
-def guess_model_config_from_state_dict(sd, filename):
- fn = os.path.basename(filename)
+ import ldm.modules.diffusionmodules.openaimodel
+ from modules import devices
+
+ device = devices.cpu
+
+ with sd_disable_initialization.DisableInitialization():
+ unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
+ use_checkpoint=True,
+ use_fp16=False,
+ image_size=32,
+ in_channels=4,
+ out_channels=4,
+ model_channels=320,
+ attention_resolutions=[4, 2, 1],
+ num_res_blocks=2,
+ channel_mult=[1, 2, 4, 4],
+ num_head_channels=64,
+ use_spatial_transformer=True,
+ use_linear_in_transformer=True,
+ transformer_depth=1,
+ context_dim=1024,
+ legacy=False
+ )
+ unet.eval()
+
+ with torch.no_grad():
+ unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k}
+ unet.load_state_dict(unet_sd, strict=True)
+ unet.to(device=device, dtype=torch.float)
+ test_cond = torch.ones((1, 2, 1024), device=device) * 0.5
+ x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5
+
+ out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().item()
+
+ return out < -1
+
+
+def guess_model_config_from_state_dict(sd, filename):
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
@@ -31,7 +72,7 @@ def guess_model_config_from_state_dict(sd, filename):
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if diffusion_model_input.shape[1] == 9:
return config_sd2_inpainting
- elif re.search(re_parametrization_v, fn):
+ elif is_using_v_parameterization_for_sd2(sd):
return config_sd2v
else:
return config_sd2
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index a7910b56..28c2136f 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,53 +1,11 @@
-from collections import namedtuple, deque
-import numpy as np
-from math import floor
-import torch
-import tqdm
-from PIL import Image
-import inspect
-import k_diffusion.sampling
-import torchsde._brownian.brownian_interval
-import ldm.models.diffusion.ddim
-import ldm.models.diffusion.plms
-from modules import prompt_parser, devices, processing, images, sd_vae_approx
+from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
-from modules.shared import opts, cmd_opts, state
-import modules.shared as shared
-from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
-
-
-SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
-
-samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
- ('Euler', 'sample_euler', ['k_euler'], {}),
- ('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
- ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
- ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
- ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
- ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
-]
-
-samplers_data_k_diffusion = [
- SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
- for label, funcname, aliases, options in samplers_k_diffusion
- if hasattr(k_diffusion.sampling, funcname)
-]
+# imports for functions that previously were here and are used by other modules
+from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
all_samplers = [
- *samplers_data_k_diffusion,
- SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
- SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
+ *sd_samplers_kdiffusion.samplers_data_k_diffusion,
+ *sd_samplers_compvis.samplers_data_compvis,
]
all_samplers_map = {x.name: x for x in all_samplers}
@@ -73,8 +31,8 @@ def create_sampler(name, model):
def set_samplers():
global samplers, samplers_for_img2img
- hidden = set(opts.hide_samplers)
- hidden_img2img = set(opts.hide_samplers + ['PLMS'])
+ hidden = set(shared.opts.hide_samplers)
+ hidden_img2img = set(shared.opts.hide_samplers + ['PLMS'])
samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
@@ -87,466 +45,3 @@ def set_samplers():
set_samplers()
-
-sampler_extra_params = {
- 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
- 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
-}
-
-
-def setup_img2img_steps(p, steps=None):
- if opts.img2img_fix_steps or steps is not None:
- requested_steps = (steps or p.steps)
- steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
- t_enc = requested_steps - 1
- else:
- steps = p.steps
- t_enc = int(min(p.denoising_strength, 0.999) * steps)
-
- return steps, t_enc
-
-
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
-
-
-def single_sample_to_image(sample, approximation=None):
- if approximation is None:
- approximation = approximation_indexes.get(opts.show_progress_type, 0)
-
- if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
- elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
- else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
-
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
- x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
- x_sample = x_sample.astype(np.uint8)
- return Image.fromarray(x_sample)
-
-
-def sample_to_image(samples, index=0, approximation=None):
- return single_sample_to_image(samples[index], approximation)
-
-
-def samples_to_image_grid(samples, approximation=None):
- return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
-
-
-def store_latent(decoded):
- state.current_latent = decoded
-
- if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
- if not shared.parallel_processing_allowed:
- shared.state.assign_current_image(sample_to_image(decoded))
-
-
-class InterruptedException(BaseException):
- pass
-
-
-class VanillaStableDiffusionSampler:
- def __init__(self, constructor, sd_model):
- self.sampler = constructor(sd_model)
- self.is_plms = hasattr(self.sampler, 'p_sample_plms')
- self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.sampler_noises = None
- self.step = 0
- self.stop_at = None
- self.eta = None
- self.default_eta = 0.0
- self.config = None
- self.last_latent = None
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def number_of_needed_noises(self, p):
- return 0
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except InterruptedException:
- return self.last_latent
-
- def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
- if state.interrupted or state.skipped:
- raise InterruptedException
-
- if self.stop_at is not None and self.step > self.stop_at:
- raise InterruptedException
-
- # Have to unwrap the inpainting conditioning here to perform pre-processing
- image_conditioning = None
- if isinstance(cond, dict):
- image_conditioning = cond["c_concat"][0]
- cond = cond["c_crossattn"][0]
- unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
-
- assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
- cond = tensor
-
- # for DDIM, shapes must match, we can't just process cond and uncond independently;
- # filling unconditional_conditioning with repeats of the last vector to match length is
- # not 100% correct but should work well enough
- if unconditional_conditioning.shape[1] < cond.shape[1]:
- last_vector = unconditional_conditioning[:, -1:]
- last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
- unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
- elif unconditional_conditioning.shape[1] > cond.shape[1]:
- unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
-
- if self.mask is not None:
- img_orig = self.sampler.model.q_sample(self.init_latent, ts)
- x_dec = img_orig * self.mask + self.nmask * x_dec
-
- # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
- # Note that they need to be lists because it just concatenates them later.
- if image_conditioning is not None:
- cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
-
- if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
- else:
- self.last_latent = res[1]
-
- store_latent(self.last_latent)
-
- self.step += 1
- state.sampling_step = self.step
- shared.total_tqdm.update()
-
- return res
-
- def initialize(self, p):
- self.eta = p.eta if p.eta is not None else opts.eta_ddim
-
- for fieldname in ['p_sample_ddim', 'p_sample_plms']:
- if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
-
- self.mask = p.mask if hasattr(p, 'mask') else None
- self.nmask = p.nmask if hasattr(p, 'nmask') else None
-
- def adjust_steps_if_invalid(self, p, num_steps):
- if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
- valid_step = 999 / (1000 // num_steps)
- if valid_step == floor(valid_step):
- return int(valid_step) + 1
-
- return num_steps
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = setup_img2img_steps(p, steps)
- steps = self.adjust_steps_if_invalid(p, steps)
- self.initialize(p)
-
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
-
- self.init_latent = x
- self.last_latent = x
- self.step = 0
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- if image_conditioning is not None:
- conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- self.initialize(p)
-
- self.init_latent = None
- self.last_latent = x
- self.step = 0
-
- steps = self.adjust_steps_if_invalid(p, steps or p.steps)
-
- # Wrap the conditioning models with additional image conditioning for inpainting model
- # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
- if image_conditioning is not None:
- conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
- unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
-
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
-
- return samples_ddim
-
-
-class CFGDenoiser(torch.nn.Module):
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
- self.mask = None
- self.nmask = None
- self.init_latent = None
- self.step = 0
-
- def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
- denoised_uncond = x_out[-uncond.shape[0]:]
- denoised = torch.clone(denoised_uncond)
-
- for i, conds in enumerate(conds_list):
- for cond_index, weight in conds:
- denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
-
- return denoised
-
- def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
- if state.interrupted or state.skipped:
- raise InterruptedException
-
- conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
- uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
-
- batch_size = len(conds_list)
- repeats = [len(conds_list[i]) for i in range(batch_size)]
-
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
-
- denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
- cfg_denoiser_callback(denoiser_params)
- x_in = denoiser_params.x
- image_cond_in = denoiser_params.image_cond
- sigma_in = denoiser_params.sigma
-
- if tensor.shape[1] == uncond.shape[1]:
- cond_in = torch.cat([tensor, uncond])
-
- if shared.batch_cond_uncond:
- x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
- else:
- x_out = torch.zeros_like(x_in)
- for batch_offset in range(0, x_out.shape[0], batch_size):
- a = batch_offset
- b = a + batch_size
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
- else:
- x_out = torch.zeros_like(x_in)
- batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
- for batch_offset in range(0, tensor.shape[0], batch_size):
- a = batch_offset
- b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
-
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
-
- devices.test_for_nans(x_out, "unet")
-
- if opts.live_preview_content == "Prompt":
- store_latent(x_out[0:uncond.shape[0]])
- elif opts.live_preview_content == "Negative prompt":
- store_latent(x_out[-uncond.shape[0]:])
-
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
-
- if self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
-
- self.step += 1
-
- return denoised
-
-
-class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
-
- def __getattr__(self, item):
- if item == 'randn_like':
- return self.randn_like
-
- if hasattr(torch, item):
- return getattr(torch, item)
-
- raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
-
- def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
-
- if x.device.type == 'mps':
- return torch.randn_like(x, device=devices.cpu).to(x.device)
- else:
- return torch.randn_like(x)
-
-
-# MPS fix for randn in torchsde
-def torchsde_randn(size, dtype, device, seed):
- if device.type == 'mps':
- generator = torch.Generator(devices.cpu).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
- else:
- generator = torch.Generator(device).manual_seed(int(seed))
- return torch.randn(size, dtype=dtype, device=device, generator=generator)
-
-
-torchsde._brownian.brownian_interval._randn = torchsde_randn
-
-
-class KDiffusionSampler:
- def __init__(self, funcname, sd_model):
- denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
-
- self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
- self.funcname = funcname
- self.func = getattr(k_diffusion.sampling, self.funcname)
- self.extra_params = sampler_extra_params.get(funcname, [])
- self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
- self.sampler_noises = None
- self.stop_at = None
- self.eta = None
- self.default_eta = 1.0
- self.config = None
- self.last_latent = None
-
- self.conditioning_key = sd_model.model.conditioning_key
-
- def callback_state(self, d):
- step = d['i']
- latent = d["denoised"]
- if opts.live_preview_content == "Combined":
- store_latent(latent)
- self.last_latent = latent
-
- if self.stop_at is not None and step > self.stop_at:
- raise InterruptedException
-
- state.sampling_step = step
- shared.total_tqdm.update()
-
- def launch_sampling(self, steps, func):
- state.sampling_steps = steps
- state.sampling_step = 0
-
- try:
- return func()
- except InterruptedException:
- return self.last_latent
-
- def number_of_needed_noises(self, p):
- return p.steps
-
- def initialize(self, p):
- self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
- self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
- self.model_wrap_cfg.step = 0
- self.eta = p.eta or opts.eta_ancestral
-
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
-
- extra_params_kwargs = {}
- for param_name in self.extra_params:
- if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
- extra_params_kwargs[param_name] = getattr(p, param_name)
-
- if 'eta' in inspect.signature(self.func).parameters:
- extra_params_kwargs['eta'] = self.eta
-
- return extra_params_kwargs
-
- def get_sigmas(self, p, steps):
- discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
- if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
- discard_next_to_last_sigma = True
- p.extra_generation_params["Discard penultimate sigma"] = True
-
- steps += 1 if discard_next_to_last_sigma else 0
-
- if p.sampler_noise_scheduler_override:
- sigmas = p.sampler_noise_scheduler_override(steps)
- elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
- sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
-
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
- else:
- sigmas = self.model_wrap.get_sigmas(steps)
-
- if discard_next_to_last_sigma:
- sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
-
- return sigmas
-
- def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
- steps, t_enc = setup_img2img_steps(p, steps)
-
- sigmas = self.get_sigmas(p, steps)
-
- sigma_sched = sigmas[steps - t_enc - 1:]
- xi = x + noise * sigma_sched[0]
-
- extra_params_kwargs = self.initialize(p)
- if 'sigma_min' in inspect.signature(self.func).parameters:
- ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
- extra_params_kwargs['sigma_min'] = sigma_sched[-2]
- if 'sigma_max' in inspect.signature(self.func).parameters:
- extra_params_kwargs['sigma_max'] = sigma_sched[0]
- if 'n' in inspect.signature(self.func).parameters:
- extra_params_kwargs['n'] = len(sigma_sched) - 1
- if 'sigma_sched' in inspect.signature(self.func).parameters:
- extra_params_kwargs['sigma_sched'] = sigma_sched
- if 'sigmas' in inspect.signature(self.func).parameters:
- extra_params_kwargs['sigmas'] = sigma_sched
-
- self.model_wrap_cfg.init_latent = x
- self.last_latent = x
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
-
- return samples
-
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
- steps = steps or p.steps
-
- sigmas = self.get_sigmas(p, steps)
-
- x = x * sigmas[0]
-
- extra_params_kwargs = self.initialize(p)
- if 'sigma_min' in inspect.signature(self.func).parameters:
- extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
- extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in inspect.signature(self.func).parameters:
- extra_params_kwargs['n'] = steps
- else:
- extra_params_kwargs['sigmas'] = sigmas
-
- self.last_latent = x
- samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
-
- return samples
-
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
new file mode 100644
index 00000000..a1aac7cf
--- /dev/null
+++ b/modules/sd_samplers_common.py
@@ -0,0 +1,62 @@
+from collections import namedtuple
+import numpy as np
+import torch
+from PIL import Image
+from modules import devices, processing, images, sd_vae_approx
+
+from modules.shared import opts, state
+import modules.shared as shared
+
+SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
+
+
+def setup_img2img_steps(p, steps=None):
+ if opts.img2img_fix_steps or steps is not None:
+ requested_steps = (steps or p.steps)
+ steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
+ t_enc = requested_steps - 1
+ else:
+ steps = p.steps
+ t_enc = int(min(p.denoising_strength, 0.999) * steps)
+
+ return steps, t_enc
+
+
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+
+
+def single_sample_to_image(sample, approximation=None):
+ if approximation is None:
+ approximation = approximation_indexes.get(opts.show_progress_type, 0)
+
+ if approximation == 2:
+ x_sample = sd_vae_approx.cheap_approximation(sample)
+ elif approximation == 1:
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ else:
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+
+ x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+ return Image.fromarray(x_sample)
+
+
+def sample_to_image(samples, index=0, approximation=None):
+ return single_sample_to_image(samples[index], approximation)
+
+
+def samples_to_image_grid(samples, approximation=None):
+ return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
+
+
+def store_latent(decoded):
+ state.current_latent = decoded
+
+ if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
+ if not shared.parallel_processing_allowed:
+ shared.state.assign_current_image(sample_to_image(decoded))
+
+
+class InterruptedException(BaseException):
+ pass
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
new file mode 100644
index 00000000..d03131cd
--- /dev/null
+++ b/modules/sd_samplers_compvis.py
@@ -0,0 +1,160 @@
+import math
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
+
+import numpy as np
+import torch
+
+from modules.shared import state
+from modules import sd_samplers_common, prompt_parser, shared
+
+
+samplers_data_compvis = [
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
+ sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
+]
+
+
+class VanillaStableDiffusionSampler:
+ def __init__(self, constructor, sd_model):
+ self.sampler = constructor(sd_model)
+ self.is_plms = hasattr(self.sampler, 'p_sample_plms')
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.sampler_noises = None
+ self.step = 0
+ self.stop_at = None
+ self.eta = None
+ self.config = None
+ self.last_latent = None
+
+ self.conditioning_key = sd_model.model.conditioning_key
+
+ def number_of_needed_noises(self, p):
+ return 0
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except sd_samplers_common.InterruptedException:
+ return self.last_latent
+
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ if state.interrupted or state.skipped:
+ raise sd_samplers_common.InterruptedException
+
+ if self.stop_at is not None and self.step > self.stop_at:
+ raise sd_samplers_common.InterruptedException
+
+ # Have to unwrap the inpainting conditioning here to perform pre-processing
+ image_conditioning = None
+ if isinstance(cond, dict):
+ image_conditioning = cond["c_concat"][0]
+ cond = cond["c_crossattn"][0]
+ unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+
+ assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
+ cond = tensor
+
+ # for DDIM, shapes must match, we can't just process cond and uncond independently;
+ # filling unconditional_conditioning with repeats of the last vector to match length is
+ # not 100% correct but should work well enough
+ if unconditional_conditioning.shape[1] < cond.shape[1]:
+ last_vector = unconditional_conditioning[:, -1:]
+ last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
+ unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
+ elif unconditional_conditioning.shape[1] > cond.shape[1]:
+ unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
+
+ if self.mask is not None:
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
+ x_dec = img_orig * self.mask + self.nmask * x_dec
+
+ # Wrap the image conditioning back up since the DDIM code can accept the dict directly.
+ # Note that they need to be lists because it just concatenates them later.
+ if image_conditioning is not None:
+ cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ if self.mask is not None:
+ self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
+ else:
+ self.last_latent = res[1]
+
+ sd_samplers_common.store_latent(self.last_latent)
+
+ self.step += 1
+ state.sampling_step = self.step
+ shared.total_tqdm.update()
+
+ return res
+
+ def initialize(self, p):
+ self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ if self.eta != 0.0:
+ p.extra_generation_params["Eta DDIM"] = self.eta
+
+ for fieldname in ['p_sample_ddim', 'p_sample_plms']:
+ if hasattr(self.sampler, fieldname):
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
+
+ self.mask = p.mask if hasattr(p, 'mask') else None
+ self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
+ def adjust_steps_if_invalid(self, p, num_steps):
+ if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
+ valid_step = 999 / (1000 // num_steps)
+ if valid_step == math.floor(valid_step):
+ return int(valid_step) + 1
+
+ return num_steps
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+ steps = self.adjust_steps_if_invalid(p, steps)
+ self.initialize(p)
+
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
+ x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
+
+ self.init_latent = x
+ self.last_latent = x
+ self.step = 0
+
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ if image_conditioning is not None:
+ conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
+ unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
+
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ self.initialize(p)
+
+ self.init_latent = None
+ self.last_latent = x
+ self.step = 0
+
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
+
+ # Wrap the conditioning models with additional image conditioning for inpainting model
+ # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
+ if image_conditioning is not None:
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
+
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+
+ return samples_ddim
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
new file mode 100644
index 00000000..528f513f
--- /dev/null
+++ b/modules/sd_samplers_kdiffusion.py
@@ -0,0 +1,357 @@
+from collections import deque
+import torch
+import inspect
+import einops
+import k_diffusion.sampling
+from modules import prompt_parser, devices, sd_samplers_common
+
+from modules.shared import opts, state
+import modules.shared as shared
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+
+samplers_k_diffusion = [
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
+ ('Euler', 'sample_euler', ['k_euler'], {}),
+ ('LMS', 'sample_lms', ['k_lms'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
+]
+
+samplers_data_k_diffusion = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_k_diffusion
+ if hasattr(k_diffusion.sampling, funcname)
+]
+
+sampler_extra_params = {
+ 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+ 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
+}
+
+
+class CFGDenoiser(torch.nn.Module):
+ """
+ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
+ that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
+ instead of one. Originally, the second prompt is just an empty string, but we use non-empty
+ negative prompt.
+ """
+
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+ self.step = 0
+ self.image_cfg_scale = None
+
+ def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
+ denoised_uncond = x_out[-uncond.shape[0]:]
+ denoised = torch.clone(denoised_uncond)
+
+ for i, conds in enumerate(conds_list):
+ for cond_index, weight in conds:
+ denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
+
+ return denoised
+
+ def combine_denoised_for_edit_model(self, x_out, cond_scale):
+ out_cond, out_img_cond, out_uncond = x_out.chunk(3)
+ denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
+
+ return denoised
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
+ if state.interrupted or state.skipped:
+ raise sd_samplers_common.InterruptedException
+
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
+ # so is_edit_model is set to False to support AND composition.
+ is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
+ assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+
+ batch_size = len(conds_list)
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
+
+ if not is_edit_model:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
+ else:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
+
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+
+ if tensor.shape[1] == uncond.shape[1]:
+ if not is_edit_model:
+ cond_in = torch.cat([tensor, uncond])
+ else:
+ cond_in = torch.cat([tensor, uncond, uncond])
+
+ if shared.batch_cond_uncond:
+ x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
+ else:
+ x_out = torch.zeros_like(x_in)
+ batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
+ a = batch_offset
+ b = min(a + batch_size, tensor.shape[0])
+
+ if not is_edit_model:
+ c_crossattn = [tensor[a:b]]
+ else:
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
+
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
+
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ cfg_denoised_callback(denoised_params)
+
+ devices.test_for_nans(x_out, "unet")
+
+ if opts.live_preview_content == "Prompt":
+ sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
+ elif opts.live_preview_content == "Negative prompt":
+ sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
+
+ if not is_edit_model:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ else:
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+
+ if self.mask is not None:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+
+ self.step += 1
+
+ return denoised
+
+
+class TorchHijack:
+ def __init__(self, sampler_noises):
+ # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
+ # implementation.
+ self.sampler_noises = deque(sampler_noises)
+
+ def __getattr__(self, item):
+ if item == 'randn_like':
+ return self.randn_like
+
+ if hasattr(torch, item):
+ return getattr(torch, item)
+
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
+
+ def randn_like(self, x):
+ if self.sampler_noises:
+ noise = self.sampler_noises.popleft()
+ if noise.shape == x.shape:
+ return noise
+
+ if x.device.type == 'mps':
+ return torch.randn_like(x, device=devices.cpu).to(x.device)
+ else:
+ return torch.randn_like(x)
+
+
+class KDiffusionSampler:
+ def __init__(self, funcname, sd_model):
+ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+
+ self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
+ self.funcname = funcname
+ self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.extra_params = sampler_extra_params.get(funcname, [])
+ self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ self.sampler_noises = None
+ self.stop_at = None
+ self.eta = None
+ self.config = None
+ self.last_latent = None
+
+ self.conditioning_key = sd_model.model.conditioning_key
+
+ def callback_state(self, d):
+ step = d['i']
+ latent = d["denoised"]
+ if opts.live_preview_content == "Combined":
+ sd_samplers_common.store_latent(latent)
+ self.last_latent = latent
+
+ if self.stop_at is not None and step > self.stop_at:
+ raise sd_samplers_common.InterruptedException
+
+ state.sampling_step = step
+ shared.total_tqdm.update()
+
+ def launch_sampling(self, steps, func):
+ state.sampling_steps = steps
+ state.sampling_step = 0
+
+ try:
+ return func()
+ except sd_samplers_common.InterruptedException:
+ return self.last_latent
+
+ def number_of_needed_noises(self, p):
+ return p.steps
+
+ def initialize(self, p):
+ self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
+ self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
+ self.model_wrap_cfg.step = 0
+ self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
+ self.eta = p.eta if p.eta is not None else opts.eta_ancestral
+
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
+
+ extra_params_kwargs = {}
+ for param_name in self.extra_params:
+ if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
+ extra_params_kwargs[param_name] = getattr(p, param_name)
+
+ if 'eta' in inspect.signature(self.func).parameters:
+ if self.eta != 1.0:
+ p.extra_generation_params["Eta"] = self.eta
+
+ extra_params_kwargs['eta'] = self.eta
+
+ return extra_params_kwargs
+
+ def get_sigmas(self, p, steps):
+ discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
+ if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
+ discard_next_to_last_sigma = True
+ p.extra_generation_params["Discard penultimate sigma"] = True
+
+ steps += 1 if discard_next_to_last_sigma else 0
+
+ if p.sampler_noise_scheduler_override:
+ sigmas = p.sampler_noise_scheduler_override(steps)
+ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
+ else:
+ sigmas = self.model_wrap.get_sigmas(steps)
+
+ if discard_next_to_last_sigma:
+ sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
+
+ return sigmas
+
+ def create_noise_sampler(self, x, sigmas, p):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ if shared.opts.no_dpmpp_sde_batch_determinism:
+ return None
+
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+
+ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
+
+ sigmas = self.get_sigmas(p, steps)
+
+ sigma_sched = sigmas[steps - t_enc - 1:]
+ xi = x + noise * sigma_sched[0]
+
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+
+ if 'sigma_min' in parameters:
+ ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
+ extra_params_kwargs['sigma_min'] = sigma_sched[-2]
+ if 'sigma_max' in parameters:
+ extra_params_kwargs['sigma_max'] = sigma_sched[0]
+ if 'n' in parameters:
+ extra_params_kwargs['n'] = len(sigma_sched) - 1
+ if 'sigma_sched' in parameters:
+ extra_params_kwargs['sigma_sched'] = sigma_sched
+ if 'sigmas' in parameters:
+ extra_params_kwargs['sigmas'] = sigma_sched
+
+ if self.funcname == 'sample_dpmpp_sde':
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+
+ self.model_wrap_cfg.init_latent = x
+ self.last_latent = x
+ extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale,
+ }
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ return samples
+
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
+ steps = steps or p.steps
+
+ sigmas = self.get_sigmas(p, steps)
+
+ x = x * sigmas[0]
+
+ extra_params_kwargs = self.initialize(p)
+ parameters = inspect.signature(self.func).parameters
+
+ if 'sigma_min' in parameters:
+ extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
+ extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
+ if 'n' in parameters:
+ extra_params_kwargs['n'] = steps
+ else:
+ extra_params_kwargs['sigmas'] = sigmas
+
+ if self.funcname == 'sample_dpmpp_sde':
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+
+ self.last_latent = x
+ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
+ 'cond_scale': p.cfg_scale
+ }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+
+ return samples
+
diff --git a/modules/shared.py b/modules/shared.py
index 36e9762f..2c2edfbd 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -106,6 +106,8 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
+parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
+parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
script_loading.preload_extensions(extensions.extensions_dir, parser)
@@ -128,12 +130,13 @@ restricted_opts = {
ui_reorder_categories = [
"inpaint",
"sampler",
+ "checkboxes",
+ "hires_fix",
"dimensions",
"cfg",
"seed",
- "checkboxes",
- "hires_fix",
"batch",
+ "override_settings",
"scripts",
]
@@ -323,9 +326,11 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
- "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
+ "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
+ "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
+ "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
- "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
+ "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
@@ -347,10 +352,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
- "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
- "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
+ "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
+ "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("", "Directory name pattern", component_args=hide_dirs),
+ "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
@@ -362,7 +367,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
- "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
+ "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
}))
@@ -406,13 +411,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
- "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
+ "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
}))
@@ -432,7 +437,9 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
}))
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
- "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }),
+ "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
+ "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
@@ -440,7 +447,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
- "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+ "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"),
@@ -605,11 +612,37 @@ class Options:
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+ def cast_value(self, key, value):
+ """casts an arbitrary to the same type as this setting's value with key
+ Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
+ """
+
+ if value is None:
+ return None
+
+ default_value = self.data_labels[key].default
+ if default_value is None:
+ default_value = getattr(self, key, None)
+ if default_value is None:
+ return None
+
+ expected_type = type(default_value)
+ if expected_type == bool and value == "False":
+ value = False
+ else:
+ value = expected_type(value)
+
+ return value
+
+
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
+settings_components = None
+"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
+
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
"Latent": {"mode": "bilinear", "antialias": False},
diff --git a/modules/shared_items.py b/modules/shared_items.py
index 8b5ec96d..e792a134 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -20,4 +20,4 @@ def sd_vae_items():
def refresh_vae_list():
import modules.sd_vae
- return modules.sd_vae.refresh_vae_list
+ modules.sd_vae.refresh_vae_list()
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index d31963d4..af9fbcf2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
- def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
+ def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):
self.filename = filename
self.filename_text = filename_text
+ self.weight = weight
self.latent_dist = latent_dist
self.latent_sample = latent_sample
self.cond = cond
@@ -30,7 +31,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
+ alpha_channel = None
if shared.state.interrupted:
raise Exception("interrupted")
try:
- image = Image.open(path).convert('RGB')
+ image = Image.open(path)
+ #Currently does not work for single color transparency
+ #We would need to read image.info['transparency'] for that
+ if use_weight and 'A' in image.getbands():
+ alpha_channel = image.getchannel('A')
+ image = image.convert('RGB')
if not varsize:
image = image.resize((width, height), PIL.Image.BICUBIC)
except Exception:
@@ -87,17 +94,35 @@ class PersonalizedBase(Dataset):
with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
- if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
- latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- latent_sampling_method = "once"
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
- elif latent_sampling_method == "deterministic":
- # Works only for DiagonalGaussianDistribution
- latent_dist.std = 0
- latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
- elif latent_sampling_method == "random":
- entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
+ #Perform latent sampling, even for random sampling.
+ #We need the sample dimensions for the weights
+ if latent_sampling_method == "deterministic":
+ if isinstance(latent_dist, DiagonalGaussianDistribution):
+ # Works only for DiagonalGaussianDistribution
+ latent_dist.std = 0
+ else:
+ latent_sampling_method = "once"
+ latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
+
+ if use_weight and alpha_channel is not None:
+ channels, *latent_size = latent_sample.shape
+ weight_img = alpha_channel.resize(latent_size)
+ npweight = np.array(weight_img).astype(np.float32)
+ #Repeat for every channel in the latent sample
+ weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)
+ #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
+ weight -= weight.min()
+ weight /= weight.mean()
+ elif use_weight:
+ #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
+ weight = torch.ones(latent_sample.shape)
+ else:
+ weight = None
+
+ if latent_sampling_method == "random":
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
+ else:
+ entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)
if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text)
@@ -110,6 +135,7 @@ class PersonalizedBase(Dataset):
del torchdata
del latent_dist
del latent_sample
+ del weight
self.length = len(self.dataset)
self.groups = list(groups.values())
@@ -195,6 +221,10 @@ class BatchLoader:
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ if all(entry.weight is not None for entry in data):
+ self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)
+ else:
+ self.weight = None
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6cf00e65..c63c7d1d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -112,6 +112,7 @@ class EmbeddingDatabase:
self.skipped_embeddings = {}
self.expected_shape = -1
self.embedding_dirs = {}
+ self.previously_displayed_embeddings = ()
def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
@@ -228,9 +229,12 @@ class EmbeddingDatabase:
self.load_from_dir(embdir)
embdir.update()
- print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
- if len(self.skipped_embeddings) > 0:
- print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
+ displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
+ if self.previously_displayed_embeddings != displayed_embeddings:
+ self.previously_displayed_embeddings = displayed_embeddings
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+ if len(self.skipped_embeddings) > 0:
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -347,7 +351,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert log_directory, "Log directory is empty"
-def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
template_file = textual_inversion_templates.get(template_filename, None)
@@ -406,7 +410,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
pin_memory = shared.opts.pin_memory
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
if shared.opts.save_training_settings_to_txt:
save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
@@ -476,6 +480,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ if use_weight:
+ w = batch.weight.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
if is_training_inpainting_model:
@@ -486,7 +492,11 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
else:
cond = c
- loss = shared.sd_model(x, cond)[0] / gradient_step
+ if use_weight:
+ loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
+ del w
+ else:
+ loss = shared.sd_model.forward(x, cond)[0] / gradient_step
del x
_loss_step += loss.item()
diff --git a/modules/txt2img.py b/modules/txt2img.py
index e945fd69..16841d0f 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,5 +1,6 @@
import modules.scripts
from modules import sd_samplers
+from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
@@ -8,7 +9,9 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
+ override_settings = create_override_settings_dict(override_settings_texts)
+
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -38,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_second_pass_steps=hr_second_pass_steps,
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
+ override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
diff --git a/modules/ui.py b/modules/ui.py
index 9f4cfda1..0516c643 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -380,6 +380,7 @@ def apply_setting(key, value):
opts.save(shared.config_filename)
return getattr(opts, key)
+
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
@@ -433,6 +434,18 @@ def get_value_for_setting(key):
return gr.update(value=value, **args)
+def create_override_settings_dropdown(tabname, row):
+ dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
+
+ dropdown.change(
+ fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
+ inputs=[dropdown],
+ outputs=[dropdown],
+ )
+
+ return dropdown
+
+
def create_ui():
import modules.img2img
import modules.txt2img
@@ -466,8 +479,8 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
if opts.dimensions_and_batch_together:
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
@@ -503,6 +516,10 @@ def create_ui():
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
+ elif category == "override_settings":
+ with FormRow(elem_id="txt2img_override_settings_row") as row:
+ override_settings = create_override_settings_dropdown('txt2img', row)
+
elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
@@ -524,7 +541,6 @@ def create_ui():
)
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
- parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -555,6 +571,7 @@ def create_ui():
hr_second_pass_steps,
hr_resize_x,
hr_resize_y,
+ override_settings,
] + custom_inputs,
outputs=[
@@ -614,7 +631,10 @@ def create_ui():
(hr_resize_y, "Hires resize-2"),
*modules.scripts.scripts_txt2img.infotext_fields
]
- parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields)
+ parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
+ ))
txt2img_preview_params = [
txt2img_prompt,
@@ -737,15 +757,17 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+ res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
if opts.dimensions_and_batch_together:
- res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
elif category == "cfg":
with FormGroup():
- cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ with FormRow():
+ cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
+ image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "seed":
@@ -762,6 +784,10 @@ def create_ui():
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
+ elif category == "override_settings":
+ with FormRow(elem_id="img2img_override_settings_row") as row:
+ override_settings = create_override_settings_dropdown('img2img', row)
+
elif category == "scripts":
with FormGroup(elem_id="img2img_script_container"):
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
@@ -796,7 +822,6 @@ def create_ui():
)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
- parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
@@ -838,6 +863,7 @@ def create_ui():
batch_count,
batch_size,
cfg_scale,
+ image_cfg_scale,
denoising_strength,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
@@ -849,7 +875,8 @@ def create_ui():
inpainting_mask_invert,
img2img_batch_input_dir,
img2img_batch_output_dir,
- img2img_batch_inpaint_mask_dir
+ img2img_batch_inpaint_mask_dir,
+ override_settings,
] + custom_inputs,
outputs=[
img2img_gallery,
@@ -923,6 +950,7 @@ def create_ui():
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"),
+ (image_cfg_scale, "Image CFG scale"),
(seed, "Seed"),
(width, "Size-1"),
(height, "Size-2"),
@@ -935,8 +963,11 @@ def create_ui():
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields
]
- parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
- parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)
+ parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
+ parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
+ ))
modules.scripts.scripts_current = None
@@ -954,7 +985,11 @@ def create_ui():
html2 = gr.HTML()
with gr.Row():
buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"])
- parameters_copypaste.bind_buttons(buttons, image, generation_info)
+
+ for tabname, button in buttons.items():
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,
+ ))
image.change(
fn=wrap_gradio_call(modules.extras.run_pnginfo),
@@ -1156,6 +1191,8 @@ def create_ui():
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every")
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every")
+ use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight")
+
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding")
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img")
@@ -1269,6 +1306,7 @@ def create_ui():
shuffle_tags,
tag_drop_out,
latent_sampling_method,
+ use_weight,
create_image_every,
save_embedding_every,
template_file,
@@ -1302,6 +1340,7 @@ def create_ui():
shuffle_tags,
tag_drop_out,
latent_sampling_method,
+ use_weight,
create_image_every,
save_embedding_every,
template_file,
@@ -1363,6 +1402,7 @@ def create_ui():
components = []
component_dict = {}
+ shared.settings_components = component_dict
script_callbacks.ui_settings_callback()
opts.reorder()
@@ -1529,8 +1569,7 @@ def create_ui():
component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component
- parameters_copypaste.integrate_settings_paste_fields(component_dict)
- parameters_copypaste.run_bind()
+ parameters_copypaste.connect_paste_params_buttons()
with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces:
@@ -1560,6 +1599,20 @@ def create_ui():
outputs=[component, text_settings],
)
+ text_settings.change(
+ fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
+ inputs=[],
+ outputs=[image_cfg_scale],
+ )
+
+ button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
+ button_set_checkpoint.click(
+ fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+ inputs=[component_dict['sd_model_checkpoint'], dummy_component],
+ outputs=[component_dict['sd_model_checkpoint'], text_settings],
+ )
+
component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
def get_settings_values():
@@ -1692,14 +1745,14 @@ def create_ui():
def reload_javascript():
- head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}"></script>\n'
+ head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}?{os.path.getmtime("script.js")}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
inline += f"set_theme('{cmd_opts.theme}');"
for script in modules.scripts.list_scripts("javascript", ".js"):
- head += f'<script type="text/javascript" src="file={script.path}"></script>\n'
+ head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
head += f'<script type="text/javascript">{inline}</script>\n'
@@ -1733,7 +1786,7 @@ def versions_html():
return f"""
python: <span title="{sys.version}">{python_version}</span>
 • 
-torch: {torch.__version__}
+torch: {getattr(torch, '__long_version__',torch.__version__)}
 • 
xformers: {xformers_version}
 • 
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 9405ac1f..fd047f31 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -198,5 +198,9 @@ Requested path was: {f}
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
- parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None)
+ for paste_tabname, paste_button in buttons.items():
+ parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
+ paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
+ ))
+
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index 66a41865..bd4308ef 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -13,7 +13,7 @@ import shutil
import errno
from modules import extensions, shared, paths
-
+from modules.call_queue import wrap_gradio_gpu_call
available_extensions = {"extensions": []}
@@ -50,12 +50,17 @@ def apply_and_restart(disable_list, update_list):
shared.state.need_restart = True
-def check_updates():
+def check_updates(id_task, disable_list):
check_access()
- for ext in extensions.extensions:
- if ext.remote is None:
- continue
+ disabled = json.loads(disable_list)
+ assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
+
+ exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]
+ shared.state.job_count = len(exts)
+
+ for ext in exts:
+ shared.state.textinfo = ext.name
try:
ext.check_updates()
@@ -63,7 +68,9 @@ def check_updates():
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- return extension_table()
+ shared.state.nextjob()
+
+ return extension_table(), ""
def extension_table():
@@ -73,6 +80,7 @@ def extension_table():
<tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
<th>URL</th>
+ <th><abbr title="Extension version">Version</abbr></th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
@@ -80,11 +88,7 @@ def extension_table():
"""
for ext in extensions.extensions:
- remote = ""
- if ext.is_builtin:
- remote = "built-in"
- elif ext.remote:
- remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
+ remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
@@ -95,6 +99,7 @@ def extension_table():
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td>{remote}</td>
+ <td>{ext.version}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@@ -273,12 +278,13 @@ def create_ui():
with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"):
- with gr.Row():
+ with gr.Row(elem_id="extensions_installed_top"):
apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
+ info = gr.HTML()
extensions_table = gr.HTML(lambda: extension_table())
apply.click(
@@ -289,10 +295,10 @@ def create_ui():
)
check.click(
- fn=check_updates,
+ fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),
_js="extensions_check",
- inputs=[],
- outputs=[extensions_table],
+ inputs=[info, extensions_disabled_list],
+ outputs=[extensions_table, info],
)
with gr.TabItem("Available"):
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index c6ff889a..71f1d81f 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -1,4 +1,7 @@
+import glob
import os.path
+import urllib.parse
+from pathlib import Path
from modules import shared
import gradio as gr
@@ -8,12 +11,32 @@ import html
from modules.generation_parameters_copypaste import image_from_url_text
extra_pages = []
+allowed_dirs = set()
def register_page(page):
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
extra_pages.append(page)
+ allowed_dirs.clear()
+ allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
+
+
+def add_pages_to_demo(app):
+ def fetch_file(filename: str = ""):
+ from starlette.responses import FileResponse
+
+ if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
+ raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
+
+ ext = os.path.splitext(filename)[1].lower()
+ if ext not in (".png", ".jpg"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg.")
+
+ # would profit from returning 304
+ return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
+
+ app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
class ExtraNetworksPage:
@@ -26,10 +49,48 @@ class ExtraNetworksPage:
def refresh(self):
pass
+ def link_preview(self, filename):
+ return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
+
+ def search_terms_from_path(self, filename, possible_directories=None):
+ abspath = os.path.abspath(filename)
+
+ for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
+ parentdir = os.path.abspath(parentdir)
+ if abspath.startswith(parentdir):
+ return abspath[len(parentdir):].replace('\\', '/')
+
+ return ""
+
def create_html(self, tabname):
view = shared.opts.extra_networks_default_view
items_html = ''
+ subdirs = {}
+ for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
+ for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
+ if not os.path.isdir(x):
+ continue
+
+ subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
+ while subdir.startswith("/"):
+ subdir = subdir[1:]
+
+ is_empty = len(os.listdir(x)) == 0
+ if not is_empty and not subdir.endswith("/"):
+ subdir = subdir + "/"
+
+ subdirs[subdir] = 1
+
+ if subdirs:
+ subdirs = {"": 1, **subdirs}
+
+ subdirs_html = "".join([f"""
+<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
+{html.escape(subdir if subdir!="" else "all")}
+</button>
+""" for subdir in subdirs])
+
for item in self.list_items():
items_html += self.create_html_for_item(item, tabname)
@@ -37,8 +98,13 @@ class ExtraNetworksPage:
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+ self_name_id = self.name.replace(" ", "_")
+
res = f"""
-<div id='{tabname}_{self.name}_cards' class='extra-network-{view}'>
+<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-{view}'>
+{subdirs_html}
+</div>
+<div id='{tabname}_{self_name_id}_cards' class='extra-network-{view}'>
{items_html}
</div>
"""
@@ -54,14 +120,19 @@ class ExtraNetworksPage:
def create_html_for_item(self, item, tabname):
preview = item.get("preview", None)
+ onclick = item.get("onclick", None)
+ if onclick is None:
+ onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
+
args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
- "prompt": item["prompt"],
+ "prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
"name": item["name"],
- "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
+ "card_clicked": onclick,
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
+ "search_term": item.get("search_term", ""),
}
return self.card_page.format(**args)
@@ -143,7 +214,7 @@ def path_is_parent(parent_path, child_path):
parent_path = os.path.abspath(parent_path)
child_path = os.path.abspath(child_path)
- return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
+ return child_path.startswith(parent_path)
def setup_ui(ui, gallery):
@@ -173,7 +244,8 @@ def setup_ui(ui, gallery):
ui.button_save_preview.click(
fn=save_preview,
- _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
+ _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
outputs=[*ui.pages]
)
+
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
new file mode 100644
index 00000000..04097a79
--- /dev/null
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -0,0 +1,39 @@
+import html
+import json
+import os
+import urllib.parse
+
+from modules import shared, ui_extra_networks, sd_models
+
+
+class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Checkpoints')
+
+ def refresh(self):
+ shared.refresh_checkpoints()
+
+ def list_items(self):
+ checkpoint: sd_models.CheckpointInfo
+ for name, checkpoint in sd_models.checkpoints_list.items():
+ path, ext = os.path.splitext(checkpoint.filename)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = self.link_preview(file)
+ break
+
+ yield {
+ "name": checkpoint.name_for_extra,
+ "filename": path,
+ "preview": preview,
+ "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
+ "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
+
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 65d000cf..57851088 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -19,13 +19,14 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
preview = None
for file in previews:
if os.path.isfile(file):
- preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ preview = self.link_preview(file)
break
yield {
"name": name,
"filename": path,
"preview": preview,
+ "search_term": self.search_terms_from_path(path),
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index dbd23d2d..bb64eb81 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -19,12 +19,13 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
preview = None
if os.path.isfile(preview_file):
- preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
+ preview = self.link_preview(preview_file)
yield {
"name": embedding.name,
"filename": embedding.filename,
"preview": preview,
+ "search_term": self.search_terms_from_path(embedding.filename),
"prompt": json.dumps(embedding.name),
"local_preview": path + ".preview.png",
}
diff --git a/requirements_versions.txt b/requirements_versions.txt
index eaa08806..331d0fe8 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -27,3 +27,4 @@ GitPython==3.1.27
torchsde==0.2.5
safetensors==0.2.7
httpcore<=0.15
+fastapi==0.90.1
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
index cbdfc6b3..2572443f 100644
--- a/scripts/img2imgalt.py
+++ b/scripts/img2imgalt.py
@@ -6,7 +6,7 @@ from tqdm import trange
import modules.scripts as scripts
import gradio as gr
-from modules import processing, shared, sd_samplers, prompt_parser
+from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
from modules.processing import Processed
from modules.shared import opts, cmd_opts, state
@@ -50,7 +50,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
x = x + d * dt
- sd_samplers.store_latent(x)
+ sd_samplers_common.store_latent(x)
# This shouldn't be necessary, but solved some VRAM issues
del x_in, sigma_in, cond_in, c_out, c_in, t,
@@ -104,7 +104,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
dt = sigmas[i] - sigmas[i - 1]
x = x + d * dt
- sd_samplers.store_latent(x)
+ sd_samplers_common.store_latent(x)
# This shouldn't be necessary, but solved some VRAM issues
del x_in, sigma_in, cond_in, c_out, c_in, t,
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index dd95e588..b1c486d4 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -44,16 +44,34 @@ class Script(scripts.Script):
def title(self):
return "Prompt matrix"
- def ui(self, is_img2img):
- put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
- different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
-
- return [put_at_start, different_seeds]
-
- def run(self, p, put_at_start, different_seeds):
+ def ui(self, is_img2img):
+ gr.HTML('<br />')
+ with gr.Row():
+ with gr.Column():
+ put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start"))
+ different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds"))
+ with gr.Column():
+ prompt_type = gr.Radio(["positive", "negative"], label="Select prompt", elem_id=self.elem_id("prompt_type"), value="positive")
+ variations_delimiter = gr.Radio(["comma", "space"], label="Select joining char", elem_id=self.elem_id("variations_delimiter"), value="comma")
+ with gr.Column():
+ margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
+
+ return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]
+
+ def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):
modules.processing.fix_seed(p)
+ # Raise error if promp type is not positive or negative
+ if prompt_type not in ["positive", "negative"]:
+ raise ValueError(f"Unknown prompt type {prompt_type}")
+ # Raise error if variations delimiter is not comma or space
+ if variations_delimiter not in ["comma", "space"]:
+ raise ValueError(f"Unknown variations delimiter {variations_delimiter}")
+
+ prompt = p.prompt if prompt_type == "positive" else p.negative_prompt
+ original_prompt = prompt[0] if type(prompt) == list else prompt
+ positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
- original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
+ delimiter = ", " if variations_delimiter == "comma" else " "
all_prompts = []
prompt_matrix_parts = original_prompt.split("|")
@@ -66,20 +84,23 @@ class Script(scripts.Script):
else:
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
- all_prompts.append(", ".join(selected_prompts))
+ all_prompts.append(delimiter.join(selected_prompts))
p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
p.do_not_save_grid = True
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
- p.prompt = all_prompts
+ if prompt_type == "positive":
+ p.prompt = all_prompts
+ else:
+ p.negative_prompt = all_prompts
p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]
- p.prompt_for_display = original_prompt
+ p.prompt_for_display = positive_prompt
processed = process_images(p)
- grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
- grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
+ grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
+ grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[1].height, prompt_matrix_parts, margin_size)
processed.images.insert(0, grid)
processed.index_of_first_image = 1
processed.infotexts.insert(0, processed.infotexts[0])
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index f0116055..4afcbbfd 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -25,6 +25,8 @@ from modules.ui_components import ToolButton
fill_values_symbol = "\U0001f4d2" # 📒
+AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])
+
def apply_field(field):
def fun(p, x, xs):
@@ -186,6 +188,7 @@ axis_options = [
AxisOption("Steps", int, apply_field("steps")),
AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
AxisOption("CFG Scale", float, apply_field("cfg_scale")),
+ AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
@@ -205,7 +208,7 @@ axis_options = [
]
-def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed):
+def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
title_texts = [[images.GridAnnotation(z)] for z in z_labels]
@@ -286,23 +289,24 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
return Processed(p, [])
- grids = [None] * len(zs)
+ sub_grids = [None] * len(zs)
for i in range(len(zs)):
start_index = i * len(xs) * len(ys)
end_index = start_index + len(xs) * len(ys)
grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
if draw_legend:
- grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
-
- grids[i] = grid
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
+ sub_grids[i] = grid
if include_sub_grids and len(zs) > 1:
processed_result.images.insert(i+1, grid)
- original_grid_size = grids[0].size
- grids = images.image_grid(grids, rows=1)
- processed_result.images[0] = images.draw_grid_annotations(grids, original_grid_size[0], original_grid_size[1], title_texts, [[images.GridAnnotation()]])
+ sub_grid_size = sub_grids[0].size
+ z_grid = images.image_grid(sub_grids, rows=1)
+ if draw_legend:
+ z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
+ processed_result.images[0] = z_grid
- return processed_result
+ return processed_result, sub_grids
class SharedSettingsStackHelper(object):
@@ -350,10 +354,16 @@ class Script(scripts.Script):
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"):
- draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
- include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
- include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
- no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ with gr.Column():
+ draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
+ no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
+ with gr.Column():
+ include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
+ include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
+ with gr.Column():
+ margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
+
+ with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
@@ -383,9 +393,18 @@ class Script(scripts.Script):
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
- return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds]
+ self.infotext_fields = (
+ (x_type, "X Type"),
+ (x_values, "X Values"),
+ (y_type, "Y Type"),
+ (y_values, "Y Values"),
+ (z_type, "Z Type"),
+ (z_values, "Z Values"),
+ )
- def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds):
+ return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
+
+ def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
@@ -504,6 +523,10 @@ class Script(scripts.Script):
grid_infotext = [None]
+ state.xyz_plot_x = AxisInfo(x_opt, xs)
+ state.xyz_plot_y = AxisInfo(y_opt, ys)
+ state.xyz_plot_z = AxisInfo(z_opt, zs)
+
# If one of the axes is very slow to change between (like SD model
# checkpoint), then make sure it is in the outer iteration of the nested
# `for` loop.
@@ -542,6 +565,7 @@ class Script(scripts.Script):
if grid_infotext[0] is None:
pc.extra_generation_params = copy(pc.extra_generation_params)
+ pc.extra_generation_params['Script'] = self.title()
if x_opt.label != 'Nothing':
pc.extra_generation_params["X Type"] = x_opt.label
@@ -566,7 +590,7 @@ class Script(scripts.Script):
return res
with SharedSettingsStackHelper():
- processed = draw_xyz_grid(
+ processed, sub_grids = draw_xyz_grid(
p,
xs=xs,
ys=ys,
@@ -579,9 +603,14 @@ class Script(scripts.Script):
include_lone_images=include_lone_images,
include_sub_grids=include_sub_grids,
first_axes_processed=first_axes_processed,
- second_axes_processed=second_axes_processed
+ second_axes_processed=second_axes_processed,
+ margin_size=margin_size
)
+ if opts.grid_save and len(sub_grids) > 1:
+ for sub_grid in sub_grids:
+ images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+
if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
diff --git a/style.css b/style.css
index dd914104..05572f66 100644
--- a/style.css
+++ b/style.css
@@ -74,7 +74,12 @@
#txt2img_gallery img, #img2img_gallery img{
object-fit: scale-down;
}
-
+#txt2img_actions_column, #img2img_actions_column {
+ margin: 0.35rem 0.75rem 0.35rem 0;
+}
+#script_list {
+ padding: .625rem .75rem 0 .625rem;
+}
.justify-center.overflow-x-scroll {
justify-content: left;
}
@@ -126,6 +131,7 @@
#txt2img_actions_column, #img2img_actions_column{
gap: 0;
+ margin-right: .75rem;
}
#txt2img_tools, #img2img_tools{
@@ -150,6 +156,7 @@
#txt2img_styles_row, #img2img_styles_row{
gap: 0.25em;
+ margin-top: 0.3em;
}
#txt2img_styles_row > button, #img2img_styles_row > button{
@@ -311,11 +318,11 @@ input[type="range"]{
.min-h-\[6rem\] { min-height: unset !important; }
.progressDiv{
- position: absolute;
+ position: relative;
height: 20px;
- top: -20px;
background: #b4c0cc;
border-radius: 3px !important;
+ margin-bottom: -3px;
}
.dark .progressDiv{
@@ -535,7 +542,7 @@ input[type="range"]{
}
#quicksettings {
- gap: 0.4em;
+ width: fit-content;
}
#quicksettings > div, #quicksettings > fieldset{
@@ -545,6 +552,7 @@ input[type="range"]{
border: none;
box-shadow: none;
background: none;
+ margin-right: 10px;
}
#quicksettings > div > div > div > label > span {
@@ -567,7 +575,7 @@ canvas[key="mask"] {
right: 0.5em;
top: -0.6em;
z-index: 400;
- width: 8em;
+ width: 6em;
}
#quicksettings .gr-box > div > div > input.gr-text-input {
top: -1.12em;
@@ -665,11 +673,27 @@ canvas[key="mask"] {
#quicksettings .gr-button-tool{
margin: 0;
+ border-color: unset;
+ background-color: unset;
}
-
+#modelmerger_interp_description>p {
+ margin: 0!important;
+ text-align: center;
+}
+#modelmerger_interp_description {
+ margin: 0.35rem 0.75rem 1.23rem;
+}
#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form {
padding-top: 0.9em;
+ padding-bottom: 0.9em;
+}
+#txt2img_settings {
+ padding-top: 1.16em;
+ padding-bottom: 0.9em;
+}
+#img2img_settings {
+ padding-bottom: 0.9em;
}
#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{
@@ -741,6 +765,7 @@ footer {
.dark .gr-compact{
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
+ margin-left: 0;
}
.gr-compact{
@@ -782,7 +807,13 @@ footer {
margin: 0.3em;
}
+.extra-network-subdirs{
+ padding: 0.2em 0.35em;
+}
+.extra-network-subdirs button{
+ margin: 0 0.15em;
+}
#txt2img_extra_networks .search, #img2img_extra_networks .search{
display: inline-block;
@@ -925,3 +956,6 @@ footer {
color: red;
}
+[id*='_prompt_container'] > div {
+ margin: 0!important;
+}
diff --git a/webui-macos-env.sh b/webui-macos-env.sh
index fa187dd1..37cac4fb 100644
--- a/webui-macos-env.sh
+++ b/webui-macos-env.sh
@@ -10,7 +10,7 @@ then
fi
export install_dir="$HOME"
-export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate"
+export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
diff --git a/webui.py b/webui.py
index 0e2b28b9..69312a19 100644
--- a/webui.py
+++ b/webui.py
@@ -12,7 +12,7 @@ from packaging import version
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import import_hook, errors, extra_networks
+from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
@@ -20,6 +20,7 @@ import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
+ torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
@@ -52,6 +53,9 @@ else:
def check_versions():
+ if shared.cmd_opts.skip_version_check:
+ return
+
expected_torch_version = "1.13.1"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
@@ -59,7 +63,10 @@ def check_versions():
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
-Beware that this will cause a lot of large files to be downloaded.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
""".strip())
expected_xformers_version = "0.0.16rc425"
@@ -71,6 +78,8 @@ Beware that this will cause a lot of large files to be downloaded.
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
""".strip())
@@ -89,7 +98,6 @@ def initialize():
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
- shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.list_builtin_upscalers()
modules.scripts.load_scripts()
@@ -119,6 +127,7 @@ def initialize():
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+ ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
@@ -227,6 +236,8 @@ def webui():
if launch_api:
create_api(app)
+ ui_extra_networks.add_pages_to_demo(app)
+
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(shared.demo)
@@ -254,6 +265,7 @@ def webui():
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
+ ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())