diff options
31 files changed, 518 insertions, 455 deletions
diff --git a/.eslintrc.js b/.eslintrc.js index 9e7ab34d..944cc869 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -1,3 +1,4 @@ +/* global module */ module.exports = { env: { browser: true, @@ -25,9 +26,14 @@ module.exports = { "linebreak-style": ["error", "unix"], "no-extra-semi": "error", "no-mixed-spaces-and-tabs": "error", + "no-multi-spaces": "error", + "no-redeclare": ["error", {builtinGlobals: false}], "no-trailing-spaces": "error", + "no-unused-vars": "off", "no-whitespace-before-property": "error", "object-curly-newline": ["error", {consistent: true, multiline: true}], + "object-curly-spacing": ["error", "never"], + "operator-linebreak": ["error", "after"], "quote-props": ["error", "consistent-as-needed"], "semi": ["error", "always"], "semi-spacing": "error", @@ -40,51 +46,43 @@ module.exports = { "switch-colon-spacing": "error", "template-curly-spacing": ["error", "never"], "unicode-bom": "error", - "no-multi-spaces": "error", - "object-curly-spacing": ["error", "never"], - "operator-linebreak": ["error", "after"], - "no-unused-vars": "off", - "no-redeclare": "off", }, globals: { - // this file - module: "writable", //script.js - gradioApp: "writable", - onUiLoaded: "writable", - onUiUpdate: "writable", - onOptionsChanged: "writable", + gradioApp: "readonly", + onUiLoaded: "readonly", + onUiUpdate: "readonly", + onOptionsChanged: "readonly", uiCurrentTab: "writable", - uiElementIsVisible: "writable", - uiElementInSight: "writable", - executeCallbacks: "writable", + uiElementIsVisible: "readonly", + uiElementInSight: "readonly", + executeCallbacks: "readonly", //ui.js opts: "writable", - all_gallery_buttons: "writable", - selected_gallery_button: "writable", - selected_gallery_index: "writable", - args_to_array: "writable", - switch_to_txt2img: "writable", - switch_to_img2img_tab: "writable", - switch_to_img2img: "writable", - switch_to_sketch: "writable", - switch_to_inpaint: "writable", - switch_to_inpaint_sketch: "writable", - switch_to_extras: "writable", - get_tab_index: "writable", - create_submit_args: "writable", - restart_reload: "writable", - updateInput: "writable", + all_gallery_buttons: "readonly", + selected_gallery_button: "readonly", + selected_gallery_index: "readonly", + switch_to_txt2img: "readonly", + switch_to_img2img_tab: "readonly", + switch_to_img2img: "readonly", + switch_to_sketch: "readonly", + switch_to_inpaint: "readonly", + switch_to_inpaint_sketch: "readonly", + switch_to_extras: "readonly", + get_tab_index: "readonly", + create_submit_args: "readonly", + restart_reload: "readonly", + updateInput: "readonly", //extraNetworks.js - requestGet: "writable", - popup: "writable", + requestGet: "readonly", + popup: "readonly", // from python - localization: "writable", + localization: "readonly", // progrssbar.js - randomId: "writable", - requestProgress: "writable", + randomId: "readonly", + requestProgress: "readonly", // imageviewer.js - modalPrevImage: "writable", - modalNextImage: "writable", + modalPrevImage: "readonly", + modalNextImage: "readonly", } }; diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 00000000..4104da63 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Apply ESlint +9c54b78d9dde5601e916f308d9a9d6953ec39430
\ No newline at end of file diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 69056331..c9fcda2e 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,28 +1,15 @@ -# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! +## Description -If you have a large change, pay special attention to this paragraph: +* a simple description of what you're trying to accomplish +* a summary of changes in code +* which issues it fixes, if any -> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. +## Screenshots/videos: -Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. -**Describe what this pull request is trying to achieve.** +## Checklist: -A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. - -**Additional notes and description of your changes** - -More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. - -**Environment this was tested in** - -List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. - - OS: [e.g. Windows, Linux] - - Browser: [e.g. chrome, safari] - - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] - -**Screenshots or videos of your changes** - -If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. - -This is **required** for anything that touches the user interface.
\ No newline at end of file +- [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) +- [ ] I have performed a self-review of my own code +- [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style) +- [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 0708398b..226cf759 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -18,18 +18,53 @@ jobs: cache-dependency-path: | **/requirements*txt launch.py - - name: Run tests - run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test + - name: Install test dependencies + run: pip install wait-for-it -r requirements-test.txt + env: + PIP_DISABLE_PIP_VERSION_CHECK: "1" + PIP_PROGRESS_BAR: "off" + - name: Setup environment + run: python launch.py --skip-torch-cuda-test --exit env: PIP_DISABLE_PIP_VERSION_CHECK: "1" PIP_PROGRESS_BAR: "off" TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu WEBUI_LAUNCH_LIVE_OUTPUT: "1" - - name: Upload main app stdout-stderr + PYTHONUNBUFFERED: "1" + - name: Start test server + run: > + python -m coverage run + --data-file=.coverage.server + launch.py + --skip-prepare-environment + --skip-torch-cuda-test + --test-server + --no-half + --disable-opt-split-attention + --use-cpu all + --add-stop-route + 2>&1 | tee output.txt & + - name: Run tests + run: | + wait-for-it --service 127.0.0.1:7860 -t 600 + python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test + - name: Kill test server + if: always() + run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10 + - name: Show coverage + run: | + python -m coverage combine .coverage* + python -m coverage report -i + python -m coverage html -i + - name: Upload main app output + uses: actions/upload-artifact@v3 + if: always() + with: + name: output + path: output.txt + - name: Upload coverage HTML uses: actions/upload-artifact@v3 if: always() with: - name: stdout-stderr - path: | - test/stdout.txt - test/stderr.txt + name: htmlcov + path: htmlcov @@ -35,4 +35,5 @@ notification.mp3 /cache.json* /config_states/ /node_modules -/package-lock.json
\ No newline at end of file +/package-lock.json +/.coverage* diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index ccb249ac..b5fea4d2 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -23,5 +23,23 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): lora.load_loras(names, multipliers)
+ if shared.opts.lora_add_hashes_to_infotext:
+ lora_hashes = []
+ for item in lora.loaded_loras:
+ shorthash = item.lora_on_disk.shorthash
+ if not shorthash:
+ continue
+
+ alias = item.mentioned_name
+ if not alias:
+ continue
+
+ alias = alias.replace(":", "").replace(",", "")
+
+ lora_hashes.append(f"{alias}: {shorthash}")
+
+ if lora_hashes:
+ p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
+
def deactivate(self, p):
pass
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index fa57d466..eec14712 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -3,7 +3,7 @@ import re import torch
from typing import Union
-from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -76,9 +76,9 @@ class LoraOnDisk: self.name = name
self.filename = filename
self.metadata = {}
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
- _, ext = os.path.splitext(filename)
- if ext.lower() == ".safetensors":
+ if self.is_safetensors:
try:
self.metadata = sd_models.read_metadata_from_safetensors(filename)
except Exception as e:
@@ -94,14 +94,43 @@ class LoraOnDisk: self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
+ self.hash = None
+ self.shorthash = None
+ self.set_hash(
+ self.metadata.get('sshs_model_hash') or
+ hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
+ ''
+ )
+
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
+ if self.shorthash:
+ available_lora_hash_lookup[self.shorthash] = self
+
+ def read_hash(self):
+ if not self.hash:
+ self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
+
+ def get_alias(self):
+ if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
+ return self.name
+ else:
+ return self.alias
+
class LoraModule:
- def __init__(self, name):
+ def __init__(self, name, lora_on_disk: LoraOnDisk):
self.name = name
+ self.lora_on_disk = lora_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
+ self.mentioned_name = None
+ """the text that was used to add lora to prompt - can be either name or an alias"""
+
class LoraUpDownModule:
def __init__(self):
@@ -126,11 +155,11 @@ def assign_lora_names_to_compvis_modules(sd_model): sd_model.lora_layer_mapping = lora_layer_mapping
-def load_lora(name, filename):
- lora = LoraModule(name)
- lora.mtime = os.path.getmtime(filename)
+def load_lora(name, lora_on_disk):
+ lora = LoraModule(name, lora_on_disk)
+ lora.mtime = os.path.getmtime(lora_on_disk.filename)
- sd = sd_models.read_state_dict(filename)
+ sd = sd_models.read_state_dict(lora_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
@@ -191,7 +220,7 @@ def load_lora(name, filename): raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
- print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
+ print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
return lora
@@ -217,14 +246,19 @@ def load_loras(names, multipliers=None): lora = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i]
+
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
try:
- lora = load_lora(name, lora_on_disk.filename)
+ lora = load_lora(name, lora_on_disk)
except Exception as e:
errors.display(e, f"loading Lora {lora_on_disk.filename}")
continue
+ lora.mentioned_name = name
+
+ lora_on_disk.read_hash()
+
if lora is None:
failed_to_load_loras.append(name)
print(f"Couldn't find Lora with name {name}")
@@ -403,7 +437,8 @@ def list_available_loras(): available_loras.clear()
available_lora_aliases.clear()
forbidden_lora_aliases.clear()
- forbidden_lora_aliases.update({"none": 1})
+ available_lora_hash_lookup.clear()
+ forbidden_lora_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
@@ -457,8 +492,10 @@ def infotext_pasted(infotext, params): if added:
params["Prompt"] += "\n" + "".join(added)
+
available_loras = {}
available_lora_aliases = {}
+available_lora_hash_lookup = {}
forbidden_lora_aliases = {}
loaded_loras = []
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 5eafbe86..e650f469 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,3 +1,5 @@ +import re
+
import torch
import gradio as gr
from fastapi import FastAPI
@@ -54,7 +56,8 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
- "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+ "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+ "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
}))
@@ -76,7 +79,7 @@ def api_loras(_: gr.Blocks, app: FastAPI): @app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
-
+
@app.post("/sdapi/v1/refresh-loras")
async def refresh_loras():
return lora.list_available_loras()
@@ -84,3 +87,30 @@ def api_loras(_: gr.Blocks, app: FastAPI): script_callbacks.on_app_started(api_loras)
+re_lora = re.compile("<lora:([^:]+):")
+
+
+def infotext_pasted(infotext, d):
+ hashes = d.get("Lora hashes")
+ if not hashes:
+ return
+
+ hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
+ hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
+
+ def lora_replacement(m):
+ alias = m.group(1)
+ shorthash = hashes.get(alias)
+ if shorthash is None:
+ return m.group(0)
+
+ lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
+ if lora_on_disk is None:
+ return m.group(0)
+
+ return f'<lora:{lora_on_disk.get_alias()}:'
+
+ d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
+
+
+script_callbacks.on_infotext_pasted(infotext_pasted)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 2050e3fa..259e99ac 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -16,10 +16,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename)
- if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
- alias = name
- else:
- alias = lora_on_disk.alias
+ alias = lora_on_disk.get_alias()
yield {
"name": name,
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js index 37e3d075..20443fcc 100644 --- a/javascript/textualInversion.js +++ b/javascript/textualInversion.js @@ -9,7 +9,7 @@ function start_training_textual_inversion() { gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo; }); - var res = args_to_array(arguments); + var res = Array.from(arguments); res[0] = id; diff --git a/javascript/ui.js b/javascript/ui.js index c7316ddb..648a5290 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -61,18 +61,12 @@ function extract_image_from_gallery(gallery) { return [gallery[index]]; } -function args_to_array(args) { - var res = []; - for (var i = 0; i < args.length; i++) { - res.push(args[i]); - } - return res; -} +window.args_to_array = Array.from; // Compatibility with e.g. extensions that may expect this to be around function switch_to_txt2img() { gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click(); - return args_to_array(arguments); + return Array.from(arguments); } function switch_to_img2img_tab(no) { @@ -81,65 +75,55 @@ function switch_to_img2img_tab(no) { } function switch_to_img2img() { switch_to_img2img_tab(0); - return args_to_array(arguments); + return Array.from(arguments); } function switch_to_sketch() { switch_to_img2img_tab(1); - return args_to_array(arguments); + return Array.from(arguments); } function switch_to_inpaint() { switch_to_img2img_tab(2); - return args_to_array(arguments); + return Array.from(arguments); } function switch_to_inpaint_sketch() { switch_to_img2img_tab(3); - return args_to_array(arguments); + return Array.from(arguments); } function switch_to_extras() { gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click(); - return args_to_array(arguments); + return Array.from(arguments); } function get_tab_index(tabId) { - var res = 0; - - gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i) { - if (button.className.indexOf('selected') != -1) { - res = i; + let buttons = gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button'); + for (let i = 0; i < buttons.length; i++) { + if (buttons[i].classList.contains('selected')) { + return i; } - }); - - return res; + } + return 0; } function create_tab_index_args(tabId, args) { - var res = []; - for (var i = 0; i < args.length; i++) { - res.push(args[i]); - } - + var res = Array.from(args); res[0] = get_tab_index(tabId); - return res; } function get_img2img_tab_index() { - let res = args_to_array(arguments); + let res = Array.from(arguments); res.splice(-2); res[0] = get_tab_index('mode_img2img'); return res; } function create_submit_args(args) { - var res = []; - for (var i = 0; i < args.length; i++) { - res.push(args[i]); - } + var res = Array.from(args); // As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image. // This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate. @@ -275,13 +259,13 @@ function recalculatePromptTokens(name) { function recalculate_prompts_txt2img() { recalculatePromptTokens('txt2img_prompt'); recalculatePromptTokens('txt2img_neg_prompt'); - return args_to_array(arguments); + return Array.from(arguments); } function recalculate_prompts_img2img() { recalculatePromptTokens('img2img_prompt'); recalculatePromptTokens('img2img_neg_prompt'); - return args_to_array(arguments); + return Array.from(arguments); } @@ -310,12 +310,8 @@ def prepare_environment(): print("Exiting because of --exit argument")
exit(0)
- if args.tests and not args.no_tests:
- exitcode = tests(args.tests)
- exit(exitcode)
-
-def tests(test_dir):
+def configure_for_tests():
if "--api" not in sys.argv:
sys.argv.append("--api")
if "--ckpt" not in sys.argv:
@@ -325,21 +321,8 @@ def tests(test_dir): sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv:
sys.argv.append("--disable-nan-check")
- if "--no-tests" not in sys.argv:
- sys.argv.append("--no-tests")
-
- print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
os.environ['COMMANDLINE_ARGS'] = ""
- with open(os.path.join(script_path, 'test/stdout.txt'), "w", encoding="utf8") as stdout, open(os.path.join(script_path, 'test/stderr.txt'), "w", encoding="utf8") as stderr:
- proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
-
- import test.server_poll
- exitcode = test.server_poll.run_tests(proc, test_dir)
-
- print(f"Stopping Web UI process with id {proc.pid}")
- proc.kill()
- return exitcode
def start():
@@ -351,6 +334,15 @@ def start(): webui.webui()
-if __name__ == "__main__":
- prepare_environment()
+def main():
+ if not args.skip_prepare_environment:
+ prepare_environment()
+
+ if args.test_server:
+ configure_for_tests()
+
start()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 7bde161e..8625690b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -12,8 +12,8 @@ parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch. parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
-parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
-parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
+parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
+parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 94347275..34a3ba63 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -22,6 +22,15 @@ def register_default_extra_networks(): class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
+ self.positional = []
+ self.named = {}
+
+ for item in self.items:
+ parts = item.split('=', 2)
+ if len(parts) == 2:
+ self.named[parts[0]] = parts[1]
+ else:
+ self.positional.append(item)
class ExtraNetwork:
diff --git a/modules/hashes.py b/modules/hashes.py index 032120f4..8b7ea0ac 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -46,8 +46,8 @@ def calculate_sha256(filename): return hash_sha256.hexdigest()
-def sha256_from_cache(filename, title):
- hashes = cache("hashes")
+def sha256_from_cache(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
ondisk_mtime = os.path.getmtime(filename)
if title not in hashes:
@@ -62,10 +62,10 @@ def sha256_from_cache(filename, title): return cached_sha256
-def sha256(filename, title):
- hashes = cache("hashes")
+def sha256(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
- sha256_value = sha256_from_cache(filename, title)
+ sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
if sha256_value is not None:
return sha256_value
@@ -73,7 +73,11 @@ def sha256(filename, title): return None
print(f"Calculating sha256 for {filename}: ", end='')
- sha256_value = calculate_sha256(filename)
+ if use_addnet_hash:
+ with open(filename, "rb") as file:
+ sha256_value = addnet_hash_safetensors(file)
+ else:
+ sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
hashes[title] = {
@@ -86,6 +90,19 @@ def sha256(filename, title): return sha256_value
+def addnet_hash_safetensors(b):
+ """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
diff --git a/modules/shared.py b/modules/shared.py index fa080458..e53b1e11 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -15,6 +15,7 @@ import modules.devices as devices from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion
+from typing import Optional
demo = None
@@ -113,7 +114,7 @@ class State: time_start = None
server_start = None
_server_command_signal = threading.Event()
- _server_command: str | None = None
+ _server_command: Optional[str] = None
@property
def need_restart(self) -> bool:
@@ -131,14 +132,14 @@ class State: return self._server_command
@server_command.setter
- def server_command(self, value: str | None) -> None:
+ def server_command(self, value: Optional[str]) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
- def wait_for_server_command(self, timeout: float | None = None) -> str | None:
+ def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for server command to get set; return and clear the value and signal.
"""
diff --git a/modules/ui.py b/modules/ui.py index beccf78b..82820ab5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -463,8 +463,8 @@ def create_ui(): elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="txt2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="txt2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
@@ -792,8 +792,8 @@ def create_ui(): with gr.Tab(label="Resize to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
- width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
- height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
+ width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
+ height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
@@ -1183,8 +1183,8 @@ def create_ui(): with gr.Tab(label="Preprocess images", id="preprocess_images"):
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
- process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
- process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
+ process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="train_process_width")
+ process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="train_process_height")
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
with gr.Row():
@@ -1276,8 +1276,8 @@ def create_ui(): template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names())
create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file")
- training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width")
- training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height")
+ training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="train_training_width")
+ training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="train_training_height")
varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize")
steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps")
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 728fec9e..0052a5cc 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -55,7 +55,7 @@ class UiLoadsave: if field == 'value' and key not in self.component_mapping:
self.component_mapping[key] = x
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
apply_field(x, 'visible')
if type(x) == gr.Slider:
@@ -109,6 +109,8 @@ class UiLoadsave: self.add_block(c, path)
elif x.label is not None:
self.add_component(f"{path}/{x.label}", x)
+ elif isinstance(x, gr.Button) and x.value is not None:
+ self.add_component(f"{path}/{x.value}", x)
def read_from_file(self):
with open(self.filename, "r", encoding="utf8") as file:
diff --git a/pyproject.toml b/pyproject.toml index d4a1bbf4..80541a8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,3 +30,6 @@ ignore = [ [tool.ruff.flake8-bugbear] # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] + +[tool.pytest.ini_options] +base_url = "http://127.0.0.1:7860" diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 00000000..37838ca2 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,3 @@ +pytest-base-url~=2.0 +pytest-cov~=4.0 +pytest~=7.3 diff --git a/test/basic_features/__init__.py b/test/basic_features/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/test/basic_features/__init__.py +++ /dev/null diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py deleted file mode 100644 index 8ed98747..00000000 --- a/test/basic_features/extras_test.py +++ /dev/null @@ -1,56 +0,0 @@ -import os -import unittest -import requests -from gradio.processing_utils import encode_pil_to_base64 -from PIL import Image -from modules.paths import script_path - -class TestExtrasWorking(unittest.TestCase): - def setUp(self): - self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image" - self.extras_single = { - "resize_mode": 0, - "show_extras_results": True, - "gfpgan_visibility": 0, - "codeformer_visibility": 0, - "codeformer_weight": 0, - "upscaling_resize": 2, - "upscaling_resize_w": 128, - "upscaling_resize_h": 128, - "upscaling_crop": True, - "upscaler_1": "None", - "upscaler_2": "None", - "extras_upscaler_2_visibility": 0, - "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))) - } - - def test_simple_upscaling_performed(self): - self.extras_single["upscaler_1"] = "Lanczos" - self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200) - - -class TestPngInfoWorking(unittest.TestCase): - def setUp(self): - self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image" - self.png_info = { - "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))) - } - - def test_png_info_performed(self): - self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200) - - -class TestInterrogateWorking(unittest.TestCase): - def setUp(self): - self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image" - self.interrogate = { - "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))), - "model": "clip" - } - - def test_interrogate_performed(self): - self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py deleted file mode 100644 index 5240ec36..00000000 --- a/test/basic_features/img2img_test.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import unittest -import requests -from gradio.processing_utils import encode_pil_to_base64 -from PIL import Image -from modules.paths import script_path - - -class TestImg2ImgWorking(unittest.TestCase): - def setUp(self): - self.url_img2img = "http://localhost:7860/sdapi/v1/img2img" - self.simple_img2img = { - "init_images": [encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))], - "resize_mode": 0, - "denoising_strength": 0.75, - "mask": None, - "mask_blur": 4, - "inpainting_fill": 0, - "inpaint_full_res": False, - "inpaint_full_res_padding": 0, - "inpainting_mask_invert": False, - "prompt": "example prompt", - "styles": [], - "seed": -1, - "subseed": -1, - "subseed_strength": 0, - "seed_resize_from_h": -1, - "seed_resize_from_w": -1, - "batch_size": 1, - "n_iter": 1, - "steps": 3, - "cfg_scale": 7, - "width": 64, - "height": 64, - "restore_faces": False, - "tiling": False, - "negative_prompt": "", - "eta": 0, - "s_churn": 0, - "s_tmax": 0, - "s_tmin": 0, - "s_noise": 1, - "override_settings": {}, - "sampler_index": "Euler a", - "include_init_images": False - } - - def test_img2img_simple_performed(self): - self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) - - def test_inpainting_masked_performed(self): - self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))) - self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) - - def test_inpainting_with_inverted_masked_performed(self): - self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))) - self.simple_img2img["inpainting_mask_invert"] = True - self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) - - def test_img2img_sd_upscale_performed(self): - self.simple_img2img["script_name"] = "sd upscale" - self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0] - - self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py deleted file mode 100644 index cb525fbb..00000000 --- a/test/basic_features/txt2img_test.py +++ /dev/null @@ -1,82 +0,0 @@ -import unittest -import requests - - -class TestTxt2ImgWorking(unittest.TestCase): - def setUp(self): - self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" - self.simple_txt2img = { - "enable_hr": False, - "denoising_strength": 0, - "firstphase_width": 0, - "firstphase_height": 0, - "prompt": "example prompt", - "styles": [], - "seed": -1, - "subseed": -1, - "subseed_strength": 0, - "seed_resize_from_h": -1, - "seed_resize_from_w": -1, - "batch_size": 1, - "n_iter": 1, - "steps": 3, - "cfg_scale": 7, - "width": 64, - "height": 64, - "restore_faces": False, - "tiling": False, - "negative_prompt": "", - "eta": 0, - "s_churn": 0, - "s_tmax": 0, - "s_tmin": 0, - "s_noise": 1, - "sampler_index": "Euler a" - } - - def test_txt2img_simple_performed(self): - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_with_negative_prompt_performed(self): - self.simple_txt2img["negative_prompt"] = "example negative prompt" - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_with_complex_prompt_performed(self): - self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]" - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_not_square_image_performed(self): - self.simple_txt2img["height"] = 128 - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_with_hrfix_performed(self): - self.simple_txt2img["enable_hr"] = True - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_with_tiling_performed(self): - self.simple_txt2img["tiling"] = True - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_with_restore_faces_performed(self): - self.simple_txt2img["restore_faces"] = True - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_with_vanilla_sampler_performed(self): - self.simple_txt2img["sampler_index"] = "PLMS" - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - self.simple_txt2img["sampler_index"] = "DDIM" - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - self.simple_txt2img["sampler_index"] = "UniPC" - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_multiple_batches_performed(self): - self.simple_txt2img["n_iter"] = 2 - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - def test_txt2img_batch_performed(self): - self.simple_txt2img["batch_size"] = 2 - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py deleted file mode 100644 index d9e46b5e..00000000 --- a/test/basic_features/utils_test.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest -import requests - - -class UtilsTests(unittest.TestCase): - def setUp(self): - self.url_options = "http://localhost:7860/sdapi/v1/options" - self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags" - self.url_samplers = "http://localhost:7860/sdapi/v1/samplers" - self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers" - self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models" - self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks" - self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers" - self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models" - self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" - self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings" - - def test_options_get(self): - self.assertEqual(requests.get(self.url_options).status_code, 200) - - def test_options_write(self): - response = requests.get(self.url_options) - self.assertEqual(response.status_code, 200) - - pre_value = response.json()["send_seed"] - - self.assertEqual(requests.post(self.url_options, json={"send_seed": not pre_value}).status_code, 200) - - response = requests.get(self.url_options) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json()["send_seed"], not pre_value) - - requests.post(self.url_options, json={"send_seed": pre_value}) - - def test_cmd_flags(self): - self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) - - def test_samplers(self): - self.assertEqual(requests.get(self.url_samplers).status_code, 200) - - def test_upscalers(self): - self.assertEqual(requests.get(self.url_upscalers).status_code, 200) - - def test_sd_models(self): - self.assertEqual(requests.get(self.url_sd_models).status_code, 200) - - def test_hypernetworks(self): - self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200) - - def test_face_restorers(self): - self.assertEqual(requests.get(self.url_face_restorers).status_code, 200) - - def test_realesrgan_models(self): - self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200) - - def test_prompt_styles(self): - self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200) - - def test_embeddings(self): - self.assertEqual(requests.get(self.url_embeddings).status_code, 200) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..0723f62a --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,17 @@ +import os + +import pytest +from PIL import Image +from gradio.processing_utils import encode_pil_to_base64 + +test_files_path = os.path.dirname(__file__) + "/test_files" + + +@pytest.fixture(scope="session") # session so we don't read this over and over +def img2img_basic_image_base64() -> str: + return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "img2img_basic.png"))) + + +@pytest.fixture(scope="session") # session so we don't read this over and over +def mask_basic_image_base64() -> str: + return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "mask_basic.png"))) diff --git a/test/server_poll.py b/test/server_poll.py deleted file mode 100644 index c732630f..00000000 --- a/test/server_poll.py +++ /dev/null @@ -1,26 +0,0 @@ -import unittest -import requests -import time -import os -from modules.paths import script_path - - -def run_tests(proc, test_dir): - timeout_threshold = 240 - start_time = time.time() - while time.time()-start_time < timeout_threshold: - try: - requests.head("http://localhost:7860/") - break - except requests.exceptions.ConnectionError: - if proc.poll() is not None: - break - if proc.poll() is None: - if test_dir is None: - test_dir = os.path.join(script_path, "test") - suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir=test_dir) - result = unittest.TextTestRunner(verbosity=2).run(suite) - return len(result.failures) + len(result.errors) - else: - print("Launch unsuccessful") - return 1 diff --git a/test/test_extras.py b/test/test_extras.py new file mode 100644 index 00000000..799d9fad --- /dev/null +++ b/test/test_extras.py @@ -0,0 +1,35 @@ +import requests + + +def test_simple_upscaling_performed(base_url, img2img_basic_image_base64): + payload = { + "resize_mode": 0, + "show_extras_results": True, + "gfpgan_visibility": 0, + "codeformer_visibility": 0, + "codeformer_weight": 0, + "upscaling_resize": 2, + "upscaling_resize_w": 128, + "upscaling_resize_h": 128, + "upscaling_crop": True, + "upscaler_1": "Lanczos", + "upscaler_2": "None", + "extras_upscaler_2_visibility": 0, + "image": img2img_basic_image_base64, + } + assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 + + +def test_png_info_performed(base_url, img2img_basic_image_base64): + payload = { + "image": img2img_basic_image_base64, + } + assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 + + +def test_interrogate_performed(base_url, img2img_basic_image_base64): + payload = { + "image": img2img_basic_image_base64, + "model": "clip", + } + assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 diff --git a/test/test_img2img.py b/test/test_img2img.py new file mode 100644 index 00000000..117d2d1e --- /dev/null +++ b/test/test_img2img.py @@ -0,0 +1,68 @@ + +import pytest +import requests + + +@pytest.fixture() +def url_img2img(base_url): + return f"{base_url}/sdapi/v1/img2img" + + +@pytest.fixture() +def simple_img2img_request(img2img_basic_image_base64): + return { + "batch_size": 1, + "cfg_scale": 7, + "denoising_strength": 0.75, + "eta": 0, + "height": 64, + "include_init_images": False, + "init_images": [img2img_basic_image_base64], + "inpaint_full_res": False, + "inpaint_full_res_padding": 0, + "inpainting_fill": 0, + "inpainting_mask_invert": False, + "mask": None, + "mask_blur": 4, + "n_iter": 1, + "negative_prompt": "", + "override_settings": {}, + "prompt": "example prompt", + "resize_mode": 0, + "restore_faces": False, + "s_churn": 0, + "s_noise": 1, + "s_tmax": 0, + "s_tmin": 0, + "sampler_index": "Euler a", + "seed": -1, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "steps": 3, + "styles": [], + "subseed": -1, + "subseed_strength": 0, + "tiling": False, + "width": 64, + } + + +def test_img2img_simple_performed(url_img2img, simple_img2img_request): + assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 + + +def test_inpainting_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64): + simple_img2img_request["mask"] = mask_basic_image_base64 + assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 + + +def test_inpainting_with_inverted_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64): + simple_img2img_request["mask"] = mask_basic_image_base64 + simple_img2img_request["inpainting_mask_invert"] = True + assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 + + +def test_img2img_sd_upscale_performed(url_img2img, simple_img2img_request): + simple_img2img_request["script_name"] = "sd upscale" + simple_img2img_request["script_args"] = ["", 8, "Lanczos", 2.0] + assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 diff --git a/test/test_txt2img.py b/test/test_txt2img.py new file mode 100644 index 00000000..6eb94f0a --- /dev/null +++ b/test/test_txt2img.py @@ -0,0 +1,90 @@ + +import pytest +import requests + + +@pytest.fixture() +def url_txt2img(base_url): + return f"{base_url}/sdapi/v1/txt2img" + + +@pytest.fixture() +def simple_txt2img_request(): + return { + "batch_size": 1, + "cfg_scale": 7, + "denoising_strength": 0, + "enable_hr": False, + "eta": 0, + "firstphase_height": 0, + "firstphase_width": 0, + "height": 64, + "n_iter": 1, + "negative_prompt": "", + "prompt": "example prompt", + "restore_faces": False, + "s_churn": 0, + "s_noise": 1, + "s_tmax": 0, + "s_tmin": 0, + "sampler_index": "Euler a", + "seed": -1, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "steps": 3, + "styles": [], + "subseed": -1, + "subseed_strength": 0, + "tiling": False, + "width": 64, + } + + +def test_txt2img_simple_performed(url_txt2img, simple_txt2img_request): + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +def test_txt2img_with_negative_prompt_performed(url_txt2img, simple_txt2img_request): + simple_txt2img_request["negative_prompt"] = "example negative prompt" + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +def test_txt2img_with_complex_prompt_performed(url_txt2img, simple_txt2img_request): + simple_txt2img_request["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]" + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +def test_txt2img_not_square_image_performed(url_txt2img, simple_txt2img_request): + simple_txt2img_request["height"] = 128 + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +def test_txt2img_with_hrfix_performed(url_txt2img, simple_txt2img_request): + simple_txt2img_request["enable_hr"] = True + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +def test_txt2img_with_tiling_performed(url_txt2img, simple_txt2img_request): + simple_txt2img_request["tiling"] = True + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +def test_txt2img_with_restore_faces_performed(url_txt2img, simple_txt2img_request): + simple_txt2img_request["restore_faces"] = True + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +@pytest.mark.parametrize("sampler", ["PLMS", "DDIM", "UniPC"]) +def test_txt2img_with_vanilla_sampler_performed(url_txt2img, simple_txt2img_request, sampler): + simple_txt2img_request["sampler_index"] = sampler + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +def test_txt2img_multiple_batches_performed(url_txt2img, simple_txt2img_request): + simple_txt2img_request["n_iter"] = 2 + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 + + +def test_txt2img_batch_performed(url_txt2img, simple_txt2img_request): + simple_txt2img_request["batch_size"] = 2 + assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200 diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..edba0b18 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,33 @@ +import pytest +import requests + + +def test_options_write(base_url): + url_options = f"{base_url}/sdapi/v1/options" + response = requests.get(url_options) + assert response.status_code == 200 + + pre_value = response.json()["send_seed"] + + assert requests.post(url_options, json={'send_seed': (not pre_value)}).status_code == 200 + + response = requests.get(url_options) + assert response.status_code == 200 + assert response.json()['send_seed'] == (not pre_value) + + requests.post(url_options, json={"send_seed": pre_value}) + + +@pytest.mark.parametrize("url", [ + "sdapi/v1/cmd-flags", + "sdapi/v1/samplers", + "sdapi/v1/upscalers", + "sdapi/v1/sd-models", + "sdapi/v1/hypernetworks", + "sdapi/v1/face-restorers", + "sdapi/v1/realesrgan-models", + "sdapi/v1/prompt-styles", + "sdapi/v1/embeddings", +]) +def test_get_api_url(base_url, url): + assert requests.get(f"{base_url}/{url}").status_code == 200 |