aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.eslintignore4
-rw-r--r--.eslintrc.js91
-rw-r--r--.git-blame-ignore-revs2
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.yml42
-rw-r--r--.github/pull_request_template.md33
-rw-r--r--.github/workflows/on_pull_request.yaml49
-rw-r--r--.github/workflows/run_tests.yaml53
-rw-r--r--.gitignore3
-rw-r--r--CHANGELOG.md215
-rw-r--r--README.md9
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py13
-rw-r--r--extensions-builtin/LDSR/scripts/ldsr_model.py29
-rw-r--r--extensions-builtin/LDSR/sd_hijack_autoencoder.py33
-rw-r--r--extensions-builtin/LDSR/sd_hijack_ddpm_v1.py70
-rw-r--r--extensions-builtin/LDSR/vqvae_quantize.py147
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py22
-rw-r--r--extensions-builtin/Lora/lora.py90
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py38
-rw-r--r--extensions-builtin/Lora/ui_extra_networks_lora.py9
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py48
-rw-r--r--extensions-builtin/ScuNET/scunet_model_arch.py11
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py64
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch.py6
-rw-r--r--extensions-builtin/SwinIR/swinir_model_arch_v2.py58
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js776
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py14
-rw-r--r--extensions-builtin/canvas-zoom-and-pan/style.css63
-rw-r--r--extensions-builtin/extra-options-section/scripts/extra_options_section.py48
-rw-r--r--extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js52
-rw-r--r--html/extra-networks-card.html7
-rw-r--r--html/footer.html4
-rw-r--r--html/licenses.html26
-rw-r--r--javascript/aspectRatioOverlay.js224
-rw-r--r--javascript/contextMenus.js342
-rw-r--r--javascript/dragdrop.js99
-rw-r--r--javascript/edit-attention.js241
-rw-r--r--javascript/edit-order.js41
-rw-r--r--javascript/extensions.js163
-rw-r--r--javascript/extraNetworks.js461
-rw-r--r--javascript/generationParams.js48
-rw-r--r--javascript/hints.js137
-rw-r--r--javascript/hires_fix.js36
-rw-r--r--javascript/imageMaskFix.js21
-rw-r--r--javascript/imageParams.js18
-rw-r--r--javascript/imageviewer.js228
-rw-r--r--javascript/imageviewerGamepad.js10
-rw-r--r--javascript/localization.js353
-rw-r--r--javascript/notification.js12
-rw-r--r--javascript/profilerVisualization.js153
-rw-r--r--javascript/progressbar.js184
-rw-r--r--javascript/textualInversion.js34
-rw-r--r--javascript/token-counters.js83
-rw-r--r--javascript/ui.js444
-rw-r--r--javascript/ui_settings_hints.js103
-rw-r--r--launch.py384
-rw-r--r--modules/Roboto-Regular.ttfbin0 -> 305608 bytes
-rw-r--r--modules/api/api.py298
-rw-r--r--modules/api/models.py37
-rw-r--r--modules/call_queue.py30
-rw-r--r--modules/cmd_args.py30
-rw-r--r--modules/codeformer/codeformer_arch.py24
-rw-r--r--modules/codeformer/vqgan_arch.py44
-rw-r--r--modules/codeformer_model.py25
-rw-r--r--modules/config_states.py19
-rw-r--r--modules/deepbooru.py3
-rw-r--r--modules/devices.py27
-rw-r--r--modules/errors.py46
-rw-r--r--modules/esrgan_model.py35
-rw-r--r--modules/esrgan_model_arch.py7
-rw-r--r--modules/extensions.py48
-rw-r--r--modules/extra_networks.py22
-rw-r--r--modules/extra_networks_hypernet.py6
-rw-r--r--modules/extras.py21
-rw-r--r--modules/generation_parameters_copypaste.py115
-rw-r--r--modules/gfpgan_model.py16
-rw-r--r--modules/gitpython_hack.py42
-rw-r--r--modules/hashes.py29
-rw-r--r--modules/hypernetworks/hypernetwork.py68
-rw-r--r--modules/hypernetworks/ui.py6
-rw-r--r--modules/images.py190
-rw-r--r--modules/img2img.py82
-rw-r--r--modules/interrogate.py14
-rw-r--r--modules/launch_utils.py344
-rw-r--r--modules/localization.py6
-rw-r--r--modules/lowvram.py6
-rw-r--r--modules/mac_specific.py28
-rw-r--r--modules/masking.py2
-rw-r--r--modules/modelloader.py48
-rw-r--r--modules/models/diffusion/ddpm_edit.py56
-rw-r--r--modules/models/diffusion/uni_pc/__init__.py2
-rw-r--r--modules/models/diffusion/uni_pc/sampler.py3
-rw-r--r--modules/models/diffusion/uni_pc/uni_pc.py82
-rw-r--r--modules/ngrok.py38
-rw-r--r--modules/paths.py19
-rw-r--r--modules/paths_internal.py12
-rw-r--r--modules/postprocessing.py7
-rw-r--r--modules/processing.py359
-rw-r--r--modules/progress.py15
-rw-r--r--modules/prompt_parser.py33
-rw-r--r--modules/realesrgan_model.py53
-rw-r--r--modules/restart.py23
-rw-r--r--modules/safe.py33
-rw-r--r--modules/script_callbacks.py102
-rw-r--r--modules/script_loading.py8
-rw-r--r--modules/scripts.py225
-rw-r--r--modules/scripts_auto_postprocessing.py2
-rw-r--r--modules/scripts_postprocessing.py8
-rw-r--r--modules/sd_disable_initialization.py2
-rw-r--r--modules/sd_hijack.py129
-rw-r--r--modules/sd_hijack_clip.py4
-rw-r--r--modules/sd_hijack_clip_old.py2
-rw-r--r--modules/sd_hijack_inpainting.py10
-rw-r--r--modules/sd_hijack_ip2p.py7
-rw-r--r--modules/sd_hijack_optimizations.py180
-rw-r--r--modules/sd_hijack_xlmr.py2
-rw-r--r--modules/sd_models.py89
-rw-r--r--modules/sd_models_config.py1
-rw-r--r--modules/sd_samplers.py10
-rw-r--r--modules/sd_samplers_common.py35
-rw-r--r--modules/sd_samplers_compvis.py12
-rw-r--r--modules/sd_samplers_kdiffusion.py119
-rw-r--r--modules/sd_unet.py92
-rw-r--r--modules/sd_vae.py11
-rw-r--r--modules/sd_vae_taesd.py88
-rw-r--r--modules/shared.py283
-rw-r--r--modules/shared_items.py46
-rw-r--r--modules/styles.py78
-rw-r--r--modules/sub_quadratic_attention.py17
-rw-r--r--modules/sysinfo.py162
-rw-r--r--modules/textual_inversion/autocrop.py221
-rw-r--r--modules/textual_inversion/dataset.py6
-rw-r--r--modules/textual_inversion/image_embedding.py28
-rw-r--r--modules/textual_inversion/learn_schedule.py6
-rw-r--r--modules/textual_inversion/logging.py48
-rw-r--r--modules/textual_inversion/preprocess.py12
-rw-r--r--modules/textual_inversion/textual_inversion.py69
-rw-r--r--modules/timer.py46
-rw-r--r--modules/txt2img.py17
-rw-r--r--modules/ui.py638
-rw-r--r--modules/ui_common.py32
-rw-r--r--modules/ui_extensions.py109
-rw-r--r--modules/ui_extra_networks.py99
-rw-r--r--modules/ui_extra_networks_checkpoints.py4
-rw-r--r--modules/ui_extra_networks_hypernets.py4
-rw-r--r--modules/ui_extra_networks_textual_inversion.py4
-rw-r--r--modules/ui_gradio_extensions.py69
-rw-r--r--modules/ui_loadsave.py210
-rw-r--r--modules/ui_postprocessing.py2
-rw-r--r--modules/ui_settings.py289
-rw-r--r--modules/ui_tempdir.py19
-rw-r--r--modules/upscaler.py15
-rw-r--r--modules/xlmr.py8
-rw-r--r--package.json11
-rw-r--r--pyproject.toml35
-rw-r--r--requirements-test.txt3
-rw-r--r--requirements.txt43
-rw-r--r--requirements_versions.txt42
-rw-r--r--script.js153
-rw-r--r--scripts/custom_code.py2
-rw-r--r--scripts/img2imgalt.py14
-rw-r--r--scripts/loopback.py8
-rw-r--r--scripts/outpainting_mk_2.py36
-rw-r--r--scripts/poor_mans_outpainting.py6
-rw-r--r--scripts/postprocessing_upscale.py6
-rw-r--r--scripts/prompt_matrix.py9
-rw-r--r--scripts/prompts_from_file.py19
-rw-r--r--scripts/sd_upscale.py10
-rw-r--r--scripts/xyz_grid.py23
-rw-r--r--style.css122
-rw-r--r--test/basic_features/__init__.py0
-rw-r--r--test/basic_features/extras_test.py56
-rw-r--r--test/basic_features/img2img_test.py68
-rw-r--r--test/basic_features/txt2img_test.py82
-rw-r--r--test/basic_features/utils_test.py62
-rw-r--r--test/conftest.py17
-rw-r--r--test/server_poll.py26
-rw-r--r--test/test_extras.py35
-rw-r--r--test/test_img2img.py68
-rw-r--r--test/test_txt2img.py90
-rw-r--r--test/test_utils.py33
-rw-r--r--webui-macos-env.sh2
-rw-r--r--webui-user.sh1
-rw-r--r--webui.bat4
-rw-r--r--webui.py378
-rwxr-xr-xwebui.sh86
185 files changed, 8850 insertions, 4780 deletions
diff --git a/.eslintignore b/.eslintignore
new file mode 100644
index 00000000..1cfd9487
--- /dev/null
+++ b/.eslintignore
@@ -0,0 +1,4 @@
+extensions
+extensions-disabled
+repositories
+venv \ No newline at end of file
diff --git a/.eslintrc.js b/.eslintrc.js
new file mode 100644
index 00000000..f33aca09
--- /dev/null
+++ b/.eslintrc.js
@@ -0,0 +1,91 @@
+/* global module */
+module.exports = {
+ env: {
+ browser: true,
+ es2021: true,
+ },
+ extends: "eslint:recommended",
+ parserOptions: {
+ ecmaVersion: "latest",
+ },
+ rules: {
+ "arrow-spacing": "error",
+ "block-spacing": "error",
+ "brace-style": "error",
+ "comma-dangle": ["error", "only-multiline"],
+ "comma-spacing": "error",
+ "comma-style": ["error", "last"],
+ "curly": ["error", "multi-line", "consistent"],
+ "eol-last": "error",
+ "func-call-spacing": "error",
+ "function-call-argument-newline": ["error", "consistent"],
+ "function-paren-newline": ["error", "consistent"],
+ "indent": ["error", 4],
+ "key-spacing": "error",
+ "keyword-spacing": "error",
+ "linebreak-style": ["error", "unix"],
+ "no-extra-semi": "error",
+ "no-mixed-spaces-and-tabs": "error",
+ "no-multi-spaces": "error",
+ "no-redeclare": ["error", {builtinGlobals: false}],
+ "no-trailing-spaces": "error",
+ "no-unused-vars": "off",
+ "no-whitespace-before-property": "error",
+ "object-curly-newline": ["error", {consistent: true, multiline: true}],
+ "object-curly-spacing": ["error", "never"],
+ "operator-linebreak": ["error", "after"],
+ "quote-props": ["error", "consistent-as-needed"],
+ "semi": ["error", "always"],
+ "semi-spacing": "error",
+ "semi-style": ["error", "last"],
+ "space-before-blocks": "error",
+ "space-before-function-paren": ["error", "never"],
+ "space-in-parens": ["error", "never"],
+ "space-infix-ops": "error",
+ "space-unary-ops": "error",
+ "switch-colon-spacing": "error",
+ "template-curly-spacing": ["error", "never"],
+ "unicode-bom": "error",
+ },
+ globals: {
+ //script.js
+ gradioApp: "readonly",
+ executeCallbacks: "readonly",
+ onAfterUiUpdate: "readonly",
+ onOptionsChanged: "readonly",
+ onUiLoaded: "readonly",
+ onUiUpdate: "readonly",
+ uiCurrentTab: "writable",
+ uiElementInSight: "readonly",
+ uiElementIsVisible: "readonly",
+ //ui.js
+ opts: "writable",
+ all_gallery_buttons: "readonly",
+ selected_gallery_button: "readonly",
+ selected_gallery_index: "readonly",
+ switch_to_txt2img: "readonly",
+ switch_to_img2img_tab: "readonly",
+ switch_to_img2img: "readonly",
+ switch_to_sketch: "readonly",
+ switch_to_inpaint: "readonly",
+ switch_to_inpaint_sketch: "readonly",
+ switch_to_extras: "readonly",
+ get_tab_index: "readonly",
+ create_submit_args: "readonly",
+ restart_reload: "readonly",
+ updateInput: "readonly",
+ //extraNetworks.js
+ requestGet: "readonly",
+ popup: "readonly",
+ // from python
+ localization: "readonly",
+ // progrssbar.js
+ randomId: "readonly",
+ requestProgress: "readonly",
+ // imageviewer.js
+ modalPrevImage: "readonly",
+ modalNextImage: "readonly",
+ // token-counters.js
+ setupTokenCounters: "readonly",
+ }
+};
diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 00000000..4104da63
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Apply ESlint
+9c54b78d9dde5601e916f308d9a9d6953ec39430 \ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index 7d435297..d80b24e2 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -43,11 +43,20 @@ body:
- type: input
id: commit
attributes:
- label: Commit where the problem happens
- description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
+ label: Version or Commit where the problem happens
+ description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
validations:
required: true
- type: dropdown
+ id: py-version
+ attributes:
+ label: What Python version are you running on ?
+ multiple: false
+ options:
+ - Python 3.10.x
+ - Python 3.11.x (above, no supported yet)
+ - Python 3.9.x (below, no recommended)
+ - type: dropdown
id: platforms
attributes:
label: What platforms do you use to access the UI ?
@@ -60,6 +69,35 @@ body:
- Android
- Other/Cloud
- type: dropdown
+ id: device
+ attributes:
+ label: What device are you running WebUI on?
+ multiple: true
+ options:
+ - Nvidia GPUs (RTX 20 above)
+ - Nvidia GPUs (GTX 16 below)
+ - AMD GPUs (RX 6000 above)
+ - AMD GPUs (RX 5000 below)
+ - CPU
+ - Other GPUs
+ - type: dropdown
+ id: cross_attention_opt
+ attributes:
+ label: Cross attention optimization
+ description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization
+ multiple: false
+ options:
+ - Automatic
+ - xformers
+ - sdp-no-mem
+ - sdp
+ - Doggettx
+ - V1
+ - InvokeAI
+ - "None "
+ validations:
+ required: true
+ - type: dropdown
id: browsers
attributes:
label: What browsers do you use to access the UI ?
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 69056331..c9fcda2e 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -1,28 +1,15 @@
-# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request!
+## Description
-If you have a large change, pay special attention to this paragraph:
+* a simple description of what you're trying to accomplish
+* a summary of changes in code
+* which issues it fixes, if any
-> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature.
+## Screenshots/videos:
-Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on.
-**Describe what this pull request is trying to achieve.**
+## Checklist:
-A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code.
-
-**Additional notes and description of your changes**
-
-More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of.
-
-**Environment this was tested in**
-
-List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box.
- - OS: [e.g. Windows, Linux]
- - Browser: [e.g. chrome, safari]
- - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB]
-
-**Screenshots or videos of your changes**
-
-If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made.
-
-This is **required** for anything that touches the user interface. \ No newline at end of file
+- [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
+- [ ] I have performed a self-review of my own code
+- [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style)
+- [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests)
diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml
index a168be5b..8ebf5918 100644
--- a/.github/workflows/on_pull_request.yaml
+++ b/.github/workflows/on_pull_request.yaml
@@ -1,39 +1,34 @@
-# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
name: Run Linting/Formatting on Pull Requests
on:
- push
- pull_request
- # See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
- # if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
- # pull_request:
- # branches:
- # - master
- # branches-ignore:
- # - development
jobs:
- lint:
+ lint-python:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- - name: Set up Python 3.10
- uses: actions/setup-python@v4
+ - uses: actions/setup-python@v4
with:
- python-version: 3.10.6
- cache: pip
- cache-dependency-path: |
- **/requirements*txt
- - name: Install PyLint
- run: |
- python -m pip install --upgrade pip
- pip install pylint
- # This lets PyLint check to see if it can resolve imports
- - name: Install dependencies
- run: |
- export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
- python launch.py
- - name: Analysing the code with pylint
- run: |
- pylint $(git ls-files '*.py')
+ python-version: 3.11
+ # NB: there's no cache: pip here since we're not installing anything
+ # from the requirements.txt file(s) in the repository; it's faster
+ # not to have GHA download an (at the time of writing) 4 GB cache
+ # of PyTorch and other dependencies.
+ - name: Install Ruff
+ run: pip install ruff==0.0.272
+ - name: Run Ruff
+ run: ruff .
+ lint-js:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v3
+ - name: Install Node.js
+ uses: actions/setup-node@v3
+ with:
+ node-version: 18
+ - run: npm i --ci
+ - run: npm run lint
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index 9a0b8d22..178c026a 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -17,13 +17,54 @@ jobs:
cache: pip
cache-dependency-path: |
**/requirements*txt
+ launch.py
+ - name: Install test dependencies
+ run: pip install wait-for-it -r requirements-test.txt
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: "1"
+ PIP_PROGRESS_BAR: "off"
+ - name: Setup environment
+ run: python launch.py --skip-torch-cuda-test --exit
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: "1"
+ PIP_PROGRESS_BAR: "off"
+ TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
+ WEBUI_LAUNCH_LIVE_OUTPUT: "1"
+ PYTHONUNBUFFERED: "1"
+ - name: Start test server
+ run: >
+ python -m coverage run
+ --data-file=.coverage.server
+ launch.py
+ --skip-prepare-environment
+ --skip-torch-cuda-test
+ --test-server
+ --no-half
+ --disable-opt-split-attention
+ --use-cpu all
+ --api-server-stop
+ 2>&1 | tee output.txt &
- name: Run tests
- run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- - name: Upload main app stdout-stderr
+ run: |
+ wait-for-it --service 127.0.0.1:7860 -t 600
+ python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
+ - name: Kill test server
+ if: always()
+ run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
+ - name: Show coverage
+ run: |
+ python -m coverage combine .coverage*
+ python -m coverage report -i
+ python -m coverage html -i
+ - name: Upload main app output
+ uses: actions/upload-artifact@v3
+ if: always()
+ with:
+ name: output
+ path: output.txt
+ - name: Upload coverage HTML
uses: actions/upload-artifact@v3
if: always()
with:
- name: stdout-stderr
- path: |
- test/stdout.txt
- test/stderr.txt
+ name: htmlcov
+ path: htmlcov
diff --git a/.gitignore b/.gitignore
index 7328401f..09734267 100644
--- a/.gitignore
+++ b/.gitignore
@@ -34,3 +34,6 @@ notification.mp3
/test/stderr.txt
/cache.json*
/config_states/
+/node_modules
+/package-lock.json
+/.coverage*
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8cf444ca..6a31f35b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,56 +1,193 @@
+## 1.4.0
+
+### Features:
+ * zoom controls for inpainting
+ * run basic torch calculation at startup in parallel to reduce the performance impact of first generation
+ * option to pad prompt/neg prompt to be same length
+ * remove taming_transformers dependency
+ * custom k-diffusion scheduler settings
+ * add an option to show selected settings in main txt2img/img2img UI
+ * sysinfo tab in settings
+ * infer styles from prompts when pasting params into the UI
+ * an option to control the behavior of the above
+
+### Minor:
+ * bump Gradio to 3.32.0
+ * bump xformers to 0.0.20
+ * Add option to disable token counters
+ * tooltip fixes & optimizations
+ * make it possible to configure filename for the zip download
+ * `[vae_filename]` pattern for filenames
+ * Revert discarding penultimate sigma for DPM-Solver++(2M) SDE
+ * change UI reorder setting to multiselect
+ * read version info form CHANGELOG.md if git version info is not available
+ * link footer API to Wiki when API is not active
+ * persistent conds cache (opt-in optimization)
+
+### Extensions:
+ * After installing extensions, webui properly restarts the process rather than reloads the UI
+ * Added VAE listing to web API. Via: /sdapi/v1/sd-vae
+ * custom unet support
+ * Add onAfterUiUpdate callback
+ * refactor EmbeddingDatabase.register_embedding() to allow unregistering
+ * add before_process callback for scripts
+ * add ability for alwayson scripts to specify section and let user reorder those sections
+
+### Bug Fixes:
+ * Fix dragging text to prompt
+ * fix incorrect quoting for infotext values with colon in them
+ * fix "hires. fix" prompt sharing same labels with txt2img_prompt
+ * Fix s_min_uncond default type int
+ * Fix for #10643 (Inpainting mask sometimes not working)
+ * fix bad styling for thumbs view in extra networks #10639
+ * fix for empty list of optimizations #10605
+ * small fixes to prepare_tcmalloc for Debian/Ubuntu compatibility
+ * fix --ui-debug-mode exit
+ * patch GitPython to not use leaky persistent processes
+ * fix duplicate Cross attention optimization after UI reload
+ * torch.cuda.is_available() check for SdOptimizationXformers
+ * fix hires fix using wrong conds in second pass if using Loras.
+ * handle exception when parsing generation parameters from png info
+ * fix upcast attention dtype error
+ * forcing Torch Version to 1.13.1 for RX 5000 series GPUs
+ * split mask blur into X and Y components, patch Outpainting MK2 accordingly
+ * don't die when a LoRA is a broken symlink
+ * allow activation of Generate Forever during generation
+
+
+## 1.3.2
+
+### Bug Fixes:
+ * fix files served out of tmp directory even if they are saved to disk
+ * fix postprocessing overwriting parameters
+
+## 1.3.1
+
+### Features:
+ * revert default cross attention optimization to Doggettx
+
+### Bug Fixes:
+ * fix bug: LoRA don't apply on dropdown list sd_lora
+ * fix png info always added even if setting is not enabled
+ * fix some fields not applying in xyz plot
+ * fix "hires. fix" prompt sharing same labels with txt2img_prompt
+ * fix lora hashes not being added properly to infotex if there is only one lora
+ * fix --use-cpu failing to work properly at startup
+ * make --disable-opt-split-attention command line option work again
+
+## 1.3.0
+
+### Features:
+ * add UI to edit defaults
+ * token merging (via dbolya/tomesd)
+ * settings tab rework: add a lot of additional explanations and links
+ * load extensions' Git metadata in parallel to loading the main program to save a ton of time during startup
+ * update extensions table: show branch, show date in separate column, and show version from tags if available
+ * TAESD - another option for cheap live previews
+ * allow choosing sampler and prompts for second pass of hires fix - hidden by default, enabled in settings
+ * calculate hashes for Lora
+ * add lora hashes to infotext
+ * when pasting infotext, use infotext's lora hashes to find local loras for `<lora:xxx:1>` entries whose hashes match loras the user has
+ * select cross attention optimization from UI
+
+### Minor:
+ * bump Gradio to 3.31.0
+ * bump PyTorch to 2.0.1 for macOS and Linux AMD
+ * allow setting defaults for elements in extensions' tabs
+ * allow selecting file type for live previews
+ * show "Loading..." for extra networks when displaying for the first time
+ * suppress ENSD infotext for samplers that don't use it
+ * clientside optimizations
+ * add options to show/hide hidden files and dirs in extra networks, and to not list models/files in hidden directories
+ * allow whitespace in styles.csv
+ * add option to reorder tabs
+ * move some functionality (swap resolution and set seed to -1) to client
+ * option to specify editor height for img2img
+ * button to copy image resolution into img2img width/height sliders
+ * switch from pyngrok to ngrok-py
+ * lazy-load images in extra networks UI
+ * set "Navigate image viewer with gamepad" option to false by default, by request
+ * change upscalers to download models into user-specified directory (from commandline args) rather than the default models/<...>
+ * allow hiding buttons in ui-config.json
+
+### Extensions:
+ * add /sdapi/v1/script-info api
+ * use Ruff to lint Python code
+ * use ESlint to lint Javascript code
+ * add/modify CFG callbacks for Self-Attention Guidance extension
+ * add command and endpoint for graceful server stopping
+ * add some locals (prompts/seeds/etc) from processing function into the Processing class as fields
+ * rework quoting for infotext items that have commas in them to use JSON (should be backwards compatible except for cases where it didn't work previously)
+ * add /sdapi/v1/refresh-loras api checkpoint post request
+ * tests overhaul
+
+### Bug Fixes:
+ * fix an issue preventing the program from starting if the user specifies a bad Gradio theme
+ * fix broken prompts from file script
+ * fix symlink scanning for extra networks
+ * fix --data-dir ignored when launching via webui-user.bat COMMANDLINE_ARGS
+ * allow web UI to be ran fully offline
+ * fix inability to run with --freeze-settings
+ * fix inability to merge checkpoint without adding metadata
+ * fix extra networks' save preview image not adding infotext for jpeg/webm
+ * remove blinking effect from text in hires fix and scale resolution preview
+ * make links to `http://<...>.git` extensions work in the extension tab
+ * fix bug with webui hanging at startup due to hanging git process
+
+
## 1.2.1
### Features:
- * add an option to always refer to lora by filenames
+ * add an option to always refer to LoRA by filenames
### Bug Fixes:
- * never refer to lora by an alias if multiple loras have same alias or the alias is called none
+ * never refer to LoRA by an alias if multiple LoRAs have same alias or the alias is called none
* fix upscalers disappearing after the user reloads UI
- * allow bf16 in safe unpickler (resolves problems with loading some loras)
+ * allow bf16 in safe unpickler (resolves problems with loading some LoRAs)
* allow web UI to be ran fully offline
* fix localizations not working
- * fix error for loras: 'LatentDiffusion' object has no attribute 'lora_layer_mapping'
+ * fix error for LoRAs: `'LatentDiffusion' object has no attribute 'lora_layer_mapping'`
## 1.2.0
### Features:
- * do not wait for stable diffusion model to load at startup
- * add filename patterns: [denoising]
- * directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
- * Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
- * Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
- * Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
- * Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
+ * do not wait for Stable Diffusion model to load at startup
+ * add filename patterns: `[denoising]`
+ * directory hiding for extra networks: dirs starting with `.` will hide their cards on extra network tabs unless specifically searched for
+ * LoRA: for the `<...>` text in prompt, use name of LoRA that is in the metdata of the file, if present, instead of filename (both can be used to activate LoRA)
+ * LoRA: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
+ * LoRA: fix some LoRAs not working (ones that have 3x3 convolution layer)
+ * LoRA: add an option to use old method of applying LoRAs (producing same results as with kohya-ss)
* add version to infotext, footer and console output when starting
* add links to wiki for filename pattern settings
* add extended info for quicksettings setting and use multiselect input instead of a text field
### Minor:
- * gradio bumped to 3.29.0
- * torch bumped to 2.0.1
- * --subpath option for gradio for use with reverse proxy
- * linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
- * possible frontend optimization: do not apply localizations if there are none
- * Add extra `None` option for VAE in XYZ plot
+ * bump Gradio to 3.29.0
+ * bump PyTorch to 2.0.1
+ * `--subpath` option for gradio for use with reverse proxy
+ * Linux/macOS: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
+ * do not apply localizations if there are none (possible frontend optimization)
+ * add extra `None` option for VAE in XYZ plot
* print error to console when batch processing in img2img fails
* create HTML for extra network pages only on demand
- * allow directories starting with . to still list their models for lora, checkpoints, etc
+ * allow directories starting with `.` to still list their models for LoRA, checkpoints, etc
* put infotext options into their own category in settings tab
* do not show licenses page when user selects Show all pages in settings
### Extensions:
- * Tooltip localization support
- * Add api method to get LoRA models with prompt
+ * tooltip localization support
+ * add API method to get LoRA models with prompt
### Bug Fixes:
- * re-add /docs endpoint
+ * re-add `/docs` endpoint
* fix gamepad navigation
* make the lightbox fullscreen image function properly
* fix squished thumbnails in extras tab
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
* fix webui showing the same image if you configure the generation to always save results into same file
* fix bug with upscalers not working properly
- * Fix MPS on PyTorch 2.0.1, Intel Macs
+ * fix MPS on PyTorch 2.0.1, Intel Macs
* make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
* prevent Reload UI button/link from reloading the page when it's not yet ready
* fix prompts from file script failing to read contents from a drag/drop file
@@ -58,20 +195,20 @@
## 1.1.1
### Bug Fixes:
- * fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
+ * fix an error that prevents running webui on PyTorch<2.0 without --disable-safe-unpickle
## 1.1.0
### Features:
- * switch to torch 2.0.0 (except for AMD GPUs)
+ * switch to PyTorch 2.0.0 (except for AMD GPUs)
* visual improvements to custom code scripts
- * add filename patterns: [clip_skip], [hasprompt<>], [batch_number], [generation_number]
+ * add filename patterns: `[clip_skip]`, `[hasprompt<>]`, `[batch_number]`, `[generation_number]`
* add support for saving init images in img2img, and record their hashes in infotext for reproducability
* automatically select current word when adjusting weight with ctrl+up/down
* add dropdowns for X/Y/Z plot
- * setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
+ * add setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
* support Gradio's theme API
* use TCMalloc on Linux by default; possible fix for memory leaks
- * (optimization) option to remove negative conditioning at low sigma values #9177
+ * add optimization option to remove negative conditioning at low sigma values #9177
* embed model merge metadata in .safetensors file
* extension settings backup/restore feature #9169
* add "resize by" and "resize to" tabs to img2img
@@ -80,22 +217,22 @@
* button to restore the progress from session lost / tab reload
### Minor:
- * gradio bumped to 3.28.1
- * in extra tab, change extras "scale to" to sliders
+ * bump Gradio to 3.28.1
+ * change "scale to" to sliders in Extras tab
* add labels to tool buttons to make it possible to hide them
* add tiled inference support for ScuNET
* add branch support for extension installation
- * change linux installation script to insall into current directory rather than /home/username
- * sort textual inversion embeddings by name (case insensitive)
+ * change Linux installation script to install into current directory rather than `/home/username`
+ * sort textual inversion embeddings by name (case-insensitive)
* allow styles.csv to be symlinked or mounted in docker
* remove the "do not add watermark to images" option
* make selected tab configurable with UI config
- * extra networks UI in now fixed height and scrollable
- * add disable_tls_verify arg for use with self-signed certs
+ * make the extra networks UI fixed height and scrollable
+ * add `disable_tls_verify` arg for use with self-signed certs
### Extensions:
- * Add reload callback
- * add is_hr_pass field for processing
+ * add reload callback
+ * add `is_hr_pass` field for processing
### Bug Fixes:
* fix broken batch image processing on 'Extras/Batch Process' tab
@@ -111,10 +248,10 @@
* one broken image in img2img batch won't stop all processing
* fix image orientation bug in train/preprocess
* fix Ngrok recreating tunnels every reload
- * fix --realesrgan-models-path and --ldsr-models-path not working
- * fix --skip-install not working
- * outpainting Mk2 & Poorman should use the SAMPLE file format to save images, not GRID file format
- * do not fail all Loras if some have failed to load when making a picture
+ * fix `--realesrgan-models-path` and `--ldsr-models-path` not working
+ * fix `--skip-install` not working
+ * use SAMPLE file format in Outpainting Mk2 & Poorman
+ * do not fail all LoRAs if some have failed to load when making a picture
## 1.0.0
* everything
diff --git a/README.md b/README.md
index 67a1a83a..73d94960 100644
--- a/README.md
+++ b/README.md
@@ -15,7 +15,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- Attention, specify parts of text that the model should pay more attention to
- a man in a `((tuxedo))` - will pay more attention to tuxedo
- a man in a `(tuxedo:1.21)` - alternative syntax
- - select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user)
+ - select text and press `Ctrl+Up` or `Ctrl+Down` (or `Command+Up` or `Command+Down` if you're on a MacOS) to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion
@@ -99,6 +99,12 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
+### Installation on Windows 10/11 with NVidia-GPUs using release package
+1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
+2. Run `update.bat`.
+3. Run `run.bat`.
+> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
+
### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win).
@@ -158,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
- Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
+- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index bc11cc6e..7f450086 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -88,7 +88,7 @@ class LDSR:
x_t = None
logs = None
- for n in range(n_runs):
+ for _ in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
@@ -110,7 +110,6 @@ class LDSR:
diffusion_steps = int(steps)
eta = 1.0
- down_sample_method = 'Lanczos'
gc.collect()
if torch.cuda.is_available:
@@ -131,11 +130,11 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
-
+
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
-
+
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"]
@@ -158,7 +157,7 @@ class LDSR:
def get_cond(selected_path):
- example = dict()
+ example = {}
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
@@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
- log = dict()
+ log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
@@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
- except:
+ except Exception:
pass
log["sample"] = x_sample
diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py
index da19cff1..bd78dece 100644
--- a/extensions-builtin/LDSR/scripts/ldsr_model.py
+++ b/extensions-builtin/LDSR/scripts/ldsr_model.py
@@ -1,13 +1,11 @@
import os
-import sys
-import traceback
-
-from basicsr.utils.download_util import load_file_from_url
+from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
-from modules import shared, script_callbacks
-import sd_hijack_autoencoder, sd_hijack_ddpm_v1
+from modules import shared, script_callbacks, errors
+import sd_hijack_autoencoder # noqa: F401
+import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler):
@@ -44,22 +42,17 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path
else:
- model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True)
-
- yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True)
+ model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
- try:
- return LDSR(model, yaml)
+ yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
- except Exception:
- print("Error importing LDSR:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- return None
+ return LDSR(model, yaml)
def do_upscale(self, img, path):
- ldsr = self.load_model(path)
- if ldsr is None:
- print("NO LDSR!")
+ try:
+ ldsr = self.load_model(path)
+ except Exception:
+ errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
index 8e03c7f8..c29d274d 100644
--- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py
+++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py
@@ -1,16 +1,21 @@
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
-
+import numpy as np
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
-from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+
+from torch.optim.lr_scheduler import LambdaLR
+
+from ldm.modules.ema import LitEma
+from vqvae_quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config
import ldm.models.autoencoder
+from packaging import version
class VQModel(pl.LightningModule):
def __init__(self,
@@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed,
embed_dim,
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
@@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@@ -76,18 +81,19 @@ class VQModel(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list()):
+ def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
- for ik in ignore_keys:
+ for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
+ if missing:
print(f"Missing Keys: {missing}")
+ if unexpected:
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
@@ -165,7 +171,7 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+ self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
@@ -232,7 +238,7 @@ class VQModel(pl.LightningModule):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
- log = dict()
+ log = {}
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
@@ -249,7 +255,8 @@ class VQModel(pl.LightningModule):
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ if x.shape[1] > 3:
+ xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
@@ -264,7 +271,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
@@ -282,5 +289,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant)
return dec
-setattr(ldm.models.autoencoder, "VQModel", VQModel)
-setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
+ldm.models.autoencoder.VQModel = VQModel
+ldm.models.autoencoder.VQModelInterface = VQModelInterface
diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
index 5c0488e5..04adc5eb 100644
--- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
+++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
@@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@@ -182,22 +182,22 @@ class DDPMV1(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
- for ik in ignore_keys:
+ for ik in ignore_keys or []:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
+ if missing:
print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
+ if unexpected:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
@@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
+ log = {}
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
@@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
log["inputs"] = x
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
x_start = x[:n_row]
for t in range(self.num_timesteps):
@@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
+ except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
@@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
- self.bbox_tokenizer = None
+ self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
@@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
+ if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
@@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
- def rescale_bbox(bbox):
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
- return x0, y0, w, h
-
- return [rescale_bbox(b) for b in bboxes]
-
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
@@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
- assert not return_ids
+ assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
@@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
return img, intermediates
@torch.no_grad()
@@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
if return_intermediates:
return img, intermediates
@@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
@@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
use_ddim = ddim_steps is not None
- log = dict()
+ log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
@@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
if plot_diffusion_rows:
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
if inpaint:
# make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+ super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+ logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
@@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img
return logs
-setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
-setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
-setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
-setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
+ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
+ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
+ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
+ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
diff --git a/extensions-builtin/LDSR/vqvae_quantize.py b/extensions-builtin/LDSR/vqvae_quantize.py
new file mode 100644
index 00000000..dd14b8fd
--- /dev/null
+++ b/extensions-builtin/LDSR/vqvae_quantize.py
@@ -0,0 +1,147 @@
+# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,
+# where the license is as follows:
+#
+# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
+# OR OTHER DEALINGS IN THE SOFTWARE./
+
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
+ sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
+ assert return_logits is False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index ccb249ac..66ee9c85 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -9,19 +9,37 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_lora
- if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
multipliers = []
for params in params_list:
- assert len(params.items) > 0
+ assert params.items
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
lora.load_loras(names, multipliers)
+ if shared.opts.lora_add_hashes_to_infotext:
+ lora_hashes = []
+ for item in lora.loaded_loras:
+ shorthash = item.lora_on_disk.shorthash
+ if not shorthash:
+ continue
+
+ alias = item.mentioned_name
+ if not alias:
+ continue
+
+ alias = alias.replace(":", "").replace(",", "")
+
+ lora_hashes.append(f"{alias}: {shorthash}")
+
+ if lora_hashes:
+ p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
+
def deactivate(self, p):
pass
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index b5d0c98f..cd46e6c7 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -1,10 +1,9 @@
-import glob
import os
import re
import torch
from typing import Union
-from modules import shared, devices, sd_models, errors, scripts
+from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@@ -77,9 +76,9 @@ class LoraOnDisk:
self.name = name
self.filename = filename
self.metadata = {}
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
- _, ext = os.path.splitext(filename)
- if ext.lower() == ".safetensors":
+ if self.is_safetensors:
try:
self.metadata = sd_models.read_metadata_from_safetensors(filename)
except Exception as e:
@@ -95,14 +94,43 @@ class LoraOnDisk:
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
+ self.hash = None
+ self.shorthash = None
+ self.set_hash(
+ self.metadata.get('sshs_model_hash') or
+ hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
+ ''
+ )
+
+ def set_hash(self, v):
+ self.hash = v
+ self.shorthash = self.hash[0:12]
+
+ if self.shorthash:
+ available_lora_hash_lookup[self.shorthash] = self
+
+ def read_hash(self):
+ if not self.hash:
+ self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
+
+ def get_alias(self):
+ if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
+ return self.name
+ else:
+ return self.alias
+
class LoraModule:
- def __init__(self, name):
+ def __init__(self, name, lora_on_disk: LoraOnDisk):
self.name = name
+ self.lora_on_disk = lora_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
+ self.mentioned_name = None
+ """the text that was used to add lora to prompt - can be either name or an alias"""
+
class LoraUpDownModule:
def __init__(self):
@@ -127,11 +155,11 @@ def assign_lora_names_to_compvis_modules(sd_model):
sd_model.lora_layer_mapping = lora_layer_mapping
-def load_lora(name, filename):
- lora = LoraModule(name)
- lora.mtime = os.path.getmtime(filename)
+def load_lora(name, lora_on_disk):
+ lora = LoraModule(name, lora_on_disk)
+ lora.mtime = os.path.getmtime(lora_on_disk.filename)
- sd = sd_models.read_state_dict(filename)
+ sd = sd_models.read_state_dict(lora_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
@@ -177,7 +205,7 @@ def load_lora(name, filename):
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+ raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -189,10 +217,10 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+ raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
- if len(keys_failed_to_match) > 0:
- print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
+ if keys_failed_to_match:
+ print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
return lora
@@ -207,30 +235,41 @@ def load_loras(names, multipliers=None):
loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
- if any([x is None for x in loras_on_disk]):
+ if any(x is None for x in loras_on_disk):
list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
+ failed_to_load_loras = []
+
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i]
+
if lora_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
try:
- lora = load_lora(name, lora_on_disk.filename)
+ lora = load_lora(name, lora_on_disk)
except Exception as e:
errors.display(e, f"loading Lora {lora_on_disk.filename}")
continue
+ lora.mentioned_name = name
+
+ lora_on_disk.read_hash()
+
if lora is None:
+ failed_to_load_loras.append(name)
print(f"Couldn't find Lora with name {name}")
continue
lora.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora)
+ if failed_to_load_loras:
+ sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
+
def lora_calc_updown(lora, module, target):
with torch.no_grad():
@@ -314,7 +353,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}')
- setattr(self, "lora_current_names", wanted_names)
+ self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward):
@@ -348,8 +387,8 @@ def lora_forward(module, input, original_forward):
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+ self.lora_current_names = ()
+ self.lora_weights_backup = None
def lora_Linear_forward(self, input):
@@ -398,17 +437,22 @@ def list_available_loras():
available_loras.clear()
available_lora_aliases.clear()
forbidden_lora_aliases.clear()
- forbidden_lora_aliases.update({"none": 1})
+ available_lora_hash_lookup.clear()
+ forbidden_lora_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
- for filename in sorted(candidates, key=str.lower):
+ for filename in candidates:
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
- entry = LoraOnDisk(name, filename)
+ try:
+ entry = LoraOnDisk(name, filename)
+ except OSError: # should catch FileNotFoundError and PermissionError etc.
+ errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
+ continue
available_loras[name] = entry
@@ -428,7 +472,7 @@ def infotext_pasted(infotext, params):
added = []
- for k, v in params.items():
+ for k in params:
if not k.startswith("AddNet Model "):
continue
@@ -452,8 +496,10 @@ def infotext_pasted(infotext, params):
if added:
params["Prompt"] += "\n" + "".join(added)
+
available_loras = {}
available_lora_aliases = {}
+available_lora_hash_lookup = {}
forbidden_lora_aliases = {}
loaded_loras = []
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 060bda05..e650f469 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -1,3 +1,5 @@
+import re
+
import torch
import gradio as gr
from fastapi import FastAPI
@@ -53,8 +55,9 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
- "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
+ "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
+ "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
}))
@@ -77,6 +80,37 @@ def api_loras(_: gr.Blocks, app: FastAPI):
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
+ @app.post("/sdapi/v1/refresh-loras")
+ async def refresh_loras():
+ return lora.list_available_loras()
+
script_callbacks.on_app_started(api_loras)
+re_lora = re.compile("<lora:([^:]+):")
+
+
+def infotext_pasted(infotext, d):
+ hashes = d.get("Lora hashes")
+ if not hashes:
+ return
+
+ hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
+ hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
+
+ def lora_replacement(m):
+ alias = m.group(1)
+ shorthash = hashes.get(alias)
+ if shorthash is None:
+ return m.group(0)
+
+ lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
+ if lora_on_disk is None:
+ return m.group(0)
+
+ return f'<lora:{lora_on_disk.get_alias()}:'
+
+ d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
+
+
+script_callbacks.on_infotext_pasted(infotext_pasted)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index 2050e3fa..da49790b 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -13,13 +13,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
lora.list_available_loras()
def list_items(self):
- for name, lora_on_disk in lora.available_loras.items():
+ for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()):
path, ext = os.path.splitext(lora_on_disk.filename)
- if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
- alias = name
- else:
- alias = lora_on_disk.alias
+ alias = lora_on_disk.get_alias()
yield {
"name": name,
@@ -30,6 +27,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
+ "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
+
}
def allowed_directories_for_previews(self):
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index c7fd5739..ffef26b2 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -1,19 +1,16 @@
-import os.path
import sys
-import traceback
import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
-from basicsr.utils.download_util import load_file_from_url
-
import modules.upscaler
-from modules import devices, modelloader
-from scunet_model_arch import SCUNet as net
+from modules import devices, modelloader, script_callbacks, errors
+from scunet_model_arch import SCUNet
+
+from modules.modelloader import load_file_from_url
from modules.shared import opts
-from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -29,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers = []
add_model2 = True
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -39,8 +36,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
- print(f"Error loading ScuNET model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading ScuNET model: {file}", exc_info=True)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
@@ -91,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch.cuda.empty_cache()
- model = self.load_model(selected_file)
- if model is None:
- print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
+ try:
+ model = self.load_model(selected_file)
+ except Exception as e:
+ print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
@@ -121,20 +118,27 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str):
device = devices.get_device_for('scunet')
- if "http" in path:
- filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
- progress=True)
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
- if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
- print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
- return None
-
- model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
+ model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
- for k, v in model.named_parameters():
+ for _, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
+
+
+def on_ui_settings():
+ import gradio as gr
+ from modules import shared
+
+ shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
+ shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py
index 43ca8d36..b51a8806 100644
--- a/extensions-builtin/ScuNET/scunet_model_arch.py
+++ b/extensions-builtin/ScuNET/scunet_model_arch.py
@@ -61,7 +61,9 @@ class WMSA(nn.Module):
Returns:
output: tensor shape [b h w c]
"""
- if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+ if self.type != 'W':
+ x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
+
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1)
w_windows = x.size(2)
@@ -85,8 +87,9 @@ class WMSA(nn.Module):
output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
- if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
- dims=(1, 2))
+ if self.type != 'W':
+ output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
+
return output
def relative_embedding(self):
@@ -262,4 +265,4 @@ class SCUNet(nn.Module):
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0) \ No newline at end of file
+ nn.init.constant_(m.weight, 1.0)
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index e8783bca..c6bc53a8 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -1,18 +1,17 @@
-import contextlib
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import cmd_opts, opts, state
-from swinir_model_arch import SwinIR as net
-from swinir_model_arch_v2 import Swin2SR as net2
+from modules.shared import opts, state
+from swinir_model_arch import SwinIR
+from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData
+SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir')
@@ -20,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
self.name = "SwinIR"
- self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
- "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
- "-L_x4_GAN.pth "
+ self.model_url = SWINIR_MODEL_URL
self.model_name = "SwinIR 4x"
self.user_path = dirname
super().__init__()
scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
- if "http" in model:
+ if model.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(model)
@@ -38,42 +35,45 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers
def do_upscale(self, img, model_file):
- model = self.load_model(model_file)
- if model is None:
+ try:
+ model = self.load_model(model_file)
+ except Exception as e:
+ print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
- except:
+ except Exception:
pass
return img
def load_model(self, path, scale=4):
- if "http" in path:
- dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
+ if path.startswith("http"):
+ filename = modelloader.load_file_from_url(
+ url=path,
+ model_dir=self.model_download_path,
+ file_name=f"{self.model_name.replace(' ', '_')}.pth",
+ )
else:
filename = path
- if filename is None or not os.path.exists(filename):
- return None
if filename.endswith(".v2.pth"):
- model = net2(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6],
- embed_dim=180,
- num_heads=[6, 6, 6, 6, 6, 6],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="1conv",
+ model = Swin2SR(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="1conv",
)
params = None
else:
- model = net(
+ model = SwinIR(
upscale=scale,
in_chans=3,
img_size=64,
@@ -151,7 +151,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break
-
+
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py
index 863f42db..93b93274 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch.py
@@ -644,7 +644,7 @@ class SwinIR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
@@ -805,7 +805,7 @@ class SwinIR(nn.Module):
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
-
+
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
@@ -844,7 +844,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
index 0e28ae6e..dad22cca 100644
--- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py
+++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py
@@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
+ pretrained_window_size=(0, 0)):
super().__init__()
self.dim = dim
@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
-
+
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
- return attn_mask
+ return attn_mask
def forward(self, x, x_size):
H, W = x_size
@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
+
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
@@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
- return flops
+ return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
@@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
-
+
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
- return flops
+ return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
@@ -531,7 +531,7 @@ class RSTB(nn.Module):
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
@@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
-
+
class Upsample_hf(nn.Sequential):
"""Upsample module.
@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
+ super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
-
-
+
+
class Swin2SR(nn.Module):
r""" Swin2SR
@@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True,
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
@@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
)
self.layers.append(layer)
-
+
if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers):
@@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
)
self.layers_hf.append(layer)
-
+
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
@@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
+ nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
@@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
@@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
x = self.patch_unembed(x, x_size)
return x
-
+
def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
@@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
- return x
+ return x
def forward(self, x):
H, W = x.shape[2:]
@@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before))
-
+
x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf)
@@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
-
+
x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
+
elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
+
else:
return x[:, :, :H*self.upscale, :W*self.upscale]
@@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
- for i, layer in enumerate(self.layers):
+ for layer in self.layers:
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
@@ -1014,4 +1014,4 @@ if __name__ == '__main__':
x = torch.randn((1, 3, height, width))
x = model(x)
- print(x.shape) \ No newline at end of file
+ print(x.shape)
diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
new file mode 100644
index 00000000..30199dcd
--- /dev/null
+++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js
@@ -0,0 +1,776 @@
+onUiLoaded(async() => {
+ const elementIDs = {
+ img2imgTabs: "#mode_img2img .tab-nav",
+ inpaint: "#img2maskimg",
+ inpaintSketch: "#inpaint_sketch",
+ rangeGroup: "#img2img_column_size",
+ sketch: "#img2img_sketch"
+ };
+ const tabNameToElementId = {
+ "Inpaint sketch": elementIDs.inpaintSketch,
+ "Inpaint": elementIDs.inpaint,
+ "Sketch": elementIDs.sketch
+ };
+
+ // Helper functions
+ // Get active tab
+ function getActiveTab(elements, all = false) {
+ const tabs = elements.img2imgTabs.querySelectorAll("button");
+
+ if (all) return tabs;
+
+ for (let tab of tabs) {
+ if (tab.classList.contains("selected")) {
+ return tab;
+ }
+ }
+ }
+
+ // Get tab ID
+ function getTabId(elements) {
+ const activeTab = getActiveTab(elements);
+ return tabNameToElementId[activeTab.innerText];
+ }
+
+ // Wait until opts loaded
+ async function waitForOpts() {
+ for (;;) {
+ if (window.opts && Object.keys(window.opts).length) {
+ return window.opts;
+ }
+ await new Promise(resolve => setTimeout(resolve, 100));
+ }
+ }
+
+ // Function for defining the "Ctrl", "Shift" and "Alt" keys
+ function isModifierKey(event, key) {
+ switch (key) {
+ case "Ctrl":
+ return event.ctrlKey;
+ case "Shift":
+ return event.shiftKey;
+ case "Alt":
+ return event.altKey;
+ default:
+ return false;
+ }
+ }
+
+ // Check if hotkey is valid
+ function isValidHotkey(value) {
+ const specialKeys = ["Ctrl", "Alt", "Shift", "Disable"];
+ return (
+ (typeof value === "string" &&
+ value.length === 1 &&
+ /[a-z]/i.test(value)) ||
+ specialKeys.includes(value)
+ );
+ }
+
+ // Normalize hotkey
+ function normalizeHotkey(hotkey) {
+ return hotkey.length === 1 ? "Key" + hotkey.toUpperCase() : hotkey;
+ }
+
+ // Format hotkey for display
+ function formatHotkeyForDisplay(hotkey) {
+ return hotkey.startsWith("Key") ? hotkey.slice(3) : hotkey;
+ }
+
+ // Create hotkey configuration with the provided options
+ function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) {
+ const result = {}; // Resulting hotkey configuration
+ const usedKeys = new Set(); // Set of used hotkeys
+
+ // Iterate through defaultHotkeysConfig keys
+ for (const key in defaultHotkeysConfig) {
+ const userValue = hotkeysConfigOpts[key]; // User-provided hotkey value
+ const defaultValue = defaultHotkeysConfig[key]; // Default hotkey value
+
+ // Apply appropriate value for undefined, boolean, or object userValue
+ if (
+ userValue === undefined ||
+ typeof userValue === "boolean" ||
+ typeof userValue === "object" ||
+ userValue === "disable"
+ ) {
+ result[key] =
+ userValue === undefined ? defaultValue : userValue;
+ } else if (isValidHotkey(userValue)) {
+ const normalizedUserValue = normalizeHotkey(userValue);
+
+ // Check for conflicting hotkeys
+ if (!usedKeys.has(normalizedUserValue)) {
+ usedKeys.add(normalizedUserValue);
+ result[key] = normalizedUserValue;
+ } else {
+ console.error(
+ `Hotkey: ${formatHotkeyForDisplay(
+ userValue
+ )} for ${key} is repeated and conflicts with another hotkey. The default hotkey is used: ${formatHotkeyForDisplay(
+ defaultValue
+ )}`
+ );
+ result[key] = defaultValue;
+ }
+ } else {
+ console.error(
+ `Hotkey: ${formatHotkeyForDisplay(
+ userValue
+ )} for ${key} is not valid. The default hotkey is used: ${formatHotkeyForDisplay(
+ defaultValue
+ )}`
+ );
+ result[key] = defaultValue;
+ }
+ }
+
+ return result;
+ }
+
+ // Disables functions in the config object based on the provided list of function names
+ function disableFunctions(config, disabledFunctions) {
+ // Bind the hasOwnProperty method to the functionMap object to avoid errors
+ const hasOwnProperty =
+ Object.prototype.hasOwnProperty.bind(functionMap);
+
+ // Loop through the disabledFunctions array and disable the corresponding functions in the config object
+ disabledFunctions.forEach(funcName => {
+ if (hasOwnProperty(funcName)) {
+ const key = functionMap[funcName];
+ config[key] = "disable";
+ }
+ });
+
+ // Return the updated config object
+ return config;
+ }
+
+ /**
+ * The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio.
+ * If the image display property is set to 'none', the mask breaks. To fix this, the function
+ * temporarily sets the display property to 'block' and then hides the mask again after 300 milliseconds
+ * to avoid breaking the canvas. Additionally, the function adjusts the mask to work correctly on
+ * very long images.
+ */
+ function restoreImgRedMask(elements) {
+ const mainTabId = getTabId(elements);
+
+ if (!mainTabId) return;
+
+ const mainTab = gradioApp().querySelector(mainTabId);
+ const img = mainTab.querySelector("img");
+ const imageARPreview = gradioApp().querySelector("#imageARPreview");
+
+ if (!img || !imageARPreview) return;
+
+ imageARPreview.style.transform = "";
+ if (parseFloat(mainTab.style.width) > 865) {
+ const transformString = mainTab.style.transform;
+ const scaleMatch = transformString.match(
+ /scale\(([-+]?[0-9]*\.?[0-9]+)\)/
+ );
+ let zoom = 1; // default zoom
+
+ if (scaleMatch && scaleMatch[1]) {
+ zoom = Number(scaleMatch[1]);
+ }
+
+ imageARPreview.style.transformOrigin = "0 0";
+ imageARPreview.style.transform = `scale(${zoom})`;
+ }
+
+ if (img.style.display !== "none") return;
+
+ img.style.display = "block";
+
+ setTimeout(() => {
+ img.style.display = "none";
+ }, 400);
+ }
+
+ const hotkeysConfigOpts = await waitForOpts();
+
+ // Default config
+ const defaultHotkeysConfig = {
+ canvas_hotkey_zoom: "Alt",
+ canvas_hotkey_adjust: "Ctrl",
+ canvas_hotkey_reset: "KeyR",
+ canvas_hotkey_fullscreen: "KeyS",
+ canvas_hotkey_move: "KeyF",
+ canvas_hotkey_overlap: "KeyO",
+ canvas_disabled_functions: [],
+ canvas_show_tooltip: true,
+ canvas_blur_prompt: false
+ };
+
+ const functionMap = {
+ "Zoom": "canvas_hotkey_zoom",
+ "Adjust brush size": "canvas_hotkey_adjust",
+ "Moving canvas": "canvas_hotkey_move",
+ "Fullscreen": "canvas_hotkey_fullscreen",
+ "Reset Zoom": "canvas_hotkey_reset",
+ "Overlap": "canvas_hotkey_overlap"
+ };
+
+ // Loading the configuration from opts
+ const preHotkeysConfig = createHotkeyConfig(
+ defaultHotkeysConfig,
+ hotkeysConfigOpts
+ );
+
+ // Disable functions that are not needed by the user
+ const hotkeysConfig = disableFunctions(
+ preHotkeysConfig,
+ preHotkeysConfig.canvas_disabled_functions
+ );
+
+ let isMoving = false;
+ let mouseX, mouseY;
+ let activeElement;
+
+ const elements = Object.fromEntries(
+ Object.keys(elementIDs).map(id => [
+ id,
+ gradioApp().querySelector(elementIDs[id])
+ ])
+ );
+ const elemData = {};
+
+ // Apply functionality to the range inputs. Restore redmask and correct for long images.
+ const rangeInputs = elements.rangeGroup ?
+ Array.from(elements.rangeGroup.querySelectorAll("input")) :
+ [
+ gradioApp().querySelector("#img2img_width input[type='range']"),
+ gradioApp().querySelector("#img2img_height input[type='range']")
+ ];
+
+ for (const input of rangeInputs) {
+ input?.addEventListener("input", () => restoreImgRedMask(elements));
+ }
+
+ function applyZoomAndPan(elemId) {
+ const targetElement = gradioApp().querySelector(elemId);
+
+ if (!targetElement) {
+ console.log("Element not found");
+ return;
+ }
+
+ targetElement.style.transformOrigin = "0 0";
+
+ elemData[elemId] = {
+ zoom: 1,
+ panX: 0,
+ panY: 0
+ };
+ let fullScreenMode = false;
+
+ // Create tooltip
+ function createTooltip() {
+ const toolTipElemnt =
+ targetElement.querySelector(".image-container");
+ const tooltip = document.createElement("div");
+ tooltip.className = "canvas-tooltip";
+
+ // Creating an item of information
+ const info = document.createElement("i");
+ info.className = "canvas-tooltip-info";
+ info.textContent = "";
+
+ // Create a container for the contents of the tooltip
+ const tooltipContent = document.createElement("div");
+ tooltipContent.className = "canvas-tooltip-content";
+
+ // Define an array with hotkey information and their actions
+ const hotkeysInfo = [
+ {
+ configKey: "canvas_hotkey_zoom",
+ action: "Zoom canvas",
+ keySuffix: " + wheel"
+ },
+ {
+ configKey: "canvas_hotkey_adjust",
+ action: "Adjust brush size",
+ keySuffix: " + wheel"
+ },
+ {configKey: "canvas_hotkey_reset", action: "Reset zoom"},
+ {
+ configKey: "canvas_hotkey_fullscreen",
+ action: "Fullscreen mode"
+ },
+ {configKey: "canvas_hotkey_move", action: "Move canvas"},
+ {configKey: "canvas_hotkey_overlap", action: "Overlap"}
+ ];
+
+ // Create hotkeys array with disabled property based on the config values
+ const hotkeys = hotkeysInfo.map(info => {
+ const configValue = hotkeysConfig[info.configKey];
+ const key = info.keySuffix ?
+ `${configValue}${info.keySuffix}` :
+ configValue.charAt(configValue.length - 1);
+ return {
+ key,
+ action: info.action,
+ disabled: configValue === "disable"
+ };
+ });
+
+ for (const hotkey of hotkeys) {
+ if (hotkey.disabled) {
+ continue;
+ }
+
+ const p = document.createElement("p");
+ p.innerHTML = `<b>${hotkey.key}</b> - ${hotkey.action}`;
+ tooltipContent.appendChild(p);
+ }
+
+ // Add information and content elements to the tooltip element
+ tooltip.appendChild(info);
+ tooltip.appendChild(tooltipContent);
+
+ // Add a hint element to the target element
+ toolTipElemnt.appendChild(tooltip);
+ }
+
+ //Show tool tip if setting enable
+ if (hotkeysConfig.canvas_show_tooltip) {
+ createTooltip();
+ }
+
+ // In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui.
+ function fixCanvas() {
+ const activeTab = getActiveTab(elements).textContent.trim();
+
+ if (activeTab !== "img2img") {
+ const img = targetElement.querySelector(`${elemId} img`);
+
+ if (img && img.style.display !== "none") {
+ img.style.display = "none";
+ img.style.visibility = "hidden";
+ }
+ }
+ }
+
+ // Reset the zoom level and pan position of the target element to their initial values
+ function resetZoom() {
+ elemData[elemId] = {
+ zoomLevel: 1,
+ panX: 0,
+ panY: 0
+ };
+
+ fixCanvas();
+ targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
+
+ const canvas = gradioApp().querySelector(
+ `${elemId} canvas[key="interface"]`
+ );
+
+ toggleOverlap("off");
+ fullScreenMode = false;
+
+ if (
+ canvas &&
+ parseFloat(canvas.style.width) > 865 &&
+ parseFloat(targetElement.style.width) > 865
+ ) {
+ fitToElement();
+ return;
+ }
+
+ targetElement.style.width = "";
+ if (canvas) {
+ targetElement.style.height = canvas.style.height;
+ }
+ }
+
+ // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
+ function toggleOverlap(forced = "") {
+ const zIndex1 = "0";
+ const zIndex2 = "998";
+
+ targetElement.style.zIndex =
+ targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;
+
+ if (forced === "off") {
+ targetElement.style.zIndex = zIndex1;
+ } else if (forced === "on") {
+ targetElement.style.zIndex = zIndex2;
+ }
+ }
+
+ // Adjust the brush size based on the deltaY value from a mouse wheel event
+ function adjustBrushSize(
+ elemId,
+ deltaY,
+ withoutValue = false,
+ percentage = 5
+ ) {
+ const input =
+ gradioApp().querySelector(
+ `${elemId} input[aria-label='Brush radius']`
+ ) ||
+ gradioApp().querySelector(
+ `${elemId} button[aria-label="Use brush"]`
+ );
+
+ if (input) {
+ input.click();
+ if (!withoutValue) {
+ const maxValue =
+ parseFloat(input.getAttribute("max")) || 100;
+ const changeAmount = maxValue * (percentage / 100);
+ const newValue =
+ parseFloat(input.value) +
+ (deltaY > 0 ? -changeAmount : changeAmount);
+ input.value = Math.min(Math.max(newValue, 0), maxValue);
+ input.dispatchEvent(new Event("change"));
+ }
+ }
+ }
+
+ // Reset zoom when uploading a new image
+ const fileInput = gradioApp().querySelector(
+ `${elemId} input[type="file"][accept="image/*"].svelte-116rqfv`
+ );
+ fileInput.addEventListener("click", resetZoom);
+
+ // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
+ function updateZoom(newZoomLevel, mouseX, mouseY) {
+ newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
+
+ elemData[elemId].panX +=
+ mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;
+ elemData[elemId].panY +=
+ mouseY - (mouseY * newZoomLevel) / elemData[elemId].zoomLevel;
+
+ targetElement.style.transformOrigin = "0 0";
+ targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
+
+ toggleOverlap("on");
+ return newZoomLevel;
+ }
+
+ // Change the zoom level based on user interaction
+ function changeZoomLevel(operation, e) {
+ if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {
+ e.preventDefault();
+
+ let zoomPosX, zoomPosY;
+ let delta = 0.2;
+ if (elemData[elemId].zoomLevel > 7) {
+ delta = 0.9;
+ } else if (elemData[elemId].zoomLevel > 2) {
+ delta = 0.6;
+ }
+
+ zoomPosX = e.clientX;
+ zoomPosY = e.clientY;
+
+ fullScreenMode = false;
+ elemData[elemId].zoomLevel = updateZoom(
+ elemData[elemId].zoomLevel +
+ (operation === "+" ? delta : -delta),
+ zoomPosX - targetElement.getBoundingClientRect().left,
+ zoomPosY - targetElement.getBoundingClientRect().top
+ );
+ }
+ }
+
+ /**
+ * This function fits the target element to the screen by calculating
+ * the required scale and offsets. It also updates the global variables
+ * zoomLevel, panX, and panY to reflect the new state.
+ */
+
+ function fitToElement() {
+ //Reset Zoom
+ targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
+
+ // Get element and screen dimensions
+ const elementWidth = targetElement.offsetWidth;
+ const elementHeight = targetElement.offsetHeight;
+ const parentElement = targetElement.parentElement;
+ const screenWidth = parentElement.clientWidth;
+ const screenHeight = parentElement.clientHeight;
+
+ // Get element's coordinates relative to the parent element
+ const elementRect = targetElement.getBoundingClientRect();
+ const parentRect = parentElement.getBoundingClientRect();
+ const elementX = elementRect.x - parentRect.x;
+
+ // Calculate scale and offsets
+ const scaleX = screenWidth / elementWidth;
+ const scaleY = screenHeight / elementHeight;
+ const scale = Math.min(scaleX, scaleY);
+
+ const transformOrigin =
+ window.getComputedStyle(targetElement).transformOrigin;
+ const [originX, originY] = transformOrigin.split(" ");
+ const originXValue = parseFloat(originX);
+ const originYValue = parseFloat(originY);
+
+ const offsetX =
+ (screenWidth - elementWidth * scale) / 2 -
+ originXValue * (1 - scale);
+ const offsetY =
+ (screenHeight - elementHeight * scale) / 2.5 -
+ originYValue * (1 - scale);
+
+ // Apply scale and offsets to the element
+ targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
+
+ // Update global variables
+ elemData[elemId].zoomLevel = scale;
+ elemData[elemId].panX = offsetX;
+ elemData[elemId].panY = offsetY;
+
+ fullScreenMode = false;
+ toggleOverlap("off");
+ }
+
+ /**
+ * This function fits the target element to the screen by calculating
+ * the required scale and offsets. It also updates the global variables
+ * zoomLevel, panX, and panY to reflect the new state.
+ */
+
+ // Fullscreen mode
+ function fitToScreen() {
+ const canvas = gradioApp().querySelector(
+ `${elemId} canvas[key="interface"]`
+ );
+
+ if (!canvas) return;
+
+ if (canvas.offsetWidth > 862) {
+ targetElement.style.width = canvas.offsetWidth + "px";
+ }
+
+ if (fullScreenMode) {
+ resetZoom();
+ fullScreenMode = false;
+ return;
+ }
+
+ //Reset Zoom
+ targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
+
+ // Get scrollbar width to right-align the image
+ const scrollbarWidth =
+ window.innerWidth - document.documentElement.clientWidth;
+
+ // Get element and screen dimensions
+ const elementWidth = targetElement.offsetWidth;
+ const elementHeight = targetElement.offsetHeight;
+ const screenWidth = window.innerWidth - scrollbarWidth;
+ const screenHeight = window.innerHeight;
+
+ // Get element's coordinates relative to the page
+ const elementRect = targetElement.getBoundingClientRect();
+ const elementY = elementRect.y;
+ const elementX = elementRect.x;
+
+ // Calculate scale and offsets
+ const scaleX = screenWidth / elementWidth;
+ const scaleY = screenHeight / elementHeight;
+ const scale = Math.min(scaleX, scaleY);
+
+ // Get the current transformOrigin
+ const computedStyle = window.getComputedStyle(targetElement);
+ const transformOrigin = computedStyle.transformOrigin;
+ const [originX, originY] = transformOrigin.split(" ");
+ const originXValue = parseFloat(originX);
+ const originYValue = parseFloat(originY);
+
+ // Calculate offsets with respect to the transformOrigin
+ const offsetX =
+ (screenWidth - elementWidth * scale) / 2 -
+ elementX -
+ originXValue * (1 - scale);
+ const offsetY =
+ (screenHeight - elementHeight * scale) / 2 -
+ elementY -
+ originYValue * (1 - scale);
+
+ // Apply scale and offsets to the element
+ targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;
+
+ // Update global variables
+ elemData[elemId].zoomLevel = scale;
+ elemData[elemId].panX = offsetX;
+ elemData[elemId].panY = offsetY;
+
+ fullScreenMode = true;
+ toggleOverlap("on");
+ }
+
+ // Handle keydown events
+ function handleKeyDown(event) {
+ // Disable key locks to make pasting from the buffer work correctly
+ if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") {
+ return;
+ }
+
+ // before activating shortcut, ensure user is not actively typing in an input field
+ if (!hotkeysConfig.canvas_blur_prompt) {
+ if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {
+ return;
+ }
+ }
+
+
+ const hotkeyActions = {
+ [hotkeysConfig.canvas_hotkey_reset]: resetZoom,
+ [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
+ [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen
+ };
+
+ const action = hotkeyActions[event.code];
+ if (action) {
+ event.preventDefault();
+ action(event);
+ }
+
+ if (
+ isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) ||
+ isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust)
+ ) {
+ event.preventDefault();
+ }
+ }
+
+ // Get Mouse position
+ function getMousePosition(e) {
+ mouseX = e.offsetX;
+ mouseY = e.offsetY;
+ }
+
+ targetElement.addEventListener("mousemove", getMousePosition);
+
+ // Handle events only inside the targetElement
+ let isKeyDownHandlerAttached = false;
+
+ function handleMouseMove() {
+ if (!isKeyDownHandlerAttached) {
+ document.addEventListener("keydown", handleKeyDown);
+ isKeyDownHandlerAttached = true;
+
+ activeElement = elemId;
+ }
+ }
+
+ function handleMouseLeave() {
+ if (isKeyDownHandlerAttached) {
+ document.removeEventListener("keydown", handleKeyDown);
+ isKeyDownHandlerAttached = false;
+
+ activeElement = null;
+ }
+ }
+
+ // Add mouse event handlers
+ targetElement.addEventListener("mousemove", handleMouseMove);
+ targetElement.addEventListener("mouseleave", handleMouseLeave);
+
+ // Reset zoom when click on another tab
+ elements.img2imgTabs.addEventListener("click", resetZoom);
+ elements.img2imgTabs.addEventListener("click", () => {
+ // targetElement.style.width = "";
+ if (parseInt(targetElement.style.width) > 865) {
+ setTimeout(fitToElement, 0);
+ }
+ });
+
+ targetElement.addEventListener("wheel", e => {
+ // change zoom level
+ const operation = e.deltaY > 0 ? "-" : "+";
+ changeZoomLevel(operation, e);
+
+ // Handle brush size adjustment with ctrl key pressed
+ if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {
+ e.preventDefault();
+
+ // Increase or decrease brush size based on scroll direction
+ adjustBrushSize(elemId, e.deltaY);
+ }
+ });
+
+ // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
+ function handleMoveKeyDown(e) {
+
+ // Disable key locks to make pasting from the buffer work correctly
+ if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") {
+ return;
+ }
+
+ // before activating shortcut, ensure user is not actively typing in an input field
+ if (!hotkeysConfig.canvas_blur_prompt) {
+ if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {
+ return;
+ }
+ }
+
+
+ if (e.code === hotkeysConfig.canvas_hotkey_move) {
+ if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
+ e.preventDefault();
+ document.activeElement.blur();
+ isMoving = true;
+ }
+ }
+ }
+
+ function handleMoveKeyUp(e) {
+ if (e.code === hotkeysConfig.canvas_hotkey_move) {
+ isMoving = false;
+ }
+ }
+
+ document.addEventListener("keydown", handleMoveKeyDown);
+ document.addEventListener("keyup", handleMoveKeyUp);
+
+ // Detect zoom level and update the pan speed.
+ function updatePanPosition(movementX, movementY) {
+ let panSpeed = 2;
+
+ if (elemData[elemId].zoomLevel > 8) {
+ panSpeed = 3.5;
+ }
+
+ elemData[elemId].panX += movementX * panSpeed;
+ elemData[elemId].panY += movementY * panSpeed;
+
+ // Delayed redraw of an element
+ requestAnimationFrame(() => {
+ targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${elemData[elemId].zoomLevel})`;
+ toggleOverlap("on");
+ });
+ }
+
+ function handleMoveByKey(e) {
+ if (isMoving && elemId === activeElement) {
+ updatePanPosition(e.movementX, e.movementY);
+ targetElement.style.pointerEvents = "none";
+ } else {
+ targetElement.style.pointerEvents = "auto";
+ }
+ }
+
+ // Prevents sticking to the mouse
+ window.onblur = function() {
+ isMoving = false;
+ };
+
+ gradioApp().addEventListener("mousemove", handleMoveByKey);
+ }
+
+ applyZoomAndPan(elementIDs.sketch);
+ applyZoomAndPan(elementIDs.inpaint);
+ applyZoomAndPan(elementIDs.inpaintSketch);
+
+ // Make the function global so that other extensions can take advantage of this solution
+ window.applyZoomAndPan = applyZoomAndPan;
+});
diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
new file mode 100644
index 00000000..380176ce
--- /dev/null
+++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py
@@ -0,0 +1,14 @@
+import gradio as gr
+from modules import shared
+
+shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
+ "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
+ "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
+ "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
+ "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
+ "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
+ "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
+ "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
+ "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
+ "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
+}))
diff --git a/extensions-builtin/canvas-zoom-and-pan/style.css b/extensions-builtin/canvas-zoom-and-pan/style.css
new file mode 100644
index 00000000..6bcc9570
--- /dev/null
+++ b/extensions-builtin/canvas-zoom-and-pan/style.css
@@ -0,0 +1,63 @@
+.canvas-tooltip-info {
+ position: absolute;
+ top: 10px;
+ left: 10px;
+ cursor: help;
+ background-color: rgba(0, 0, 0, 0.3);
+ width: 20px;
+ height: 20px;
+ border-radius: 50%;
+ display: flex;
+ align-items: center;
+ justify-content: center;
+ flex-direction: column;
+
+ z-index: 100;
+}
+
+.canvas-tooltip-info::after {
+ content: '';
+ display: block;
+ width: 2px;
+ height: 7px;
+ background-color: white;
+ margin-top: 2px;
+}
+
+.canvas-tooltip-info::before {
+ content: '';
+ display: block;
+ width: 2px;
+ height: 2px;
+ background-color: white;
+}
+
+.canvas-tooltip-content {
+ display: none;
+ background-color: #f9f9f9;
+ color: #333;
+ border: 1px solid #ddd;
+ padding: 15px;
+ position: absolute;
+ top: 40px;
+ left: 10px;
+ width: 250px;
+ font-size: 16px;
+ opacity: 0;
+ border-radius: 8px;
+ box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
+
+ z-index: 100;
+}
+
+.canvas-tooltip:hover .canvas-tooltip-content {
+ display: block;
+ animation: fadeIn 0.5s;
+ opacity: 1;
+}
+
+@keyframes fadeIn {
+ from {opacity: 0;}
+ to {opacity: 1;}
+}
+
diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
new file mode 100644
index 00000000..a05e10d8
--- /dev/null
+++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py
@@ -0,0 +1,48 @@
+import gradio as gr
+from modules import scripts, shared, ui_components, ui_settings
+from modules.ui_components import FormColumn
+
+
+class ExtraOptionsSection(scripts.Script):
+ section = "extra_options"
+
+ def __init__(self):
+ self.comps = None
+ self.setting_names = None
+
+ def title(self):
+ return "Extra options"
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def ui(self, is_img2img):
+ self.comps = []
+ self.setting_names = []
+
+ with gr.Blocks() as interface:
+ with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
+ for setting_name in shared.opts.extra_options:
+ with FormColumn():
+ comp = ui_settings.create_setting_component(setting_name)
+
+ self.comps.append(comp)
+ self.setting_names.append(setting_name)
+
+ def get_settings_values():
+ return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
+
+ interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
+
+ return self.comps
+
+ def before_process(self, p, *args):
+ for name, value in zip(self.setting_names, args):
+ if name not in p.override_settings:
+ p.override_settings[name] = value
+
+
+shared.options_templates.update(shared.options_section(('ui', "User interface"), {
+ "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
+ "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
+}))
diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
index 5c7a836a..114cf94c 100644
--- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
+++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js
@@ -4,39 +4,39 @@
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(textArea, counterElt) {
- var counts = {};
- (textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
- counts[bracket] = (counts[bracket] || 0) + 1;
- });
- var errors = [];
+ var counts = {};
+ (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
+ counts[bracket] = (counts[bracket] || 0) + 1;
+ });
+ var errors = [];
- function checkPair(open, close, kind) {
- if (counts[open] !== counts[close]) {
- errors.push(
- `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
- );
+ function checkPair(open, close, kind) {
+ if (counts[open] !== counts[close]) {
+ errors.push(
+ `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
+ );
+ }
}
- }
- checkPair('(', ')', 'round brackets');
- checkPair('[', ']', 'square brackets');
- checkPair('{', '}', 'curly brackets');
- counterElt.title = errors.join('\n');
- counterElt.classList.toggle('error', errors.length !== 0);
+ checkPair('(', ')', 'round brackets');
+ checkPair('[', ']', 'square brackets');
+ checkPair('{', '}', 'curly brackets');
+ counterElt.title = errors.join('\n');
+ counterElt.classList.toggle('error', errors.length !== 0);
}
function setupBracketChecking(id_prompt, id_counter) {
- var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
- var counter = gradioApp().getElementById(id_counter)
+ var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
+ var counter = gradioApp().getElementById(id_counter);
- if (textarea && counter) {
- textarea.addEventListener("input", () => checkBrackets(textarea, counter));
- }
+ if (textarea && counter) {
+ textarea.addEventListener("input", () => checkBrackets(textarea, counter));
+ }
}
-onUiLoaded(function () {
- setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
- setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
- setupBracketChecking('img2img_prompt', 'img2img_token_counter');
- setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
+onUiLoaded(function() {
+ setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
+ setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
+ setupBracketChecking('img2img_prompt', 'img2img_token_counter');
+ setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
});
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index 1d546217..68a84c3a 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -1,15 +1,14 @@
-<div class='card' style={style} onclick={card_clicked}>
+<div class='card' style={style} onclick={card_clicked} {sort_keys}>
+ {background_image}
{metadata_button}
-
<div class='actions'>
<div class='additional'>
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul>
- <span style="display:none" class='search_term{serach_only}'>{search_term}</span>
+ <span style="display:none" class='search_term{search_only}'>{search_term}</span>
</div>
<span class='name'>{name}</span>
<span class='description'>{description}</span>
</div>
</div>
-
diff --git a/html/footer.html b/html/footer.html
index bad87ff6..69b2372c 100644
--- a/html/footer.html
+++ b/html/footer.html
@@ -1,10 +1,12 @@
<div>
- <a href="/docs">API</a>
+ <a href="{api_docs}">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 
+ <a href="#" onclick="showProfile('./internal/profile-startup'); return false;">Startup profile</a>
+  • 
<a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
</div>
<br />
diff --git a/html/licenses.html b/html/licenses.html
index bc995aa0..ef6f2c0a 100644
--- a/html/licenses.html
+++ b/html/licenses.html
@@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
+</pre>
+
+<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
+<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
+<pre>
+MIT License
+
+Copyright (c) 2023 Ollin Boer Bohan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
</pre> \ No newline at end of file
diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js
index 5160081d..2cf2d571 100644
--- a/javascript/aspectRatioOverlay.js
+++ b/javascript/aspectRatioOverlay.js
@@ -1,111 +1,113 @@
-
-let currentWidth = null;
-let currentHeight = null;
-let arFrameTimeout = setTimeout(function(){},0);
-
-function dimensionChange(e, is_width, is_height){
-
- if(is_width){
- currentWidth = e.target.value*1.0
- }
- if(is_height){
- currentHeight = e.target.value*1.0
- }
-
- var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
-
- if(!inImg2img){
- return;
- }
-
- var targetElement = null;
-
- var tabIndex = get_tab_index('mode_img2img')
- if(tabIndex == 0){ // img2img
- targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
- } else if(tabIndex == 1){ //Sketch
- targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
- } else if(tabIndex == 2){ // Inpaint
- targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
- } else if(tabIndex == 3){ // Inpaint sketch
- targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
- }
-
-
- if(targetElement){
-
- var arPreviewRect = gradioApp().querySelector('#imageARPreview');
- if(!arPreviewRect){
- arPreviewRect = document.createElement('div')
- arPreviewRect.id = "imageARPreview";
- gradioApp().appendChild(arPreviewRect)
- }
-
-
-
- var viewportOffset = targetElement.getBoundingClientRect();
-
- var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
-
- var scaledx = targetElement.naturalWidth*viewportscale
- var scaledy = targetElement.naturalHeight*viewportscale
-
- var cleintRectTop = (viewportOffset.top+window.scrollY)
- var cleintRectLeft = (viewportOffset.left+window.scrollX)
- var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
- var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
-
- var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
- var arscaledx = currentWidth*arscale
- var arscaledy = currentHeight*arscale
-
- var arRectTop = cleintRectCentreY-(arscaledy/2)
- var arRectLeft = cleintRectCentreX-(arscaledx/2)
- var arRectWidth = arscaledx
- var arRectHeight = arscaledy
-
- arPreviewRect.style.top = arRectTop+'px';
- arPreviewRect.style.left = arRectLeft+'px';
- arPreviewRect.style.width = arRectWidth+'px';
- arPreviewRect.style.height = arRectHeight+'px';
-
- clearTimeout(arFrameTimeout);
- arFrameTimeout = setTimeout(function(){
- arPreviewRect.style.display = 'none';
- },2000);
-
- arPreviewRect.style.display = 'block';
-
- }
-
-}
-
-
-onUiUpdate(function(){
- var arPreviewRect = gradioApp().querySelector('#imageARPreview');
- if(arPreviewRect){
- arPreviewRect.style.display = 'none';
- }
- var tabImg2img = gradioApp().querySelector("#tab_img2img");
- if (tabImg2img) {
- var inImg2img = tabImg2img.style.display == "block";
- if(inImg2img){
- let inputs = gradioApp().querySelectorAll('input');
- inputs.forEach(function(e){
- var is_width = e.parentElement.id == "img2img_width"
- var is_height = e.parentElement.id == "img2img_height"
-
- if((is_width || is_height) && !e.classList.contains('scrollwatch')){
- e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
- e.classList.add('scrollwatch')
- }
- if(is_width){
- currentWidth = e.value*1.0
- }
- if(is_height){
- currentHeight = e.value*1.0
- }
- })
- }
- }
-});
+
+let currentWidth = null;
+let currentHeight = null;
+let arFrameTimeout = setTimeout(function() {}, 0);
+
+function dimensionChange(e, is_width, is_height) {
+
+ if (is_width) {
+ currentWidth = e.target.value * 1.0;
+ }
+ if (is_height) {
+ currentHeight = e.target.value * 1.0;
+ }
+
+ var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
+
+ if (!inImg2img) {
+ return;
+ }
+
+ var targetElement = null;
+
+ var tabIndex = get_tab_index('mode_img2img');
+ if (tabIndex == 0) { // img2img
+ targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
+ } else if (tabIndex == 1) { //Sketch
+ targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
+ } else if (tabIndex == 2) { // Inpaint
+ targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
+ } else if (tabIndex == 3) { // Inpaint sketch
+ targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
+ }
+
+
+ if (targetElement) {
+
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
+ if (!arPreviewRect) {
+ arPreviewRect = document.createElement('div');
+ arPreviewRect.id = "imageARPreview";
+ gradioApp().appendChild(arPreviewRect);
+ }
+
+
+
+ var viewportOffset = targetElement.getBoundingClientRect();
+
+ var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);
+
+ var scaledx = targetElement.naturalWidth * viewportscale;
+ var scaledy = targetElement.naturalHeight * viewportscale;
+
+ var cleintRectTop = (viewportOffset.top + window.scrollY);
+ var cleintRectLeft = (viewportOffset.left + window.scrollX);
+ var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
+ var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
+
+ var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
+ var arscaledx = currentWidth * arscale;
+ var arscaledy = currentHeight * arscale;
+
+ var arRectTop = cleintRectCentreY - (arscaledy / 2);
+ var arRectLeft = cleintRectCentreX - (arscaledx / 2);
+ var arRectWidth = arscaledx;
+ var arRectHeight = arscaledy;
+
+ arPreviewRect.style.top = arRectTop + 'px';
+ arPreviewRect.style.left = arRectLeft + 'px';
+ arPreviewRect.style.width = arRectWidth + 'px';
+ arPreviewRect.style.height = arRectHeight + 'px';
+
+ clearTimeout(arFrameTimeout);
+ arFrameTimeout = setTimeout(function() {
+ arPreviewRect.style.display = 'none';
+ }, 2000);
+
+ arPreviewRect.style.display = 'block';
+
+ }
+
+}
+
+
+onAfterUiUpdate(function() {
+ var arPreviewRect = gradioApp().querySelector('#imageARPreview');
+ if (arPreviewRect) {
+ arPreviewRect.style.display = 'none';
+ }
+ var tabImg2img = gradioApp().querySelector("#tab_img2img");
+ if (tabImg2img) {
+ var inImg2img = tabImg2img.style.display == "block";
+ if (inImg2img) {
+ let inputs = gradioApp().querySelectorAll('input');
+ inputs.forEach(function(e) {
+ var is_width = e.parentElement.id == "img2img_width";
+ var is_height = e.parentElement.id == "img2img_height";
+
+ if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {
+ e.addEventListener('input', function(e) {
+ dimensionChange(e, is_width, is_height);
+ });
+ e.classList.add('scrollwatch');
+ }
+ if (is_width) {
+ currentWidth = e.value * 1.0;
+ }
+ if (is_height) {
+ currentHeight = e.value * 1.0;
+ }
+ });
+ }
+ }
+});
diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js
index b2bdf053..ccae242f 100644
--- a/javascript/contextMenus.js
+++ b/javascript/contextMenus.js
@@ -1,166 +1,176 @@
-
-contextMenuInit = function(){
- let eventListenerApplied=false;
- let menuSpecs = new Map();
-
- const uid = function(){
- return Date.now().toString(36) + Math.random().toString(36).substring(2);
- }
-
- function showContextMenu(event,element,menuEntries){
- let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
- let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
-
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
-
- let baseStyle = window.getComputedStyle(uiCurrentTab)
-
- const contextMenu = document.createElement('nav')
- contextMenu.id = "context-menu"
- contextMenu.style.background = baseStyle.background
- contextMenu.style.color = baseStyle.color
- contextMenu.style.fontFamily = baseStyle.fontFamily
- contextMenu.style.top = posy+'px'
- contextMenu.style.left = posx+'px'
-
-
-
- const contextMenuList = document.createElement('ul')
- contextMenuList.className = 'context-menu-items';
- contextMenu.append(contextMenuList);
-
- menuEntries.forEach(function(entry){
- let contextMenuEntry = document.createElement('a')
- contextMenuEntry.innerHTML = entry['name']
- contextMenuEntry.addEventListener("click", function() {
- entry['func']();
- })
- contextMenuList.append(contextMenuEntry);
-
- })
-
- gradioApp().appendChild(contextMenu)
-
- let menuWidth = contextMenu.offsetWidth + 4;
- let menuHeight = contextMenu.offsetHeight + 4;
-
- let windowWidth = window.innerWidth;
- let windowHeight = window.innerHeight;
-
- if ( (windowWidth - posx) < menuWidth ) {
- contextMenu.style.left = windowWidth - menuWidth + "px";
- }
-
- if ( (windowHeight - posy) < menuHeight ) {
- contextMenu.style.top = windowHeight - menuHeight + "px";
- }
-
- }
-
- function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
-
- var currentItems = menuSpecs.get(targetElementSelector)
-
- if(!currentItems){
- currentItems = []
- menuSpecs.set(targetElementSelector,currentItems);
- }
- let newItem = {'id':targetElementSelector+'_'+uid(),
- 'name':entryName,
- 'func':entryFunction,
- 'isNew':true}
-
- currentItems.push(newItem)
- return newItem['id']
- }
-
- function removeContextMenuOption(uid){
- menuSpecs.forEach(function(v) {
- let index = -1
- v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
- if(index>=0){
- v.splice(index, 1);
- }
- })
- }
-
- function addContextMenuEventListener(){
- if(eventListenerApplied){
- return;
- }
- gradioApp().addEventListener("click", function(e) {
- if(! e.isTrusted){
- return
- }
-
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
- });
- gradioApp().addEventListener("contextmenu", function(e) {
- let oldMenu = gradioApp().querySelector('#context-menu')
- if(oldMenu){
- oldMenu.remove()
- }
- menuSpecs.forEach(function(v,k) {
- if(e.composedPath()[0].matches(k)){
- showContextMenu(e,e.composedPath()[0],v)
- e.preventDefault()
- }
- })
- });
- eventListenerApplied=true
-
- }
-
- return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
-}
-
-initResponse = contextMenuInit();
-appendContextMenuOption = initResponse[0];
-removeContextMenuOption = initResponse[1];
-addContextMenuEventListener = initResponse[2];
-
-(function(){
- //Start example Context Menu Items
- let generateOnRepeat = function(genbuttonid,interruptbuttonid){
- let genbutton = gradioApp().querySelector(genbuttonid);
- let interruptbutton = gradioApp().querySelector(interruptbuttonid);
- if(!interruptbutton.offsetParent){
- genbutton.click();
- }
- clearInterval(window.generateOnRepeatInterval)
- window.generateOnRepeatInterval = setInterval(function(){
- if(!interruptbutton.offsetParent){
- genbutton.click();
- }
- },
- 500)
- }
-
- appendContextMenuOption('#txt2img_generate','Generate forever',function(){
- generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
- })
- appendContextMenuOption('#img2img_generate','Generate forever',function(){
- generateOnRepeat('#img2img_generate','#img2img_interrupt');
- })
-
- let cancelGenerateForever = function(){
- clearInterval(window.generateOnRepeatInterval)
- }
-
- appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
- appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
-
-})();
-//End example Context Menu Items
-
-onUiUpdate(function(){
- addContextMenuEventListener()
-});
+
+var contextMenuInit = function() {
+ let eventListenerApplied = false;
+ let menuSpecs = new Map();
+
+ const uid = function() {
+ return Date.now().toString(36) + Math.random().toString(36).substring(2);
+ };
+
+ function showContextMenu(event, element, menuEntries) {
+ let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
+ let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
+
+ let oldMenu = gradioApp().querySelector('#context-menu');
+ if (oldMenu) {
+ oldMenu.remove();
+ }
+
+ let baseStyle = window.getComputedStyle(uiCurrentTab);
+
+ const contextMenu = document.createElement('nav');
+ contextMenu.id = "context-menu";
+ contextMenu.style.background = baseStyle.background;
+ contextMenu.style.color = baseStyle.color;
+ contextMenu.style.fontFamily = baseStyle.fontFamily;
+ contextMenu.style.top = posy + 'px';
+ contextMenu.style.left = posx + 'px';
+
+
+
+ const contextMenuList = document.createElement('ul');
+ contextMenuList.className = 'context-menu-items';
+ contextMenu.append(contextMenuList);
+
+ menuEntries.forEach(function(entry) {
+ let contextMenuEntry = document.createElement('a');
+ contextMenuEntry.innerHTML = entry['name'];
+ contextMenuEntry.addEventListener("click", function() {
+ entry['func']();
+ });
+ contextMenuList.append(contextMenuEntry);
+
+ });
+
+ gradioApp().appendChild(contextMenu);
+
+ let menuWidth = contextMenu.offsetWidth + 4;
+ let menuHeight = contextMenu.offsetHeight + 4;
+
+ let windowWidth = window.innerWidth;
+ let windowHeight = window.innerHeight;
+
+ if ((windowWidth - posx) < menuWidth) {
+ contextMenu.style.left = windowWidth - menuWidth + "px";
+ }
+
+ if ((windowHeight - posy) < menuHeight) {
+ contextMenu.style.top = windowHeight - menuHeight + "px";
+ }
+
+ }
+
+ function appendContextMenuOption(targetElementSelector, entryName, entryFunction) {
+
+ var currentItems = menuSpecs.get(targetElementSelector);
+
+ if (!currentItems) {
+ currentItems = [];
+ menuSpecs.set(targetElementSelector, currentItems);
+ }
+ let newItem = {
+ id: targetElementSelector + '_' + uid(),
+ name: entryName,
+ func: entryFunction,
+ isNew: true
+ };
+
+ currentItems.push(newItem);
+ return newItem['id'];
+ }
+
+ function removeContextMenuOption(uid) {
+ menuSpecs.forEach(function(v) {
+ let index = -1;
+ v.forEach(function(e, ei) {
+ if (e['id'] == uid) {
+ index = ei;
+ }
+ });
+ if (index >= 0) {
+ v.splice(index, 1);
+ }
+ });
+ }
+
+ function addContextMenuEventListener() {
+ if (eventListenerApplied) {
+ return;
+ }
+ gradioApp().addEventListener("click", function(e) {
+ if (!e.isTrusted) {
+ return;
+ }
+
+ let oldMenu = gradioApp().querySelector('#context-menu');
+ if (oldMenu) {
+ oldMenu.remove();
+ }
+ });
+ gradioApp().addEventListener("contextmenu", function(e) {
+ let oldMenu = gradioApp().querySelector('#context-menu');
+ if (oldMenu) {
+ oldMenu.remove();
+ }
+ menuSpecs.forEach(function(v, k) {
+ if (e.composedPath()[0].matches(k)) {
+ showContextMenu(e, e.composedPath()[0], v);
+ e.preventDefault();
+ }
+ });
+ });
+ eventListenerApplied = true;
+
+ }
+
+ return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener];
+};
+
+var initResponse = contextMenuInit();
+var appendContextMenuOption = initResponse[0];
+var removeContextMenuOption = initResponse[1];
+var addContextMenuEventListener = initResponse[2];
+
+(function() {
+ //Start example Context Menu Items
+ let generateOnRepeat = function(genbuttonid, interruptbuttonid) {
+ let genbutton = gradioApp().querySelector(genbuttonid);
+ let interruptbutton = gradioApp().querySelector(interruptbuttonid);
+ if (!interruptbutton.offsetParent) {
+ genbutton.click();
+ }
+ clearInterval(window.generateOnRepeatInterval);
+ window.generateOnRepeatInterval = setInterval(function() {
+ if (!interruptbutton.offsetParent) {
+ genbutton.click();
+ }
+ },
+ 500);
+ };
+
+ let generateOnRepeat_txt2img = function() {
+ generateOnRepeat('#txt2img_generate', '#txt2img_interrupt');
+ };
+
+ let generateOnRepeat_img2img = function() {
+ generateOnRepeat('#img2img_generate', '#img2img_interrupt');
+ };
+
+ appendContextMenuOption('#txt2img_generate', 'Generate forever', generateOnRepeat_txt2img);
+ appendContextMenuOption('#txt2img_interrupt', 'Generate forever', generateOnRepeat_txt2img);
+ appendContextMenuOption('#img2img_generate', 'Generate forever', generateOnRepeat_img2img);
+ appendContextMenuOption('#img2img_interrupt', 'Generate forever', generateOnRepeat_img2img);
+
+ let cancelGenerateForever = function() {
+ clearInterval(window.generateOnRepeatInterval);
+ };
+
+ appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever);
+ appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever);
+ appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever);
+ appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever);
+
+})();
+//End example Context Menu Items
+
+onAfterUiUpdate(addContextMenuEventListener);
diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js
index fe008924..5803daea 100644
--- a/javascript/dragdrop.js
+++ b/javascript/dragdrop.js
@@ -1,11 +1,11 @@
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard
-function isValidImageList( files ) {
+function isValidImageList(files) {
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
}
-function dropReplaceImage( imgWrap, files ) {
- if ( ! isValidImageList( files ) ) {
+function dropReplaceImage(imgWrap, files) {
+ if (!isValidImageList(files)) {
return;
}
@@ -14,46 +14,61 @@ function dropReplaceImage( imgWrap, files ) {
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
const callback = () => {
const fileInput = imgWrap.querySelector('input[type="file"]');
- if ( fileInput ) {
- if ( files.length === 0 ) {
+ if (fileInput) {
+ if (files.length === 0) {
files = new DataTransfer();
files.items.add(tmpFile);
fileInput.files = files.files;
} else {
fileInput.files = files;
}
- fileInput.dispatchEvent(new Event('change'));
+ fileInput.dispatchEvent(new Event('change'));
}
};
-
- if ( imgWrap.closest('#pnginfo_image') ) {
+
+ if (imgWrap.closest('#pnginfo_image')) {
// special treatment for PNG Info tab, wait for fetch request to finish
const oldFetch = window.fetch;
- window.fetch = async (input, options) => {
+ window.fetch = async(input, options) => {
const response = await oldFetch(input, options);
- if ( 'api/predict/' === input ) {
+ if ('api/predict/' === input) {
const content = await response.text();
window.fetch = oldFetch;
- window.requestAnimationFrame( () => callback() );
+ window.requestAnimationFrame(() => callback());
return new Response(content, {
status: response.status,
statusText: response.statusText,
headers: response.headers
- })
+ });
}
return response;
- };
+ };
} else {
- window.requestAnimationFrame( () => callback() );
+ window.requestAnimationFrame(() => callback());
}
}
+function eventHasFiles(e) {
+ if (!e.dataTransfer || !e.dataTransfer.files) return false;
+ if (e.dataTransfer.files.length > 0) return true;
+ if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true;
+
+ return false;
+}
+
+function dragDropTargetIsPrompt(target) {
+ if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true;
+ if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true;
+ return false;
+}
+
window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0];
- const imgWrap = target.closest('[data-testid="image"]');
- if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
- return;
- }
+ if (!eventHasFiles(e)) return;
+
+ var targetImage = target.closest('[data-testid="image"]');
+ if (!dragDropTargetIsPrompt(target) && !targetImage) return;
+
e.stopPropagation();
e.preventDefault();
e.dataTransfer.dropEffect = 'copy';
@@ -61,37 +76,55 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => {
const target = e.composedPath()[0];
- if (target.placeholder.indexOf("Prompt") == -1) {
- return;
+ if (!eventHasFiles(e)) return;
+
+ if (dragDropTargetIsPrompt(target)) {
+ e.stopPropagation();
+ e.preventDefault();
+
+ let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
+
+ const imgParent = gradioApp().getElementById(prompt_target);
+ const files = e.dataTransfer.files;
+ const fileInput = imgParent.querySelector('input[type="file"]');
+ if (fileInput) {
+ fileInput.files = files;
+ fileInput.dispatchEvent(new Event('change'));
+ }
}
- const imgWrap = target.closest('[data-testid="image"]');
- if ( !imgWrap ) {
+
+ var targetImage = target.closest('[data-testid="image"]');
+ if (targetImage) {
+ e.stopPropagation();
+ e.preventDefault();
+ const files = e.dataTransfer.files;
+ dropReplaceImage(targetImage, files);
return;
}
- e.stopPropagation();
- e.preventDefault();
- const files = e.dataTransfer.files;
- dropReplaceImage( imgWrap, files );
});
window.addEventListener('paste', e => {
const files = e.clipboardData.files;
- if ( ! isValidImageList( files ) ) {
+ if (!isValidImageList(files)) {
return;
}
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
- .filter(el => uiElementIsVisible(el));
- if ( ! visibleImageFields.length ) {
+ .filter(el => uiElementIsVisible(el))
+ .sort((a, b) => uiElementInSight(b) - uiElementInSight(a));
+
+
+ if (!visibleImageFields.length) {
return;
}
-
+
const firstFreeImageField = visibleImageFields
.filter(el => el.querySelector('input[type=file]'))?.[0];
dropReplaceImage(
firstFreeImageField ?
- firstFreeImageField :
- visibleImageFields[visibleImageFields.length - 1]
- , files );
+ firstFreeImageField :
+ visibleImageFields[visibleImageFields.length - 1]
+ , files
+ );
});
diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js
index d2c2f190..8906c892 100644
--- a/javascript/edit-attention.js
+++ b/javascript/edit-attention.js
@@ -1,120 +1,121 @@
-function keyupEditAttention(event){
- let target = event.originalTarget || event.composedPath()[0];
- if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
- if (! (event.metaKey || event.ctrlKey)) return;
-
- let isPlus = event.key == "ArrowUp"
- let isMinus = event.key == "ArrowDown"
- if (!isPlus && !isMinus) return;
-
- let selectionStart = target.selectionStart;
- let selectionEnd = target.selectionEnd;
- let text = target.value;
-
- function selectCurrentParenthesisBlock(OPEN, CLOSE){
- if (selectionStart !== selectionEnd) return false;
-
- // Find opening parenthesis around current cursor
- const before = text.substring(0, selectionStart);
- let beforeParen = before.lastIndexOf(OPEN);
- if (beforeParen == -1) return false;
- let beforeParenClose = before.lastIndexOf(CLOSE);
- while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
- beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
- beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
- }
-
- // Find closing parenthesis around current cursor
- const after = text.substring(selectionStart);
- let afterParen = after.indexOf(CLOSE);
- if (afterParen == -1) return false;
- let afterParenOpen = after.indexOf(OPEN);
- while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
- afterParen = after.indexOf(CLOSE, afterParen + 1);
- afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
- }
- if (beforeParen === -1 || afterParen === -1) return false;
-
- // Set the selection to the text between the parenthesis
- const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
- const lastColon = parenContent.lastIndexOf(":");
- selectionStart = beforeParen + 1;
- selectionEnd = selectionStart + lastColon;
- target.setSelectionRange(selectionStart, selectionEnd);
- return true;
- }
-
- function selectCurrentWord(){
- if (selectionStart !== selectionEnd) return false;
- const delimiters = opts.keyedit_delimiters + " \r\n\t";
-
- // seek backward until to find beggining
- while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
- selectionStart--;
- }
-
- // seek forward to find end
- while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
- selectionEnd++;
- }
-
- target.setSelectionRange(selectionStart, selectionEnd);
- return true;
- }
-
- // If the user hasn't selected anything, let's select their current parenthesis block or word
- if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
- selectCurrentWord();
- }
-
- event.preventDefault();
-
- var closeCharacter = ')'
- var delta = opts.keyedit_precision_attention
-
- if (selectionStart > 0 && text[selectionStart - 1] == '<'){
- closeCharacter = '>'
- delta = opts.keyedit_precision_extra
- } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
-
- // do not include spaces at the end
- while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
- selectionEnd -= 1;
- }
- if(selectionStart == selectionEnd){
- return
- }
-
- text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
-
- selectionStart += 1;
- selectionEnd += 1;
- }
-
- var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
- var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
- if (isNaN(weight)) return;
-
- weight += isPlus ? delta : -delta;
- weight = parseFloat(weight.toPrecision(12));
- if(String(weight).length == 1) weight += ".0"
-
- if (closeCharacter == ')' && weight == 1) {
- text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
- selectionStart--;
- selectionEnd--;
- } else {
- text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
- }
-
- target.focus();
- target.value = text;
- target.selectionStart = selectionStart;
- target.selectionEnd = selectionEnd;
-
- updateInput(target)
-}
-
-addEventListener('keydown', (event) => {
- keyupEditAttention(event);
-});
+function keyupEditAttention(event) {
+ let target = event.originalTarget || event.composedPath()[0];
+ if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
+ if (!(event.metaKey || event.ctrlKey)) return;
+
+ let isPlus = event.key == "ArrowUp";
+ let isMinus = event.key == "ArrowDown";
+ if (!isPlus && !isMinus) return;
+
+ let selectionStart = target.selectionStart;
+ let selectionEnd = target.selectionEnd;
+ let text = target.value;
+
+ function selectCurrentParenthesisBlock(OPEN, CLOSE) {
+ if (selectionStart !== selectionEnd) return false;
+
+ // Find opening parenthesis around current cursor
+ const before = text.substring(0, selectionStart);
+ let beforeParen = before.lastIndexOf(OPEN);
+ if (beforeParen == -1) return false;
+ let beforeParenClose = before.lastIndexOf(CLOSE);
+ while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
+ beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
+ beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
+ }
+
+ // Find closing parenthesis around current cursor
+ const after = text.substring(selectionStart);
+ let afterParen = after.indexOf(CLOSE);
+ if (afterParen == -1) return false;
+ let afterParenOpen = after.indexOf(OPEN);
+ while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
+ afterParen = after.indexOf(CLOSE, afterParen + 1);
+ afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
+ }
+ if (beforeParen === -1 || afterParen === -1) return false;
+
+ // Set the selection to the text between the parenthesis
+ const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
+ const lastColon = parenContent.lastIndexOf(":");
+ selectionStart = beforeParen + 1;
+ selectionEnd = selectionStart + lastColon;
+ target.setSelectionRange(selectionStart, selectionEnd);
+ return true;
+ }
+
+ function selectCurrentWord() {
+ if (selectionStart !== selectionEnd) return false;
+ const delimiters = opts.keyedit_delimiters + " \r\n\t";
+
+ // seek backward until to find beggining
+ while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
+ selectionStart--;
+ }
+
+ // seek forward to find end
+ while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
+ selectionEnd++;
+ }
+
+ target.setSelectionRange(selectionStart, selectionEnd);
+ return true;
+ }
+
+ // If the user hasn't selected anything, let's select their current parenthesis block or word
+ if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
+ selectCurrentWord();
+ }
+
+ event.preventDefault();
+
+ var closeCharacter = ')';
+ var delta = opts.keyedit_precision_attention;
+
+ if (selectionStart > 0 && text[selectionStart - 1] == '<') {
+ closeCharacter = '>';
+ delta = opts.keyedit_precision_extra;
+ } else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
+
+ // do not include spaces at the end
+ while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
+ selectionEnd -= 1;
+ }
+ if (selectionStart == selectionEnd) {
+ return;
+ }
+
+ text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
+
+ selectionStart += 1;
+ selectionEnd += 1;
+ }
+
+ var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
+ var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
+ if (isNaN(weight)) return;
+
+ weight += isPlus ? delta : -delta;
+ weight = parseFloat(weight.toPrecision(12));
+ if (String(weight).length == 1) weight += ".0";
+
+ if (closeCharacter == ')' && weight == 1) {
+ var endParenPos = text.substring(selectionEnd).indexOf(')');
+ text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
+ selectionStart--;
+ selectionEnd--;
+ } else {
+ text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
+ }
+
+ target.focus();
+ target.value = text;
+ target.selectionStart = selectionStart;
+ target.selectionEnd = selectionEnd;
+
+ updateInput(target);
+}
+
+addEventListener('keydown', (event) => {
+ keyupEditAttention(event);
+});
diff --git a/javascript/edit-order.js b/javascript/edit-order.js
new file mode 100644
index 00000000..ad983d33
--- /dev/null
+++ b/javascript/edit-order.js
@@ -0,0 +1,41 @@
+/* alt+left/right moves text in prompt */
+
+function keyupEditOrder(event) {
+ if (!opts.keyedit_move) return;
+
+ let target = event.originalTarget || event.composedPath()[0];
+ if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
+ if (!event.altKey) return;
+ event.preventDefault()
+
+ let isLeft = event.key == "ArrowLeft";
+ let isRight = event.key == "ArrowRight";
+ if (!isLeft && !isRight) return;
+
+ let selectionStart = target.selectionStart;
+ let selectionEnd = target.selectionEnd;
+ let text = target.value;
+ let items = text.split(",");
+ let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length;
+ let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length;
+ let range = indexEnd - indexStart + 1;
+
+ if (isLeft && indexStart > 0) {
+ items.splice(indexStart - 1, 0, ...items.splice(indexStart, range));
+ target.value = items.join();
+ target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1);
+ target.selectionEnd = items.slice(0, indexEnd).join().length;
+ } else if (isRight && indexEnd < items.length - 1) {
+ items.splice(indexStart + 1, 0, ...items.splice(indexStart, range));
+ target.value = items.join();
+ target.selectionStart = items.slice(0, indexStart + 1).join().length + 1;
+ target.selectionEnd = items.slice(0, indexEnd + 2).join().length;
+ }
+
+ event.preventDefault();
+ updateInput(target);
+}
+
+addEventListener('keydown', (event) => {
+ keyupEditOrder(event);
+});
diff --git a/javascript/extensions.js b/javascript/extensions.js
index 2a2d2f8e..1f7254c5 100644
--- a/javascript/extensions.js
+++ b/javascript/extensions.js
@@ -1,71 +1,92 @@
-
-function extensions_apply(_disabled_list, _update_list, disable_all){
- var disable = []
- var update = []
-
- gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
- if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substring(7))
-
- if(x.name.startsWith("update_") && x.checked)
- update.push(x.name.substring(7))
- })
-
- restart_reload()
-
- return [JSON.stringify(disable), JSON.stringify(update), disable_all]
-}
-
-function extensions_check(){
- var disable = []
-
- gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
- if(x.name.startsWith("enable_") && ! x.checked)
- disable.push(x.name.substring(7))
- })
-
- gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
- x.innerHTML = "Loading..."
- })
-
-
- var id = randomId()
- requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
-
- })
-
- return [id, JSON.stringify(disable)]
-}
-
-function install_extension_from_index(button, url){
- button.disabled = "disabled"
- button.value = "Installing..."
-
- var textarea = gradioApp().querySelector('#extension_to_install textarea')
- textarea.value = url
- updateInput(textarea)
-
- gradioApp().querySelector('#install_extension_button').click()
-}
-
-function config_state_confirm_restore(_, config_state_name, config_restore_type) {
- if (config_state_name == "Current") {
- return [false, config_state_name, config_restore_type];
- }
- let restored = "";
- if (config_restore_type == "extensions") {
- restored = "all saved extension versions";
- } else if (config_restore_type == "webui") {
- restored = "the webui version";
- } else {
- restored = "the webui version and all saved extension versions";
- }
- let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
- if (confirmed) {
- restart_reload();
- gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
- x.innerHTML = "Loading..."
- })
- }
- return [confirmed, config_state_name, config_restore_type];
-}
+
+function extensions_apply(_disabled_list, _update_list, disable_all) {
+ var disable = [];
+ var update = [];
+
+ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
+ if (x.name.startsWith("enable_") && !x.checked) {
+ disable.push(x.name.substring(7));
+ }
+
+ if (x.name.startsWith("update_") && x.checked) {
+ update.push(x.name.substring(7));
+ }
+ });
+
+ restart_reload();
+
+ return [JSON.stringify(disable), JSON.stringify(update), disable_all];
+}
+
+function extensions_check() {
+ var disable = [];
+
+ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
+ if (x.name.startsWith("enable_") && !x.checked) {
+ disable.push(x.name.substring(7));
+ }
+ });
+
+ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
+ x.innerHTML = "Loading...";
+ });
+
+
+ var id = randomId();
+ requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
+
+ });
+
+ return [id, JSON.stringify(disable)];
+}
+
+function install_extension_from_index(button, url) {
+ button.disabled = "disabled";
+ button.value = "Installing...";
+
+ var textarea = gradioApp().querySelector('#extension_to_install textarea');
+ textarea.value = url;
+ updateInput(textarea);
+
+ gradioApp().querySelector('#install_extension_button').click();
+}
+
+function config_state_confirm_restore(_, config_state_name, config_restore_type) {
+ if (config_state_name == "Current") {
+ return [false, config_state_name, config_restore_type];
+ }
+ let restored = "";
+ if (config_restore_type == "extensions") {
+ restored = "all saved extension versions";
+ } else if (config_restore_type == "webui") {
+ restored = "the webui version";
+ } else {
+ restored = "the webui version and all saved extension versions";
+ }
+ let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
+ if (confirmed) {
+ restart_reload();
+ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
+ x.innerHTML = "Loading...";
+ });
+ }
+ return [confirmed, config_state_name, config_restore_type];
+}
+
+function toggle_all_extensions(event) {
+ gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
+ checkbox_el.checked = event.target.checked;
+ });
+}
+
+function toggle_extension() {
+ let all_extensions_toggled = true;
+ for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
+ if (!checkbox_el.checked) {
+ all_extensions_toggled = false;
+ break;
+ }
+ }
+
+ gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
+}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index c85bc79a..b87bca3e 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -1,196 +1,265 @@
-function setupExtraNetworksForTab(tabname){
- gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
-
- var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
- var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
- var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
-
- search.classList.add('search')
- tabs.appendChild(search)
- tabs.appendChild(refresh)
-
- var applyFilter = function(){
- var searchTerm = search.value.toLowerCase()
-
- gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
- var searchOnly = elem.querySelector('.search_only')
- var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
-
- var visible = text.indexOf(searchTerm) != -1
-
- if(searchOnly && searchTerm.length < 4){
- visible = false
- }
-
- elem.style.display = visible ? "" : "none"
- })
- }
-
- search.addEventListener("input", applyFilter);
- applyFilter();
-
- extraNetworksApplyFilter[tabname] = applyFilter;
-}
-
-function applyExtraNetworkFilter(tabname){
- setTimeout(extraNetworksApplyFilter[tabname], 1);
-}
-
-var extraNetworksApplyFilter = {}
-var activePromptTextarea = {};
-
-function setupExtraNetworks(){
- setupExtraNetworksForTab('txt2img')
- setupExtraNetworksForTab('img2img')
-
- function registerPrompt(tabname, id){
- var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
-
- if (! activePromptTextarea[tabname]){
- activePromptTextarea[tabname] = textarea
- }
-
- textarea.addEventListener("focus", function(){
- activePromptTextarea[tabname] = textarea;
- });
- }
-
- registerPrompt('txt2img', 'txt2img_prompt')
- registerPrompt('txt2img', 'txt2img_neg_prompt')
- registerPrompt('img2img', 'img2img_prompt')
- registerPrompt('img2img', 'img2img_neg_prompt')
-}
-
-onUiLoaded(setupExtraNetworks)
-
-var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
-var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
-
-function tryToRemoveExtraNetworkFromPrompt(textarea, text){
- var m = text.match(re_extranet)
- if(! m) return false
-
- var partToSearch = m[1]
- var replaced = false
- var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
- m = found.match(re_extranet);
- if(m[1] == partToSearch){
- replaced = true;
- return ""
- }
- return found;
- })
-
- if(replaced){
- textarea.value = newTextareaText
- return true;
- }
-
- return false
-}
-
-function cardClicked(tabname, textToAdd, allowNegativePrompt){
- var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
-
- if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
- textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
- }
-
- updateInput(textarea)
-}
-
-function saveCardPreview(event, tabname, filename){
- var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
- var button = gradioApp().getElementById(tabname + '_save_preview')
-
- textarea.value = filename
- updateInput(textarea)
-
- button.click()
-
- event.stopPropagation()
- event.preventDefault()
-}
-
-function extraNetworksSearchButton(tabs_id, event){
- var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
- var button = event.target
- var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
-
- searchTextarea.value = text
- updateInput(searchTextarea)
-}
-
-var globalPopup = null;
-var globalPopupInner = null;
-function popup(contents){
- if(! globalPopup){
- globalPopup = document.createElement('div')
- globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
- globalPopup.classList.add('global-popup');
-
- var close = document.createElement('div')
- close.classList.add('global-popup-close');
- close.onclick = function(){ globalPopup.style.display = "none"; };
- close.title = "Close";
- globalPopup.appendChild(close)
-
- globalPopupInner = document.createElement('div')
- globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
- globalPopupInner.classList.add('global-popup-inner');
- globalPopup.appendChild(globalPopupInner)
-
- gradioApp().appendChild(globalPopup);
- }
-
- globalPopupInner.innerHTML = '';
- globalPopupInner.appendChild(contents);
-
- globalPopup.style.display = "flex";
-}
-
-function extraNetworksShowMetadata(text){
- var elem = document.createElement('pre')
- elem.classList.add('popup-metadata');
- elem.textContent = text;
-
- popup(elem);
-}
-
-function requestGet(url, data, handler, errorHandler){
- var xhr = new XMLHttpRequest();
- var args = Object.keys(data).map(function(k){ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]) }).join('&')
- xhr.open("GET", url + "?" + args, true);
-
- xhr.onreadystatechange = function () {
- if (xhr.readyState === 4) {
- if (xhr.status === 200) {
- try {
- var js = JSON.parse(xhr.responseText);
- handler(js)
- } catch (error) {
- console.error(error);
- errorHandler()
- }
- } else{
- errorHandler()
- }
- }
- };
- var js = JSON.stringify(data);
- xhr.send(js);
-}
-
-function extraNetworksRequestMetadata(event, extraPage, cardName){
- var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
-
- requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
- if(data && data.metadata){
- extraNetworksShowMetadata(data.metadata)
- } else{
- showError()
- }
- }, showError)
-
- event.stopPropagation()
-}
+function setupExtraNetworksForTab(tabname) {
+ gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
+
+ var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
+ var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
+ var sort = gradioApp().getElementById(tabname + '_extra_sort');
+ var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
+ var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
+
+ search.classList.add('search');
+ sort.classList.add('sort');
+ sortOrder.classList.add('sortorder');
+ sort.dataset.sortkey = 'sortDefault';
+ tabs.appendChild(search);
+ tabs.appendChild(sort);
+ tabs.appendChild(sortOrder);
+ tabs.appendChild(refresh);
+
+ var applyFilter = function() {
+ var searchTerm = search.value.toLowerCase();
+
+ gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
+ var searchOnly = elem.querySelector('.search_only');
+ var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase();
+
+ var visible = text.indexOf(searchTerm) != -1;
+
+ if (searchOnly && searchTerm.length < 4) {
+ visible = false;
+ }
+
+ elem.style.display = visible ? "" : "none";
+ });
+ };
+
+ var applySort = function() {
+ var reverse = sortOrder.classList.contains("sortReverse");
+ var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim();
+ sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : "";
+ var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : "";
+ if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
+ return;
+ }
+
+ sort.dataset.sortkey = sortKeyStore;
+
+ var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
+ cards.forEach(function(card) {
+ card.originalParentElement = card.parentElement;
+ });
+ var sortedCards = Array.from(cards);
+ sortedCards.sort(function(cardA, cardB) {
+ var a = cardA.dataset[sortKey];
+ var b = cardB.dataset[sortKey];
+ if (!isNaN(a) && !isNaN(b)) {
+ return parseInt(a) - parseInt(b);
+ }
+
+ return (a < b ? -1 : (a > b ? 1 : 0));
+ });
+ if (reverse) {
+ sortedCards.reverse();
+ }
+ cards.forEach(function(card) {
+ card.remove();
+ });
+ sortedCards.forEach(function(card) {
+ card.originalParentElement.appendChild(card);
+ });
+ };
+
+ search.addEventListener("input", applyFilter);
+ applyFilter();
+ ["change", "blur", "click"].forEach(function(evt) {
+ sort.querySelector("input").addEventListener(evt, applySort);
+ });
+ sortOrder.addEventListener("click", function() {
+ sortOrder.classList.toggle("sortReverse");
+ applySort();
+ });
+
+ extraNetworksApplyFilter[tabname] = applyFilter;
+}
+
+function applyExtraNetworkFilter(tabname) {
+ setTimeout(extraNetworksApplyFilter[tabname], 1);
+}
+
+var extraNetworksApplyFilter = {};
+var activePromptTextarea = {};
+
+function setupExtraNetworks() {
+ setupExtraNetworksForTab('txt2img');
+ setupExtraNetworksForTab('img2img');
+
+ function registerPrompt(tabname, id) {
+ var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
+
+ if (!activePromptTextarea[tabname]) {
+ activePromptTextarea[tabname] = textarea;
+ }
+
+ textarea.addEventListener("focus", function() {
+ activePromptTextarea[tabname] = textarea;
+ });
+ }
+
+ registerPrompt('txt2img', 'txt2img_prompt');
+ registerPrompt('txt2img', 'txt2img_neg_prompt');
+ registerPrompt('img2img', 'img2img_prompt');
+ registerPrompt('img2img', 'img2img_neg_prompt');
+}
+
+onUiLoaded(setupExtraNetworks);
+
+var re_extranet = /<([^:]+:[^:]+):[\d.]+>/;
+var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
+
+function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
+ var m = text.match(re_extranet);
+ var replaced = false;
+ var newTextareaText;
+ if (m) {
+ var partToSearch = m[1];
+ newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found) {
+ m = found.match(re_extranet);
+ if (m[1] == partToSearch) {
+ replaced = true;
+ return "";
+ }
+ return found;
+ });
+ } else {
+ newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
+ if (found == text) {
+ replaced = true;
+ return "";
+ }
+ return found;
+ });
+ }
+
+ if (replaced) {
+ textarea.value = newTextareaText;
+ return true;
+ }
+
+ return false;
+}
+
+function cardClicked(tabname, textToAdd, allowNegativePrompt) {
+ var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
+
+ if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) {
+ textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd;
+ }
+
+ updateInput(textarea);
+}
+
+function saveCardPreview(event, tabname, filename) {
+ var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea');
+ var button = gradioApp().getElementById(tabname + '_save_preview');
+
+ textarea.value = filename;
+ updateInput(textarea);
+
+ button.click();
+
+ event.stopPropagation();
+ event.preventDefault();
+}
+
+function extraNetworksSearchButton(tabs_id, event) {
+ var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
+ var button = event.target;
+ var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
+
+ searchTextarea.value = text;
+ updateInput(searchTextarea);
+}
+
+var globalPopup = null;
+var globalPopupInner = null;
+function popup(contents) {
+ if (!globalPopup) {
+ globalPopup = document.createElement('div');
+ globalPopup.onclick = function() {
+ globalPopup.style.display = "none";
+ };
+ globalPopup.classList.add('global-popup');
+
+ var close = document.createElement('div');
+ close.classList.add('global-popup-close');
+ close.onclick = function() {
+ globalPopup.style.display = "none";
+ };
+ close.title = "Close";
+ globalPopup.appendChild(close);
+
+ globalPopupInner = document.createElement('div');
+ globalPopupInner.onclick = function(event) {
+ event.stopPropagation(); return false;
+ };
+ globalPopupInner.classList.add('global-popup-inner');
+ globalPopup.appendChild(globalPopupInner);
+
+ gradioApp().appendChild(globalPopup);
+ }
+
+ globalPopupInner.innerHTML = '';
+ globalPopupInner.appendChild(contents);
+
+ globalPopup.style.display = "flex";
+}
+
+function extraNetworksShowMetadata(text) {
+ var elem = document.createElement('pre');
+ elem.classList.add('popup-metadata');
+ elem.textContent = text;
+
+ popup(elem);
+}
+
+function requestGet(url, data, handler, errorHandler) {
+ var xhr = new XMLHttpRequest();
+ var args = Object.keys(data).map(function(k) {
+ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]);
+ }).join('&');
+ xhr.open("GET", url + "?" + args, true);
+
+ xhr.onreadystatechange = function() {
+ if (xhr.readyState === 4) {
+ if (xhr.status === 200) {
+ try {
+ var js = JSON.parse(xhr.responseText);
+ handler(js);
+ } catch (error) {
+ console.error(error);
+ errorHandler();
+ }
+ } else {
+ errorHandler();
+ }
+ }
+ };
+ var js = JSON.stringify(data);
+ xhr.send(js);
+}
+
+function extraNetworksRequestMetadata(event, extraPage, cardName) {
+ var showError = function() {
+ extraNetworksShowMetadata("there was an error getting metadata");
+ };
+
+ requestGet("./sd_extra_networks/metadata", {page: extraPage, item: cardName}, function(data) {
+ if (data && data.metadata) {
+ extraNetworksShowMetadata(data.metadata);
+ } else {
+ showError();
+ }
+ }, showError);
+
+ event.stopPropagation();
+}
diff --git a/javascript/generationParams.js b/javascript/generationParams.js
index ef64ee2e..7c0fd221 100644
--- a/javascript/generationParams.js
+++ b/javascript/generationParams.js
@@ -1,33 +1,35 @@
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined;
-onUiUpdate(function(){
- if (!txt2img_gallery) {
- txt2img_gallery = attachGalleryListeners("txt2img")
- }
- if (!img2img_gallery) {
- img2img_gallery = attachGalleryListeners("img2img")
- }
- if (!modal) {
- modal = gradioApp().getElementById('lightboxModal')
- modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
- }
+onAfterUiUpdate(function() {
+ if (!txt2img_gallery) {
+ txt2img_gallery = attachGalleryListeners("txt2img");
+ }
+ if (!img2img_gallery) {
+ img2img_gallery = attachGalleryListeners("img2img");
+ }
+ if (!modal) {
+ modal = gradioApp().getElementById('lightboxModal');
+ modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});
+ }
});
let modalObserver = new MutationObserver(function(mutations) {
- mutations.forEach(function(mutationRecord) {
- let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText
- if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img'))
- gradioApp().getElementById(selectedTab+"_generation_info_button")?.click()
- });
+ mutations.forEach(function(mutationRecord) {
+ let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;
+ if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {
+ gradioApp().getElementById(selectedTab + "_generation_info_button")?.click();
+ }
+ });
});
function attachGalleryListeners(tab_name) {
- var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
- gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
- gallery?.addEventListener('keydown', (e) => {
- if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
- gradioApp().getElementById(tab_name+"_generation_info_button").click()
- });
- return gallery;
+ var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');
+ gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click());
+ gallery?.addEventListener('keydown', (e) => {
+ if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow
+ gradioApp().getElementById(tab_name + "_generation_info_button").click();
+ }
+ });
+ return gallery;
}
diff --git a/javascript/hints.js b/javascript/hints.js
index 3746df99..dc75ce31 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -1,20 +1,21 @@
// mouseover tooltips for various UI elements
-titles = {
+var titles = {
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
"Sampling method": "Which algorithm to use to produce the image",
- "GFPGAN": "Restore low quality faces using GFPGAN neural network",
- "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
- "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
- "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
- "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
-
- "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
- "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
+ "GFPGAN": "Restore low quality faces using GFPGAN neural network",
+ "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
+ "DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
+ "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
+ "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
+
+ "\u{1F4D0}": "Auto detect size from img2img",
+ "Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
+ "Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
- "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
+ "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomized",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style",
@@ -40,7 +41,7 @@ titles = {
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
-
+
"Skip": "Stop processing current image and continue processing.",
"Interrupt": "Stop processing images and return any results accumulated so far.",
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
@@ -66,8 +67,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
- "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
- "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
+ "Images filename pattern": "Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.",
+ "Directory name pattern": "Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
@@ -96,7 +97,7 @@ titles = {
"Add difference": "Result = A + (B - C) * M",
"No interpolation": "Result = A",
- "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
+ "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
@@ -111,40 +112,84 @@ titles = {
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
- "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
+ "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.",
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
+};
+
+function updateTooltip(element) {
+ if (element.title) return; // already has a title
+
+ let text = element.textContent;
+ let tooltip = localization[titles[text]] || titles[text];
+
+ if (!tooltip) {
+ let value = element.value;
+ if (value) tooltip = localization[titles[value]] || titles[value];
+ }
+
+ if (!tooltip) {
+ // Gradio dropdown options have `data-value`.
+ let dataValue = element.dataset.value;
+ if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue];
+ }
+
+ if (!tooltip) {
+ for (const c of element.classList) {
+ if (c in titles) {
+ tooltip = localization[titles[c]] || titles[c];
+ break;
+ }
+ }
+ }
+
+ if (tooltip) {
+ element.title = tooltip;
+ }
}
+// Nodes to check for adding tooltips.
+const tooltipCheckNodes = new Set();
+// Timer for debouncing tooltip check.
+let tooltipCheckTimer = null;
-onUiUpdate(function(){
- gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
- if (span.title) return; // already has a title
-
- let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
-
- if(!tooltip){
- tooltip = localization[titles[span.value]] || titles[span.value];
- }
-
- if(!tooltip){
- for (const c of span.classList) {
- if (c in titles) {
- tooltip = localization[titles[c]] || titles[c];
- break;
- }
- }
- }
-
- if(tooltip){
- span.title = tooltip;
- }
- })
-
- gradioApp().querySelectorAll('select').forEach(function(select){
- if (select.onchange != null) return;
+function processTooltipCheckNodes() {
+ for (const node of tooltipCheckNodes) {
+ updateTooltip(node);
+ }
+ tooltipCheckNodes.clear();
+}
- select.onchange = function(){
- select.title = localization[titles[select.value]] || titles[select.value] || "";
- }
- })
-})
+onUiUpdate(function(mutationRecords) {
+ for (const record of mutationRecords) {
+ if (record.type === "childList" && record.target.classList.contains("options")) {
+ // This smells like a Gradio dropdown menu having changed,
+ // so let's enqueue an update for the input element that shows the current value.
+ let wrap = record.target.parentNode;
+ let input = wrap?.querySelector("input");
+ if (input) {
+ input.title = ""; // So we'll even have a chance to update it.
+ tooltipCheckNodes.add(input);
+ }
+ }
+ for (const node of record.addedNodes) {
+ if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) {
+ if (!node.title) {
+ if (
+ node.tagName === "SPAN" ||
+ node.tagName === "BUTTON" ||
+ node.tagName === "P" ||
+ node.tagName === "INPUT" ||
+ (node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item
+ ) {
+ tooltipCheckNodes.add(node);
+ }
+ }
+ node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n));
+ }
+ }
+ }
+ if (tooltipCheckNodes.size) {
+ clearTimeout(tooltipCheckTimer);
+ tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
+ }
+});
diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js
index 48196be4..0d04ab3b 100644
--- a/javascript/hires_fix.js
+++ b/javascript/hires_fix.js
@@ -1,18 +1,18 @@
-
-function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
- function setInactive(elem, inactive){
- elem.classList.toggle('inactive', !!inactive)
- }
-
- var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
- var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
- var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
-
- gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
-
- setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
- setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
- setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
-
- return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
-}
+
+function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {
+ function setInactive(elem, inactive) {
+ elem.classList.toggle('inactive', !!inactive);
+ }
+
+ var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');
+ var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');
+ var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');
+
+ gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "";
+
+ setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);
+ setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);
+ setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);
+
+ return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];
+}
diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js
index a612705d..900c56f3 100644
--- a/javascript/imageMaskFix.js
+++ b/javascript/imageMaskFix.js
@@ -4,17 +4,16 @@
*/
function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
- if ( ! canvases.length ) {
- canvases_fixed = false; // TODO: this is unused..?
- window.removeEventListener( 'resize', imageMaskResize );
- return;
+ if (!canvases.length) {
+ window.removeEventListener('resize', imageMaskResize);
+ return;
}
const wrapper = canvases[0].closest('.touch-none');
const previewImage = wrapper.previousElementSibling;
- if ( ! previewImage.complete ) {
- previewImage.addEventListener( 'load', imageMaskResize);
+ if (!previewImage.complete) {
+ previewImage.addEventListener('load', imageMaskResize);
return;
}
@@ -24,15 +23,15 @@ function imageMaskResize() {
const nh = previewImage.naturalHeight;
const portrait = nh > nw;
- const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
- const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
+ const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
+ const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
wrapper.style.width = `${wW}px`;
wrapper.style.height = `${wH}px`;
wrapper.style.left = `0px`;
wrapper.style.top = `0px`;
- canvases.forEach( c => {
+ canvases.forEach(c => {
c.style.width = c.style.height = '';
c.style.maxWidth = '100%';
c.style.maxHeight = '100%';
@@ -40,5 +39,5 @@ function imageMaskResize() {
});
}
-onUiUpdate(imageMaskResize);
-window.addEventListener( 'resize', imageMaskResize);
+onAfterUiUpdate(imageMaskResize);
+window.addEventListener('resize', imageMaskResize);
diff --git a/javascript/imageParams.js b/javascript/imageParams.js
deleted file mode 100644
index 64aee93b..00000000
--- a/javascript/imageParams.js
+++ /dev/null
@@ -1,18 +0,0 @@
-window.onload = (function(){
- window.addEventListener('drop', e => {
- const target = e.composedPath()[0];
- if (target.placeholder.indexOf("Prompt") == -1) return;
-
- let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
-
- e.stopPropagation();
- e.preventDefault();
- const imgParent = gradioApp().getElementById(prompt_target);
- const files = e.dataTransfer.files;
- const fileInput = imgParent.querySelector('input[type="file"]');
- if ( fileInput ) {
- fileInput.files = files;
- fileInput.dispatchEvent(new Event('change'));
- }
- });
-});
diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js
index 32066ab8..677e95c1 100644
--- a/javascript/imageviewer.js
+++ b/javascript/imageviewer.js
@@ -5,24 +5,24 @@ function closeModal() {
function showModal(event) {
const source = event.target || event.srcElement;
- const modalImage = gradioApp().getElementById("modalImage")
- const lb = gradioApp().getElementById("lightboxModal")
- modalImage.src = source.src
+ const modalImage = gradioApp().getElementById("modalImage");
+ const lb = gradioApp().getElementById("lightboxModal");
+ modalImage.src = source.src;
if (modalImage.style.display === 'none') {
lb.style.setProperty('background-image', 'url(' + source.src + ')');
}
lb.style.display = "flex";
- lb.focus()
+ lb.focus();
- const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
- const tabImg2Img = gradioApp().getElementById("tab_img2img")
+ const tabTxt2Img = gradioApp().getElementById("tab_txt2img");
+ const tabImg2Img = gradioApp().getElementById("tab_img2img");
// show the save button in modal only on txt2img or img2img tabs
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
- gradioApp().getElementById("modal_save").style.display = "inline"
+ gradioApp().getElementById("modal_save").style.display = "inline";
} else {
- gradioApp().getElementById("modal_save").style.display = "none"
+ gradioApp().getElementById("modal_save").style.display = "none";
}
- event.stopPropagation()
+ event.stopPropagation();
}
function negmod(n, m) {
@@ -30,14 +30,15 @@ function negmod(n, m) {
}
function updateOnBackgroundChange() {
- const modalImage = gradioApp().getElementById("modalImage")
+ const modalImage = gradioApp().getElementById("modalImage");
if (modalImage && modalImage.offsetParent) {
let currentButton = selected_gallery_button();
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') {
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
+ const modal = gradioApp().getElementById("lightboxModal");
+ modal.style.setProperty('background-image', `url(${modalImage.src})`);
}
}
}
@@ -49,108 +50,109 @@ function modalImageSwitch(offset) {
if (galleryButtons.length > 1) {
var currentButton = selected_gallery_button();
- var result = -1
+ var result = -1;
galleryButtons.forEach(function(v, i) {
if (v == currentButton) {
- result = i
+ result = i;
}
- })
+ });
if (result != -1) {
- var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
- nextButton.click()
+ var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)];
+ nextButton.click();
const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal");
modalImage.src = nextButton.children[0].src;
if (modalImage.style.display === 'none') {
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
+ modal.style.setProperty('background-image', `url(${modalImage.src})`);
}
setTimeout(function() {
- modal.focus()
- }, 10)
+ modal.focus();
+ }, 10);
}
}
}
-function saveImage(){
- const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
- const tabImg2Img = gradioApp().getElementById("tab_img2img")
- const saveTxt2Img = "save_txt2img"
- const saveImg2Img = "save_img2img"
+function saveImage() {
+ const tabTxt2Img = gradioApp().getElementById("tab_txt2img");
+ const tabImg2Img = gradioApp().getElementById("tab_img2img");
+ const saveTxt2Img = "save_txt2img";
+ const saveImg2Img = "save_img2img";
if (tabTxt2Img.style.display != "none") {
- gradioApp().getElementById(saveTxt2Img).click()
+ gradioApp().getElementById(saveTxt2Img).click();
} else if (tabImg2Img.style.display != "none") {
- gradioApp().getElementById(saveImg2Img).click()
+ gradioApp().getElementById(saveImg2Img).click();
} else {
- console.error("missing implementation for saving modal of this type")
+ console.error("missing implementation for saving modal of this type");
}
}
function modalSaveImage(event) {
- saveImage()
- event.stopPropagation()
+ saveImage();
+ event.stopPropagation();
}
function modalNextImage(event) {
- modalImageSwitch(1)
- event.stopPropagation()
+ modalImageSwitch(1);
+ event.stopPropagation();
}
function modalPrevImage(event) {
- modalImageSwitch(-1)
- event.stopPropagation()
+ modalImageSwitch(-1);
+ event.stopPropagation();
}
function modalKeyHandler(event) {
switch (event.key) {
- case "s":
- saveImage()
- break;
- case "ArrowLeft":
- modalPrevImage(event)
- break;
- case "ArrowRight":
- modalNextImage(event)
- break;
- case "Escape":
- closeModal();
- break;
+ case "s":
+ saveImage();
+ break;
+ case "ArrowLeft":
+ modalPrevImage(event);
+ break;
+ case "ArrowRight":
+ modalNextImage(event);
+ break;
+ case "Escape":
+ closeModal();
+ break;
}
}
function setupImageForLightbox(e) {
- if (e.dataset.modded)
- return;
+ if (e.dataset.modded) {
+ return;
+ }
- e.dataset.modded = true;
- e.style.cursor='pointer'
- e.style.userSelect='none'
+ e.dataset.modded = true;
+ e.style.cursor = 'pointer';
+ e.style.userSelect = 'none';
- var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
+ var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1;
- // For Firefox, listening on click first switched to next image then shows the lightbox.
- // If you know how to fix this without switching to mousedown event, please.
- // For other browsers the event is click to make it possiblr to drag picture.
- var event = isFirefox ? 'mousedown' : 'click'
+ // For Firefox, listening on click first switched to next image then shows the lightbox.
+ // If you know how to fix this without switching to mousedown event, please.
+ // For other browsers the event is click to make it possiblr to drag picture.
+ var event = isFirefox ? 'mousedown' : 'click';
- e.addEventListener(event, function (evt) {
- if(!opts.js_modal_lightbox || evt.button != 0) return;
+ e.addEventListener(event, function(evt) {
+ if (!opts.js_modal_lightbox || evt.button != 0) return;
- modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
- evt.preventDefault()
- showModal(evt)
- }, true);
+ modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
+ evt.preventDefault();
+ showModal(evt);
+ }, true);
}
function modalZoomSet(modalImage, enable) {
- if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
+ if (modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
}
function modalZoomToggle(event) {
var modalImage = gradioApp().getElementById("modalImage");
- modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
- event.stopPropagation()
+ modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'));
+ event.stopPropagation();
}
function modalTileImageToggle(event) {
@@ -159,99 +161,93 @@ function modalTileImageToggle(event) {
const isTiling = modalImage.style.display === 'none';
if (isTiling) {
modalImage.style.display = 'block';
- modal.style.setProperty('background-image', 'none')
+ modal.style.setProperty('background-image', 'none');
} else {
modalImage.style.display = 'none';
- modal.style.setProperty('background-image', `url(${modalImage.src})`)
+ modal.style.setProperty('background-image', `url(${modalImage.src})`);
}
- event.stopPropagation()
+ event.stopPropagation();
}
-function galleryImageHandler(e) {
- //if (e && e.parentElement.tagName == 'BUTTON') {
- e.onclick = showGalleryImage;
- //}
-}
-
-onUiUpdate(function() {
- var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
+onAfterUiUpdate(function() {
+ var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox);
}
updateOnBackgroundChange();
-})
+});
document.addEventListener("DOMContentLoaded", function() {
//const modalFragment = document.createDocumentFragment();
- const modal = document.createElement('div')
+ const modal = document.createElement('div');
modal.onclick = closeModal;
modal.id = "lightboxModal";
- modal.tabIndex = 0
- modal.addEventListener('keydown', modalKeyHandler, true)
+ modal.tabIndex = 0;
+ modal.addEventListener('keydown', modalKeyHandler, true);
- const modalControls = document.createElement('div')
+ const modalControls = document.createElement('div');
modalControls.className = 'modalControls gradio-container';
modal.append(modalControls);
- const modalZoom = document.createElement('span')
+ const modalZoom = document.createElement('span');
modalZoom.className = 'modalZoom cursor';
- modalZoom.innerHTML = '&#10529;'
- modalZoom.addEventListener('click', modalZoomToggle, true)
+ modalZoom.innerHTML = '&#10529;';
+ modalZoom.addEventListener('click', modalZoomToggle, true);
modalZoom.title = "Toggle zoomed view";
- modalControls.appendChild(modalZoom)
+ modalControls.appendChild(modalZoom);
- const modalTileImage = document.createElement('span')
+ const modalTileImage = document.createElement('span');
modalTileImage.className = 'modalTileImage cursor';
- modalTileImage.innerHTML = '&#8862;'
- modalTileImage.addEventListener('click', modalTileImageToggle, true)
+ modalTileImage.innerHTML = '&#8862;';
+ modalTileImage.addEventListener('click', modalTileImageToggle, true);
modalTileImage.title = "Preview tiling";
- modalControls.appendChild(modalTileImage)
+ modalControls.appendChild(modalTileImage);
- const modalSave = document.createElement("span")
- modalSave.className = "modalSave cursor"
- modalSave.id = "modal_save"
- modalSave.innerHTML = "&#x1F5AB;"
- modalSave.addEventListener("click", modalSaveImage, true)
- modalSave.title = "Save Image(s)"
- modalControls.appendChild(modalSave)
+ const modalSave = document.createElement("span");
+ modalSave.className = "modalSave cursor";
+ modalSave.id = "modal_save";
+ modalSave.innerHTML = "&#x1F5AB;";
+ modalSave.addEventListener("click", modalSaveImage, true);
+ modalSave.title = "Save Image(s)";
+ modalControls.appendChild(modalSave);
- const modalClose = document.createElement('span')
+ const modalClose = document.createElement('span');
modalClose.className = 'modalClose cursor';
- modalClose.innerHTML = '&times;'
+ modalClose.innerHTML = '&times;';
modalClose.onclick = closeModal;
modalClose.title = "Close image viewer";
- modalControls.appendChild(modalClose)
+ modalControls.appendChild(modalClose);
- const modalImage = document.createElement('img')
+ const modalImage = document.createElement('img');
modalImage.id = 'modalImage';
modalImage.onclick = closeModal;
- modalImage.tabIndex = 0
- modalImage.addEventListener('keydown', modalKeyHandler, true)
- modal.appendChild(modalImage)
+ modalImage.tabIndex = 0;
+ modalImage.addEventListener('keydown', modalKeyHandler, true);
+ modal.appendChild(modalImage);
- const modalPrev = document.createElement('a')
+ const modalPrev = document.createElement('a');
modalPrev.className = 'modalPrev';
- modalPrev.innerHTML = '&#10094;'
- modalPrev.tabIndex = 0
+ modalPrev.innerHTML = '&#10094;';
+ modalPrev.tabIndex = 0;
modalPrev.addEventListener('click', modalPrevImage, true);
- modalPrev.addEventListener('keydown', modalKeyHandler, true)
- modal.appendChild(modalPrev)
+ modalPrev.addEventListener('keydown', modalKeyHandler, true);
+ modal.appendChild(modalPrev);
- const modalNext = document.createElement('a')
+ const modalNext = document.createElement('a');
modalNext.className = 'modalNext';
- modalNext.innerHTML = '&#10095;'
- modalNext.tabIndex = 0
+ modalNext.innerHTML = '&#10095;';
+ modalNext.tabIndex = 0;
modalNext.addEventListener('click', modalNextImage, true);
- modalNext.addEventListener('keydown', modalKeyHandler, true)
+ modalNext.addEventListener('keydown', modalKeyHandler, true);
- modal.appendChild(modalNext)
+ modal.appendChild(modalNext);
try {
- gradioApp().appendChild(modal);
- } catch (e) {
- gradioApp().body.appendChild(modal);
- }
+ gradioApp().appendChild(modal);
+ } catch (e) {
+ gradioApp().body.appendChild(modal);
+ }
document.body.appendChild(modal);
diff --git a/javascript/imageviewerGamepad.js b/javascript/imageviewerGamepad.js
index 6297a12b..a22c7e6e 100644
--- a/javascript/imageviewerGamepad.js
+++ b/javascript/imageviewerGamepad.js
@@ -1,7 +1,9 @@
+let gamepads = [];
+
window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index;
let isWaiting = false;
- setInterval(async () => {
+ gamepads[index] = setInterval(async() => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0];
@@ -14,7 +16,7 @@ window.addEventListener('gamepadconnected', (e) => {
}
if (isWaiting) {
await sleepUntil(() => {
- const xValue = navigator.getGamepads()[index].axes[0]
+ const xValue = navigator.getGamepads()[index].axes[0];
if (xValue < 0.3 && xValue > -0.3) {
return true;
}
@@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => {
}, 10);
});
+window.addEventListener('gamepaddisconnected', (e) => {
+ clearInterval(gamepads[e.gamepad.index]);
+});
+
/*
Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr.
diff --git a/javascript/localization.js b/javascript/localization.js
index 86e5ca67..eb22b8a7 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -1,177 +1,176 @@
-
-// localization = {} -- the dict with translations is created by the backend
-
-ignore_ids_for_localization={
- setting_sd_hypernetwork: 'OPTION',
- setting_sd_model_checkpoint: 'OPTION',
- setting_realesrgan_enabled_models: 'OPTION',
- modelmerger_primary_model_name: 'OPTION',
- modelmerger_secondary_model_name: 'OPTION',
- modelmerger_tertiary_model_name: 'OPTION',
- train_embedding: 'OPTION',
- train_hypernetwork: 'OPTION',
- txt2img_styles: 'OPTION',
- img2img_styles: 'OPTION',
- setting_random_artist_categories: 'SPAN',
- setting_face_restoration_model: 'SPAN',
- setting_realesrgan_enabled_models: 'SPAN',
- extras_upscaler_1: 'SPAN',
- extras_upscaler_2: 'SPAN',
-}
-
-re_num = /^[\.\d]+$/
-re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
-
-original_lines = {}
-translated_lines = {}
-
-function hasLocalization() {
- return window.localization && Object.keys(window.localization).length > 0;
-}
-
-function textNodesUnder(el){
- var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
- while(n=walk.nextNode()) a.push(n);
- return a;
-}
-
-function canBeTranslated(node, text){
- if(! text) return false;
- if(! node.parentElement) return false;
-
- var parentType = node.parentElement.nodeName
- if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
-
- if (parentType=='OPTION' || parentType=='SPAN'){
- var pnode = node
- for(var level=0; level<4; level++){
- pnode = pnode.parentElement
- if(! pnode) break;
-
- if(ignore_ids_for_localization[pnode.id] == parentType) return false;
- }
- }
-
- if(re_num.test(text)) return false;
- if(re_emoji.test(text)) return false;
- return true
-}
-
-function getTranslation(text){
- if(! text) return undefined
-
- if(translated_lines[text] === undefined){
- original_lines[text] = 1
- }
-
- tl = localization[text]
- if(tl !== undefined){
- translated_lines[tl] = 1
- }
-
- return tl
-}
-
-function processTextNode(node){
- var text = node.textContent.trim()
-
- if(! canBeTranslated(node, text)) return
-
- tl = getTranslation(text)
- if(tl !== undefined){
- node.textContent = tl
- }
-}
-
-function processNode(node){
- if(node.nodeType == 3){
- processTextNode(node)
- return
- }
-
- if(node.title){
- tl = getTranslation(node.title)
- if(tl !== undefined){
- node.title = tl
- }
- }
-
- if(node.placeholder){
- tl = getTranslation(node.placeholder)
- if(tl !== undefined){
- node.placeholder = tl
- }
- }
-
- textNodesUnder(node).forEach(function(node){
- processTextNode(node)
- })
-}
-
-function dumpTranslations(){
- if(!hasLocalization()) {
- // If we don't have any localization,
- // we will not have traversed the app to find
- // original_lines, so do that now.
- processNode(gradioApp());
- }
- var dumped = {}
- if (localization.rtl) {
- dumped.rtl = true;
- }
-
- for (const text in original_lines) {
- if(dumped[text] !== undefined) continue;
- dumped[text] = localization[text] || text;
- }
-
- return dumped;
-}
-
-function download_localization() {
- var text = JSON.stringify(dumpTranslations(), null, 4)
-
- var element = document.createElement('a');
- element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
- element.setAttribute('download', "localization.json");
- element.style.display = 'none';
- document.body.appendChild(element);
-
- element.click();
-
- document.body.removeChild(element);
-}
-
-document.addEventListener("DOMContentLoaded", function () {
- if (!hasLocalization()) {
- return;
- }
-
- onUiUpdate(function (m) {
- m.forEach(function (mutation) {
- mutation.addedNodes.forEach(function (node) {
- processNode(node)
- })
- });
- })
-
- processNode(gradioApp())
-
- if (localization.rtl) { // if the language is from right to left,
- (new MutationObserver((mutations, observer) => { // wait for the style to load
- mutations.forEach(mutation => {
- mutation.addedNodes.forEach(node => {
- if (node.tagName === 'STYLE') {
- observer.disconnect();
-
- for (const x of node.sheet.rules) { // find all rtl media rules
- if (Array.from(x.media || []).includes('rtl')) {
- x.media.appendMedium('all'); // enable them
- }
- }
- }
- })
- });
- })).observe(gradioApp(), { childList: true });
- }
-})
+
+// localization = {} -- the dict with translations is created by the backend
+
+var ignore_ids_for_localization = {
+ setting_sd_hypernetwork: 'OPTION',
+ setting_sd_model_checkpoint: 'OPTION',
+ modelmerger_primary_model_name: 'OPTION',
+ modelmerger_secondary_model_name: 'OPTION',
+ modelmerger_tertiary_model_name: 'OPTION',
+ train_embedding: 'OPTION',
+ train_hypernetwork: 'OPTION',
+ txt2img_styles: 'OPTION',
+ img2img_styles: 'OPTION',
+ setting_random_artist_categories: 'SPAN',
+ setting_face_restoration_model: 'SPAN',
+ setting_realesrgan_enabled_models: 'SPAN',
+ extras_upscaler_1: 'SPAN',
+ extras_upscaler_2: 'SPAN',
+};
+
+var re_num = /^[.\d]+$/;
+var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u;
+
+var original_lines = {};
+var translated_lines = {};
+
+function hasLocalization() {
+ return window.localization && Object.keys(window.localization).length > 0;
+}
+
+function textNodesUnder(el) {
+ var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false);
+ while ((n = walk.nextNode())) a.push(n);
+ return a;
+}
+
+function canBeTranslated(node, text) {
+ if (!text) return false;
+ if (!node.parentElement) return false;
+
+ var parentType = node.parentElement.nodeName;
+ if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false;
+
+ if (parentType == 'OPTION' || parentType == 'SPAN') {
+ var pnode = node;
+ for (var level = 0; level < 4; level++) {
+ pnode = pnode.parentElement;
+ if (!pnode) break;
+
+ if (ignore_ids_for_localization[pnode.id] == parentType) return false;
+ }
+ }
+
+ if (re_num.test(text)) return false;
+ if (re_emoji.test(text)) return false;
+ return true;
+}
+
+function getTranslation(text) {
+ if (!text) return undefined;
+
+ if (translated_lines[text] === undefined) {
+ original_lines[text] = 1;
+ }
+
+ var tl = localization[text];
+ if (tl !== undefined) {
+ translated_lines[tl] = 1;
+ }
+
+ return tl;
+}
+
+function processTextNode(node) {
+ var text = node.textContent.trim();
+
+ if (!canBeTranslated(node, text)) return;
+
+ var tl = getTranslation(text);
+ if (tl !== undefined) {
+ node.textContent = tl;
+ }
+}
+
+function processNode(node) {
+ if (node.nodeType == 3) {
+ processTextNode(node);
+ return;
+ }
+
+ if (node.title) {
+ let tl = getTranslation(node.title);
+ if (tl !== undefined) {
+ node.title = tl;
+ }
+ }
+
+ if (node.placeholder) {
+ let tl = getTranslation(node.placeholder);
+ if (tl !== undefined) {
+ node.placeholder = tl;
+ }
+ }
+
+ textNodesUnder(node).forEach(function(node) {
+ processTextNode(node);
+ });
+}
+
+function dumpTranslations() {
+ if (!hasLocalization()) {
+ // If we don't have any localization,
+ // we will not have traversed the app to find
+ // original_lines, so do that now.
+ processNode(gradioApp());
+ }
+ var dumped = {};
+ if (localization.rtl) {
+ dumped.rtl = true;
+ }
+
+ for (const text in original_lines) {
+ if (dumped[text] !== undefined) continue;
+ dumped[text] = localization[text] || text;
+ }
+
+ return dumped;
+}
+
+function download_localization() {
+ var text = JSON.stringify(dumpTranslations(), null, 4);
+
+ var element = document.createElement('a');
+ element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
+ element.setAttribute('download', "localization.json");
+ element.style.display = 'none';
+ document.body.appendChild(element);
+
+ element.click();
+
+ document.body.removeChild(element);
+}
+
+document.addEventListener("DOMContentLoaded", function() {
+ if (!hasLocalization()) {
+ return;
+ }
+
+ onUiUpdate(function(m) {
+ m.forEach(function(mutation) {
+ mutation.addedNodes.forEach(function(node) {
+ processNode(node);
+ });
+ });
+ });
+
+ processNode(gradioApp());
+
+ if (localization.rtl) { // if the language is from right to left,
+ (new MutationObserver((mutations, observer) => { // wait for the style to load
+ mutations.forEach(mutation => {
+ mutation.addedNodes.forEach(node => {
+ if (node.tagName === 'STYLE') {
+ observer.disconnect();
+
+ for (const x of node.sheet.rules) { // find all rtl media rules
+ if (Array.from(x.media || []).includes('rtl')) {
+ x.media.appendMedium('all'); // enable them
+ }
+ }
+ }
+ });
+ });
+ })).observe(gradioApp(), {childList: true});
+ }
+});
diff --git a/javascript/notification.js b/javascript/notification.js
index 83fce1f8..76c5715d 100644
--- a/javascript/notification.js
+++ b/javascript/notification.js
@@ -4,14 +4,14 @@ let lastHeadImg = null;
let notificationButton = null;
-onUiUpdate(function(){
- if(notificationButton == null){
- notificationButton = gradioApp().getElementById('request_notifications')
+onAfterUiUpdate(function() {
+ if (notificationButton == null) {
+ notificationButton = gradioApp().getElementById('request_notifications');
- if(notificationButton != null){
+ if (notificationButton != null) {
notificationButton.addEventListener('click', () => {
void Notification.requestPermission();
- },true);
+ }, true);
}
}
@@ -42,7 +42,7 @@ onUiUpdate(function(){
}
);
- notification.onclick = function(_){
+ notification.onclick = function(_) {
parent.focus();
this.close();
};
diff --git a/javascript/profilerVisualization.js b/javascript/profilerVisualization.js
new file mode 100644
index 00000000..9d8e5f42
--- /dev/null
+++ b/javascript/profilerVisualization.js
@@ -0,0 +1,153 @@
+
+function createRow(table, cellName, items) {
+ var tr = document.createElement('tr');
+ var res = [];
+
+ items.forEach(function(x, i) {
+ if (x === undefined) {
+ res.push(null);
+ return;
+ }
+
+ var td = document.createElement(cellName);
+ td.textContent = x;
+ tr.appendChild(td);
+ res.push(td);
+
+ var colspan = 1;
+ for (var n = i + 1; n < items.length; n++) {
+ if (items[n] !== undefined) {
+ break;
+ }
+
+ colspan += 1;
+ }
+
+ if (colspan > 1) {
+ td.colSpan = colspan;
+ }
+ });
+
+ table.appendChild(tr);
+
+ return res;
+}
+
+function showProfile(path, cutoff = 0.05) {
+ requestGet(path, {}, function(data) {
+ var table = document.createElement('table');
+ table.className = 'popup-table';
+
+ data.records['total'] = data.total;
+ var keys = Object.keys(data.records).sort(function(a, b) {
+ return data.records[b] - data.records[a];
+ });
+ var items = keys.map(function(x) {
+ return {key: x, parts: x.split('/'), time: data.records[x]};
+ });
+ var maxLength = items.reduce(function(a, b) {
+ return Math.max(a, b.parts.length);
+ }, 0);
+
+ var cols = createRow(table, 'th', ['record', 'seconds']);
+ cols[0].colSpan = maxLength;
+
+ function arraysEqual(a, b) {
+ return !(a < b || b < a);
+ }
+
+ var addLevel = function(level, parent, hide) {
+ var matching = items.filter(function(x) {
+ return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent);
+ });
+ var sorted = matching.sort(function(a, b) {
+ return b.time - a.time;
+ });
+ var othersTime = 0;
+ var othersList = [];
+ var othersRows = [];
+ var childrenRows = [];
+ sorted.forEach(function(x) {
+ var visible = x.time >= cutoff && !hide;
+
+ var cells = [];
+ for (var i = 0; i < maxLength; i++) {
+ cells.push(x.parts[i]);
+ }
+ cells.push(x.time.toFixed(3));
+ var cols = createRow(table, 'td', cells);
+ for (i = 0; i < level; i++) {
+ cols[i].className = 'muted';
+ }
+
+ var tr = cols[0].parentNode;
+ if (!visible) {
+ tr.classList.add("hidden");
+ }
+
+ if (x.time >= cutoff) {
+ childrenRows.push(tr);
+ } else {
+ othersTime += x.time;
+ othersList.push(x.parts[level]);
+ othersRows.push(tr);
+ }
+
+ var children = addLevel(level + 1, parent.concat([x.parts[level]]), true);
+ if (children.length > 0) {
+ var cell = cols[level];
+ var onclick = function() {
+ cell.classList.remove("link");
+ cell.removeEventListener("click", onclick);
+ children.forEach(function(x) {
+ x.classList.remove("hidden");
+ });
+ };
+ cell.classList.add("link");
+ cell.addEventListener("click", onclick);
+ }
+ });
+
+ if (othersTime > 0) {
+ var cells = [];
+ for (var i = 0; i < maxLength; i++) {
+ cells.push(parent[i]);
+ }
+ cells.push(othersTime.toFixed(3));
+ cells[level] = 'others';
+ var cols = createRow(table, 'td', cells);
+ for (i = 0; i < level; i++) {
+ cols[i].className = 'muted';
+ }
+
+ var cell = cols[level];
+ var tr = cell.parentNode;
+ var onclick = function() {
+ tr.classList.add("hidden");
+ cell.classList.remove("link");
+ cell.removeEventListener("click", onclick);
+ othersRows.forEach(function(x) {
+ x.classList.remove("hidden");
+ });
+ };
+
+ cell.title = othersList.join(", ");
+ cell.classList.add("link");
+ cell.addEventListener("click", onclick);
+
+ if (hide) {
+ tr.classList.add("hidden");
+ }
+
+ childrenRows.push(tr);
+ }
+
+ return childrenRows;
+ };
+
+ addLevel(0, []);
+
+ popup(table);
+ });
+}
+
diff --git a/javascript/progressbar.js b/javascript/progressbar.js
index 8d2c3492..29299787 100644
--- a/javascript/progressbar.js
+++ b/javascript/progressbar.js
@@ -1,29 +1,29 @@
// code related to showing and updating progressbar shown as the image is being made
-function rememberGallerySelection(){
+function rememberGallerySelection() {
}
-function getGallerySelectedIndex(){
+function getGallerySelectedIndex() {
}
-function request(url, data, handler, errorHandler){
+function request(url, data, handler, errorHandler) {
var xhr = new XMLHttpRequest();
xhr.open("POST", url, true);
xhr.setRequestHeader("Content-Type", "application/json");
- xhr.onreadystatechange = function () {
+ xhr.onreadystatechange = function() {
if (xhr.readyState === 4) {
if (xhr.status === 200) {
try {
var js = JSON.parse(xhr.responseText);
- handler(js)
+ handler(js);
} catch (error) {
console.error(error);
- errorHandler()
+ errorHandler();
}
- } else{
- errorHandler()
+ } else {
+ errorHandler();
}
}
};
@@ -31,147 +31,147 @@ function request(url, data, handler, errorHandler){
xhr.send(js);
}
-function pad2(x){
- return x<10 ? '0'+x : x
+function pad2(x) {
+ return x < 10 ? '0' + x : x;
}
-function formatTime(secs){
- if(secs > 3600){
- return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
- } else if(secs > 60){
- return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
- } else{
- return Math.floor(secs) + "s"
+function formatTime(secs) {
+ if (secs > 3600) {
+ return pad2(Math.floor(secs / 60 / 60)) + ":" + pad2(Math.floor(secs / 60) % 60) + ":" + pad2(Math.floor(secs) % 60);
+ } else if (secs > 60) {
+ return pad2(Math.floor(secs / 60)) + ":" + pad2(Math.floor(secs) % 60);
+ } else {
+ return Math.floor(secs) + "s";
}
}
-function setTitle(progress){
- var title = 'Stable Diffusion'
+function setTitle(progress) {
+ var title = 'Stable Diffusion';
- if(opts.show_progress_in_title && progress){
+ if (opts.show_progress_in_title && progress) {
title = '[' + progress.trim() + '] ' + title;
}
- if(document.title != title){
- document.title = title;
+ if (document.title != title) {
+ document.title = title;
}
}
-function randomId(){
- return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
+function randomId() {
+ return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + ")";
}
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
// calls onProgress every time there is a progress update
-function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){
- var dateStart = new Date()
- var wasEverActive = false
- var parentProgressbar = progressbarContainer.parentNode
- var parentGallery = gallery ? gallery.parentNode : null
-
- var divProgress = document.createElement('div')
- divProgress.className='progressDiv'
- divProgress.style.display = opts.show_progressbar ? "block" : "none"
- var divInner = document.createElement('div')
- divInner.className='progress'
-
- divProgress.appendChild(divInner)
- parentProgressbar.insertBefore(divProgress, progressbarContainer)
-
- if(parentGallery){
- var livePreview = document.createElement('div')
- livePreview.className='livePreview'
- parentGallery.insertBefore(livePreview, gallery)
+function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout = 40) {
+ var dateStart = new Date();
+ var wasEverActive = false;
+ var parentProgressbar = progressbarContainer.parentNode;
+ var parentGallery = gallery ? gallery.parentNode : null;
+
+ var divProgress = document.createElement('div');
+ divProgress.className = 'progressDiv';
+ divProgress.style.display = opts.show_progressbar ? "block" : "none";
+ var divInner = document.createElement('div');
+ divInner.className = 'progress';
+
+ divProgress.appendChild(divInner);
+ parentProgressbar.insertBefore(divProgress, progressbarContainer);
+
+ if (parentGallery) {
+ var livePreview = document.createElement('div');
+ livePreview.className = 'livePreview';
+ parentGallery.insertBefore(livePreview, gallery);
}
- var removeProgressBar = function(){
- setTitle("")
- parentProgressbar.removeChild(divProgress)
- if(parentGallery) parentGallery.removeChild(livePreview)
- atEnd()
- }
+ var removeProgressBar = function() {
+ setTitle("");
+ parentProgressbar.removeChild(divProgress);
+ if (parentGallery) parentGallery.removeChild(livePreview);
+ atEnd();
+ };
- var fun = function(id_task, id_live_preview){
- request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
- if(res.completed){
- removeProgressBar()
- return
+ var fun = function(id_task, id_live_preview) {
+ request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
+ if (res.completed) {
+ removeProgressBar();
+ return;
}
- var rect = progressbarContainer.getBoundingClientRect()
+ var rect = progressbarContainer.getBoundingClientRect();
- if(rect.width){
+ if (rect.width) {
divProgress.style.width = rect.width + "px";
}
- let progressText = ""
+ let progressText = "";
- divInner.style.width = ((res.progress || 0) * 100.0) + '%'
- divInner.style.background = res.progress ? "" : "transparent"
+ divInner.style.width = ((res.progress || 0) * 100.0) + '%';
+ divInner.style.background = res.progress ? "" : "transparent";
- if(res.progress > 0){
- progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
+ if (res.progress > 0) {
+ progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%';
}
- if(res.eta){
- progressText += " ETA: " + formatTime(res.eta)
+ if (res.eta) {
+ progressText += " ETA: " + formatTime(res.eta);
}
- setTitle(progressText)
+ setTitle(progressText);
- if(res.textinfo && res.textinfo.indexOf("\n") == -1){
- progressText = res.textinfo + " " + progressText
+ if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
+ progressText = res.textinfo + " " + progressText;
}
- divInner.textContent = progressText
+ divInner.textContent = progressText;
- var elapsedFromStart = (new Date() - dateStart) / 1000
+ var elapsedFromStart = (new Date() - dateStart) / 1000;
- if(res.active) wasEverActive = true;
+ if (res.active) wasEverActive = true;
- if(! res.active && wasEverActive){
- removeProgressBar()
- return
+ if (!res.active && wasEverActive) {
+ removeProgressBar();
+ return;
}
- if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){
- removeProgressBar()
- return
+ if (elapsedFromStart > inactivityTimeout && !res.queued && !res.active) {
+ removeProgressBar();
+ return;
}
- if(res.live_preview && gallery){
- var rect = gallery.getBoundingClientRect()
- if(rect.width){
- livePreview.style.width = rect.width + "px"
- livePreview.style.height = rect.height + "px"
+ if (res.live_preview && gallery) {
+ rect = gallery.getBoundingClientRect();
+ if (rect.width) {
+ livePreview.style.width = rect.width + "px";
+ livePreview.style.height = rect.height + "px";
}
var img = new Image();
img.onload = function() {
- livePreview.appendChild(img)
- if(livePreview.childElementCount > 2){
- livePreview.removeChild(livePreview.firstElementChild)
+ livePreview.appendChild(img);
+ if (livePreview.childElementCount > 2) {
+ livePreview.removeChild(livePreview.firstElementChild);
}
- }
+ };
img.src = res.live_preview;
}
- if(onProgress){
- onProgress(res)
+ if (onProgress) {
+ onProgress(res);
}
setTimeout(() => {
fun(id_task, res.id_live_preview);
- }, opts.live_preview_refresh_period || 500)
- }, function(){
- removeProgressBar()
- })
- }
+ }, opts.live_preview_refresh_period || 500);
+ }, function() {
+ removeProgressBar();
+ });
+ };
- fun(id_task, 0)
+ fun(id_task, 0);
}
diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js
index 0354b860..20443fcc 100644
--- a/javascript/textualInversion.js
+++ b/javascript/textualInversion.js
@@ -1,17 +1,17 @@
-
-
-
-function start_training_textual_inversion(){
- gradioApp().querySelector('#ti_error').innerHTML=''
-
- var id = randomId()
- requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
- gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
- })
-
- var res = args_to_array(arguments)
-
- res[0] = id
-
- return res
-}
+
+
+
+function start_training_textual_inversion() {
+ gradioApp().querySelector('#ti_error').innerHTML = '';
+
+ var id = randomId();
+ requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) {
+ gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo;
+ });
+
+ var res = Array.from(arguments);
+
+ res[0] = id;
+
+ return res;
+}
diff --git a/javascript/token-counters.js b/javascript/token-counters.js
new file mode 100644
index 00000000..9d81a723
--- /dev/null
+++ b/javascript/token-counters.js
@@ -0,0 +1,83 @@
+let promptTokenCountDebounceTime = 800;
+let promptTokenCountTimeouts = {};
+var promptTokenCountUpdateFunctions = {};
+
+function update_txt2img_tokens(...args) {
+ // Called from Gradio
+ update_token_counter("txt2img_token_button");
+ if (args.length == 2) {
+ return args[0];
+ }
+ return args;
+}
+
+function update_img2img_tokens(...args) {
+ // Called from Gradio
+ update_token_counter("img2img_token_button");
+ if (args.length == 2) {
+ return args[0];
+ }
+ return args;
+}
+
+function update_token_counter(button_id) {
+ if (opts.disable_token_counters) {
+ return;
+ }
+ if (promptTokenCountTimeouts[button_id]) {
+ clearTimeout(promptTokenCountTimeouts[button_id]);
+ }
+ promptTokenCountTimeouts[button_id] = setTimeout(
+ () => gradioApp().getElementById(button_id)?.click(),
+ promptTokenCountDebounceTime,
+ );
+}
+
+
+function recalculatePromptTokens(name) {
+ promptTokenCountUpdateFunctions[name]?.();
+}
+
+function recalculate_prompts_txt2img() {
+ // Called from Gradio
+ recalculatePromptTokens('txt2img_prompt');
+ recalculatePromptTokens('txt2img_neg_prompt');
+ return Array.from(arguments);
+}
+
+function recalculate_prompts_img2img() {
+ // Called from Gradio
+ recalculatePromptTokens('img2img_prompt');
+ recalculatePromptTokens('img2img_neg_prompt');
+ return Array.from(arguments);
+}
+
+function setupTokenCounting(id, id_counter, id_button) {
+ var prompt = gradioApp().getElementById(id);
+ var counter = gradioApp().getElementById(id_counter);
+ var textarea = gradioApp().querySelector(`#${id} > label > textarea`);
+
+ if (opts.disable_token_counters) {
+ counter.style.display = "none";
+ return;
+ }
+
+ if (counter.parentElement == prompt.parentElement) {
+ return;
+ }
+
+ prompt.parentElement.insertBefore(counter, prompt);
+ prompt.parentElement.style.position = "relative";
+
+ promptTokenCountUpdateFunctions[id] = function() {
+ update_token_counter(id_button);
+ };
+ textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
+}
+
+function setupTokenCounters() {
+ setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
+ setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
+ setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
+ setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
+}
diff --git a/javascript/ui.js b/javascript/ui.js
index ed9673d6..d70a681b 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -1,9 +1,9 @@
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
-function set_theme(theme){
- var gradioURL = window.location.href
+function set_theme(theme) {
+ var gradioURL = window.location.href;
if (!gradioURL.includes('?__theme=')) {
- window.location.replace(gradioURL + '?__theme=' + theme);
+ window.location.replace(gradioURL + '?__theme=' + theme);
}
}
@@ -14,7 +14,7 @@ function all_gallery_buttons() {
if (elem.parentElement.offsetParent) {
visibleGalleryButtons.push(elem);
}
- })
+ });
return visibleGalleryButtons;
}
@@ -25,31 +25,35 @@ function selected_gallery_button() {
if (elem.parentElement.offsetParent) {
visibleCurrentButton = elem;
}
- })
+ });
return visibleCurrentButton;
}
-function selected_gallery_index(){
+function selected_gallery_index() {
var buttons = all_gallery_buttons();
var button = selected_gallery_button();
- var result = -1
- buttons.forEach(function(v, i){ if(v==button) { result = i } })
+ var result = -1;
+ buttons.forEach(function(v, i) {
+ if (v == button) {
+ result = i;
+ }
+ });
- return result
+ return result;
}
-function extract_image_from_gallery(gallery){
- if (gallery.length == 0){
+function extract_image_from_gallery(gallery) {
+ if (gallery.length == 0) {
return [null];
}
- if (gallery.length == 1){
+ if (gallery.length == 1) {
return [gallery[0]];
}
- var index = selected_gallery_index()
+ var index = selected_gallery_index();
- if (index < 0 || index >= gallery.length){
+ if (index < 0 || index >= gallery.length) {
// Use the first image in the gallery as the default
index = 0;
}
@@ -57,249 +61,205 @@ function extract_image_from_gallery(gallery){
return [gallery[index]];
}
-function args_to_array(args){
- var res = []
- for(var i=0;i<args.length;i++){
- res.push(args[i])
- }
- return res
-}
+window.args_to_array = Array.from; // Compatibility with e.g. extensions that may expect this to be around
-function switch_to_txt2img(){
+function switch_to_txt2img() {
gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
- return args_to_array(arguments);
+ return Array.from(arguments);
}
-function switch_to_img2img_tab(no){
+function switch_to_img2img_tab(no) {
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();
}
-function switch_to_img2img(){
+function switch_to_img2img() {
switch_to_img2img_tab(0);
- return args_to_array(arguments);
+ return Array.from(arguments);
}
-function switch_to_sketch(){
+function switch_to_sketch() {
switch_to_img2img_tab(1);
- return args_to_array(arguments);
+ return Array.from(arguments);
}
-function switch_to_inpaint(){
+function switch_to_inpaint() {
switch_to_img2img_tab(2);
- return args_to_array(arguments);
+ return Array.from(arguments);
}
-function switch_to_inpaint_sketch(){
+function switch_to_inpaint_sketch() {
switch_to_img2img_tab(3);
- return args_to_array(arguments);
-}
-
-function switch_to_inpaint(){
- gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
- gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click();
-
- return args_to_array(arguments);
+ return Array.from(arguments);
}
-function switch_to_extras(){
+function switch_to_extras() {
gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
- return args_to_array(arguments);
-}
-
-function get_tab_index(tabId){
- var res = 0
-
- gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
- if(button.className.indexOf('selected') != -1)
- res = i
- })
-
- return res
+ return Array.from(arguments);
}
-function create_tab_index_args(tabId, args){
- var res = []
- for(var i=0; i<args.length; i++){
- res.push(args[i])
+function get_tab_index(tabId) {
+ let buttons = gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button');
+ for (let i = 0; i < buttons.length; i++) {
+ if (buttons[i].classList.contains('selected')) {
+ return i;
+ }
}
+ return 0;
+}
- res[0] = get_tab_index(tabId)
-
- return res
+function create_tab_index_args(tabId, args) {
+ var res = Array.from(args);
+ res[0] = get_tab_index(tabId);
+ return res;
}
function get_img2img_tab_index() {
- let res = args_to_array(arguments)
- res.splice(-2)
- res[0] = get_tab_index('mode_img2img')
- return res
+ let res = Array.from(arguments);
+ res.splice(-2);
+ res[0] = get_tab_index('mode_img2img');
+ return res;
}
-function create_submit_args(args){
- var res = []
- for(var i=0;i<args.length;i++){
- res.push(args[i])
- }
+function create_submit_args(args) {
+ var res = Array.from(args);
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
// I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
// If gradio at some point stops sending outputs, this may break something
- if(Array.isArray(res[res.length - 3])){
- res[res.length - 3] = null
+ if (Array.isArray(res[res.length - 3])) {
+ res[res.length - 3] = null;
}
- return res
+ return res;
}
-function showSubmitButtons(tabname, show){
- gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
- gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
+function showSubmitButtons(tabname, show) {
+ gradioApp().getElementById(tabname + '_interrupt').style.display = show ? "none" : "block";
+ gradioApp().getElementById(tabname + '_skip').style.display = show ? "none" : "block";
}
-function showRestoreProgressButton(tabname, show){
- var button = gradioApp().getElementById(tabname + "_restore_progress")
- if(! button) return
+function showRestoreProgressButton(tabname, show) {
+ var button = gradioApp().getElementById(tabname + "_restore_progress");
+ if (!button) return;
- button.style.display = show ? "flex" : "none"
+ button.style.display = show ? "flex" : "none";
}
-function submit(){
- rememberGallerySelection('txt2img_gallery')
- showSubmitButtons('txt2img', false)
+function submit() {
+ showSubmitButtons('txt2img', false);
- var id = randomId()
+ var id = randomId();
localStorage.setItem("txt2img_task_id", id);
- requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
- showSubmitButtons('txt2img', true)
- localStorage.removeItem("txt2img_task_id")
- showRestoreProgressButton('txt2img', false)
- })
+ requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
+ showSubmitButtons('txt2img', true);
+ localStorage.removeItem("txt2img_task_id");
+ showRestoreProgressButton('txt2img', false);
+ });
- var res = create_submit_args(arguments)
+ var res = create_submit_args(arguments);
- res[0] = id
+ res[0] = id;
- return res
+ return res;
}
-function submit_img2img(){
- rememberGallerySelection('img2img_gallery')
- showSubmitButtons('img2img', false)
+function submit_img2img() {
+ showSubmitButtons('img2img', false);
- var id = randomId()
+ var id = randomId();
localStorage.setItem("img2img_task_id", id);
- requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
- showSubmitButtons('img2img', true)
- localStorage.removeItem("img2img_task_id")
- showRestoreProgressButton('img2img', false)
- })
+ requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
+ showSubmitButtons('img2img', true);
+ localStorage.removeItem("img2img_task_id");
+ showRestoreProgressButton('img2img', false);
+ });
- var res = create_submit_args(arguments)
+ var res = create_submit_args(arguments);
- res[0] = id
- res[1] = get_tab_index('mode_img2img')
+ res[0] = id;
+ res[1] = get_tab_index('mode_img2img');
- return res
+ return res;
}
-function restoreProgressTxt2img(){
- showRestoreProgressButton("txt2img", false)
- var id = localStorage.getItem("txt2img_task_id")
+function restoreProgressTxt2img() {
+ showRestoreProgressButton("txt2img", false);
+ var id = localStorage.getItem("txt2img_task_id");
- id = localStorage.getItem("txt2img_task_id")
+ id = localStorage.getItem("txt2img_task_id");
- if(id) {
- requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
- showSubmitButtons('txt2img', true)
- }, null, 0)
+ if (id) {
+ requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
+ showSubmitButtons('txt2img', true);
+ }, null, 0);
}
- return id
+ return id;
}
-function restoreProgressImg2img(){
- showRestoreProgressButton("img2img", false)
-
- var id = localStorage.getItem("img2img_task_id")
+function restoreProgressImg2img() {
+ showRestoreProgressButton("img2img", false);
+
+ var id = localStorage.getItem("img2img_task_id");
- if(id) {
- requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
- showSubmitButtons('img2img', true)
- }, null, 0)
+ if (id) {
+ requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
+ showSubmitButtons('img2img', true);
+ }, null, 0);
}
- return id
+ return id;
}
-onUiLoaded(function () {
- showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"))
- showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"))
+onUiLoaded(function() {
+ showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"));
+ showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"));
});
-function modelmerger(){
- var id = randomId()
- requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
+function modelmerger() {
+ var id = randomId();
+ requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function() {});
- var res = create_submit_args(arguments)
- res[0] = id
- return res
+ var res = create_submit_args(arguments);
+ res[0] = id;
+ return res;
}
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
- var name_ = prompt('Style name:')
- return [name_, prompt_text, negative_prompt_text]
+ var name_ = prompt('Style name:');
+ return [name_, prompt_text, negative_prompt_text];
}
function confirm_clear_prompt(prompt, negative_prompt) {
- if(confirm("Delete prompt?")) {
- prompt = ""
- negative_prompt = ""
+ if (confirm("Delete prompt?")) {
+ prompt = "";
+ negative_prompt = "";
}
- return [prompt, negative_prompt]
+ return [prompt, negative_prompt];
}
-promptTokecountUpdateFuncs = {}
-
-function recalculatePromptTokens(name){
- if(promptTokecountUpdateFuncs[name]){
- promptTokecountUpdateFuncs[name]()
- }
-}
-
-function recalculate_prompts_txt2img(){
- recalculatePromptTokens('txt2img_prompt')
- recalculatePromptTokens('txt2img_neg_prompt')
- return args_to_array(arguments);
-}
-
-function recalculate_prompts_img2img(){
- recalculatePromptTokens('img2img_prompt')
- recalculatePromptTokens('img2img_neg_prompt')
- return args_to_array(arguments);
-}
+var opts = {};
+onAfterUiUpdate(function() {
+ if (Object.keys(opts).length != 0) return;
+ var json_elem = gradioApp().getElementById('settings_json');
+ if (json_elem == null) return;
-var opts = {}
-onUiUpdate(function(){
- if(Object.keys(opts).length != 0) return;
+ var textarea = json_elem.querySelector('textarea');
+ var jsdata = textarea.value;
+ opts = JSON.parse(jsdata);
- var json_elem = gradioApp().getElementById('settings_json')
- if(json_elem == null) return;
-
- var textarea = json_elem.querySelector('textarea')
- var jsdata = textarea.value
- opts = JSON.parse(jsdata)
- executeCallbacks(optionsChangedCallbacks);
+ executeCallbacks(optionsChangedCallbacks); /*global optionsChangedCallbacks*/
Object.defineProperty(textarea, 'value', {
set: function(newValue) {
@@ -308,7 +268,7 @@ onUiUpdate(function(){
valueProp.set.call(textarea, newValue);
if (oldValue != newValue) {
- opts = JSON.parse(textarea.value)
+ opts = JSON.parse(textarea.value);
}
executeCallbacks(optionsChangedCallbacks);
@@ -319,123 +279,109 @@ onUiUpdate(function(){
}
});
- json_elem.parentElement.style.display="none"
-
- function registerTextarea(id, id_counter, id_button){
- var prompt = gradioApp().getElementById(id)
- var counter = gradioApp().getElementById(id_counter)
- var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
-
- if(counter.parentElement == prompt.parentElement){
- return
- }
+ json_elem.parentElement.style.display = "none";
- prompt.parentElement.insertBefore(counter, prompt)
- prompt.parentElement.style.position = "relative"
+ setupTokenCounters();
- promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
- textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
- }
-
- registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
- registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button')
- registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
- registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
-
- var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
- var settings_tabs = gradioApp().querySelector('#settings div')
- if(show_all_pages && settings_tabs){
- settings_tabs.appendChild(show_all_pages)
- show_all_pages.onclick = function(){
- gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
- if(elem.id == "settings_tab_licenses")
+ var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
+ var settings_tabs = gradioApp().querySelector('#settings div');
+ if (show_all_pages && settings_tabs) {
+ settings_tabs.appendChild(show_all_pages);
+ show_all_pages.onclick = function() {
+ gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
+ if (elem.id == "settings_tab_licenses") {
return;
+ }
elem.style.display = "block";
- })
- }
+ });
+ };
}
-})
+});
-onOptionsChanged(function(){
- var elem = gradioApp().getElementById('sd_checkpoint_hash')
- var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
- var shorthash = sd_checkpoint_hash.substring(0,10)
+onOptionsChanged(function() {
+ var elem = gradioApp().getElementById('sd_checkpoint_hash');
+ var sd_checkpoint_hash = opts.sd_checkpoint_hash || "";
+ var shorthash = sd_checkpoint_hash.substring(0, 10);
- if(elem && elem.textContent != shorthash){
- elem.textContent = shorthash
- elem.title = sd_checkpoint_hash
- elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
- }
-})
+ if (elem && elem.textContent != shorthash) {
+ elem.textContent = shorthash;
+ elem.title = sd_checkpoint_hash;
+ elem.href = "https://google.com/search?q=" + sd_checkpoint_hash;
+ }
+});
let txt2img_textarea, img2img_textarea = undefined;
-let wait_time = 800
-let token_timeouts = {};
-function update_txt2img_tokens(...args) {
- update_token_counter("txt2img_token_button")
- if (args.length == 2)
- return args[0]
- return args;
-}
+function restart_reload() {
+ document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
-function update_img2img_tokens(...args) {
- update_token_counter("img2img_token_button")
- if (args.length == 2)
- return args[0]
- return args;
-}
-
-function update_token_counter(button_id) {
- if (token_timeouts[button_id])
- clearTimeout(token_timeouts[button_id]);
- token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
-}
-
-function restart_reload(){
- document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
-
- var requestPing = function(){
- requestGet("./internal/ping", {}, function(data){
+ var requestPing = function() {
+ requestGet("./internal/ping", {}, function(data) {
location.reload();
- }, function(){
+ }, function() {
setTimeout(requestPing, 500);
- })
- }
+ });
+ };
setTimeout(requestPing, 2000);
- return []
+ return [];
}
// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
// will only visible on web page and not sent to python.
-function updateInput(target){
- let e = new Event("input", { bubbles: true })
- Object.defineProperty(e, "target", {value: target})
- target.dispatchEvent(e);
+function updateInput(target) {
+ let e = new Event("input", {bubbles: true});
+ Object.defineProperty(e, "target", {value: target});
+ target.dispatchEvent(e);
}
var desiredCheckpointName = null;
-function selectCheckpoint(name){
+function selectCheckpoint(name) {
desiredCheckpointName = name;
- gradioApp().getElementById('change_checkpoint').click()
+ gradioApp().getElementById('change_checkpoint').click();
}
-function currentImg2imgSourceResolution(_, _, scaleBy){
- var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img')
- return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]
+function currentImg2imgSourceResolution(w, h, scaleBy) {
+ var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img');
+ return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy];
}
-function updateImg2imgResizeToTextAfterChangingImage(){
+function updateImg2imgResizeToTextAfterChangingImage() {
// At the time this is called from gradio, the image has no yet been replaced.
// There may be a better solution, but this is simple and straightforward so I'm going with it.
setTimeout(function() {
- gradioApp().getElementById('img2img_update_resize_to').click()
+ gradioApp().getElementById('img2img_update_resize_to').click();
}, 500);
- return []
+ return [];
+
+}
+
+
+
+function setRandomSeed(elem_id) {
+ var input = gradioApp().querySelector("#" + elem_id + " input");
+ if (!input) return [];
+
+ input.value = "-1";
+ updateInput(input);
+ return [];
+}
+
+function switchWidthHeight(tabname) {
+ var width = gradioApp().querySelector("#" + tabname + "_width input[type=number]");
+ var height = gradioApp().querySelector("#" + tabname + "_height input[type=number]");
+ if (!width || !height) return [];
+
+ var tmp = width.value;
+ width.value = height.value;
+ height.value = tmp;
+
+ updateInput(width);
+ updateInput(height);
+ return [];
}
diff --git a/javascript/ui_settings_hints.js b/javascript/ui_settings_hints.js
index 87a289d3..d088f949 100644
--- a/javascript/ui_settings_hints.js
+++ b/javascript/ui_settings_hints.js
@@ -1,41 +1,62 @@
-// various hints and extra info for the settings tab
-
-onUiLoaded(function(){
- createLink = function(elem_id, text, href){
- var a = document.createElement('A')
- a.textContent = text
- a.target = '_blank';
-
- elem = gradioApp().querySelector('#'+elem_id)
- elem.insertBefore(a, elem.querySelector('label'))
-
- return a
- }
-
- createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
- createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
-
- createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){
- requestGet("./internal/quicksettings-hint", {}, function(data){
- var table = document.createElement('table')
- table.className = 'settings-value-table'
-
- data.forEach(function(obj){
- var tr = document.createElement('tr')
- var td = document.createElement('td')
- td.textContent = obj.name
- tr.appendChild(td)
-
- var td = document.createElement('td')
- td.textContent = obj.label
- tr.appendChild(td)
-
- table.appendChild(tr)
- })
-
- popup(table);
- })
- });
-})
-
-
+// various hints and extra info for the settings tab
+
+var settingsHintsSetup = false;
+
+onOptionsChanged(function() {
+ if (settingsHintsSetup) return;
+ settingsHintsSetup = true;
+
+ gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) {
+ var name = div.id.substr(8);
+ var commentBefore = opts._comments_before[name];
+ var commentAfter = opts._comments_after[name];
+
+ if (!commentBefore && !commentAfter) return;
+
+ var span = null;
+ if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span');
+ else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild;
+ else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild;
+ else span = div.querySelector('label span').firstChild;
+
+ if (!span) return;
+
+ if (commentBefore) {
+ var comment = document.createElement('DIV');
+ comment.className = 'settings-comment';
+ comment.innerHTML = commentBefore;
+ span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
+ span.parentElement.insertBefore(comment, span);
+ span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
+ }
+ if (commentAfter) {
+ comment = document.createElement('DIV');
+ comment.className = 'settings-comment';
+ comment.innerHTML = commentAfter;
+ span.parentElement.insertBefore(comment, span.nextSibling);
+ span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling);
+ }
+ });
+});
+
+function settingsHintsShowQuicksettings() {
+ requestGet("./internal/quicksettings-hint", {}, function(data) {
+ var table = document.createElement('table');
+ table.className = 'popup-table';
+
+ data.forEach(function(obj) {
+ var tr = document.createElement('tr');
+ var td = document.createElement('td');
+ td.textContent = obj.name;
+ tr.appendChild(td);
+
+ td = document.createElement('td');
+ td.textContent = obj.label;
+ tr.appendChild(td);
+
+ table.appendChild(tr);
+ });
+
+ popup(table);
+ });
+}
diff --git a/launch.py b/launch.py
index cfc0cffa..b103c8f3 100644
--- a/launch.py
+++ b/launch.py
@@ -1,370 +1,38 @@
-# this scripts installs necessary requirements and launches main program in webui.py
-import subprocess
-import os
-import sys
-import importlib.util
-import shlex
-import platform
-import json
+from modules import launch_utils
-from modules import cmd_args
-from modules.paths_internal import script_path, extensions_dir
-commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
-sys.argv += shlex.split(commandline_args)
+args = launch_utils.args
+python = launch_utils.python
+git = launch_utils.git
+index_url = launch_utils.index_url
+dir_repos = launch_utils.dir_repos
-args, _ = cmd_args.parser.parse_known_args()
+commit_hash = launch_utils.commit_hash
+git_tag = launch_utils.git_tag
-python = sys.executable
-git = os.environ.get('GIT', "git")
-index_url = os.environ.get('INDEX_URL', "")
-stored_commit_hash = None
-stored_git_tag = None
-dir_repos = "repositories"
+run = launch_utils.run
+is_installed = launch_utils.is_installed
+repo_dir = launch_utils.repo_dir
-if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
- os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+run_pip = launch_utils.run_pip
+check_run_python = launch_utils.check_run_python
+git_clone = launch_utils.git_clone
+git_pull_recursive = launch_utils.git_pull_recursive
+run_extension_installer = launch_utils.run_extension_installer
+prepare_environment = launch_utils.prepare_environment
+configure_for_tests = launch_utils.configure_for_tests
+start = launch_utils.start
-def check_python_version():
- is_windows = platform.system() == "Windows"
- major = sys.version_info.major
- minor = sys.version_info.minor
- micro = sys.version_info.micro
+def main():
+ if not args.skip_prepare_environment:
+ prepare_environment()
- if is_windows:
- supported_minors = [10]
- else:
- supported_minors = [7, 8, 9, 10, 11]
+ if args.test_server:
+ configure_for_tests()
- if not (major == 3 and minor in supported_minors):
- import modules.errors
-
- modules.errors.print_error_explanation(f"""
-INCOMPATIBLE PYTHON VERSION
-
-This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
-If you encounter an error with "RuntimeError: Couldn't install torch." message,
-or any other error regarding unsuccessful package (library) installation,
-please downgrade (or upgrade) to the latest version of 3.10 Python
-and delete current Python and "venv" folder in WebUI's directory.
-
-You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
-
-{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
-
-Use --skip-python-version-check to suppress this warning.
-""")
-
-
-def commit_hash():
- global stored_commit_hash
-
- if stored_commit_hash is not None:
- return stored_commit_hash
-
- try:
- stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
- except Exception:
- stored_commit_hash = "<none>"
-
- return stored_commit_hash
-
-
-def git_tag():
- global stored_git_tag
-
- if stored_git_tag is not None:
- return stored_git_tag
-
- try:
- stored_git_tag = run(f"{git} describe --tags").strip()
- except Exception:
- stored_git_tag = "<none>"
-
- return stored_git_tag
-
-
-def run(command, desc=None, errdesc=None, custom_env=None, live=False):
- if desc is not None:
- print(desc)
-
- if live:
- result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
- if result.returncode != 0:
- raise RuntimeError(f"""{errdesc or 'Error running command'}.
-Command: {command}
-Error code: {result.returncode}""")
-
- return ""
-
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
-
- if result.returncode != 0:
-
- message = f"""{errdesc or 'Error running command'}.
-Command: {command}
-Error code: {result.returncode}
-stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
-stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
-"""
- raise RuntimeError(message)
-
- return result.stdout.decode(encoding="utf8", errors="ignore")
-
-
-def check_run(command):
- result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
- return result.returncode == 0
-
-
-def is_installed(package):
- try:
- spec = importlib.util.find_spec(package)
- except ModuleNotFoundError:
- return False
-
- return spec is not None
-
-
-def repo_dir(name):
- return os.path.join(script_path, dir_repos, name)
-
-
-def run_python(code, desc=None, errdesc=None):
- return run(f'"{python}" -c "{code}"', desc, errdesc)
-
-
-def run_pip(command, desc=None, live=False):
- if args.skip_install:
- return
-
- index_url_line = f' --index-url {index_url}' if index_url != '' else ''
- return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
-
-
-def check_run_python(code):
- return check_run(f'"{python}" -c "{code}"')
-
-
-def git_clone(url, dir, name, commithash=None):
- # TODO clone into temporary dir and move if successful
-
- if os.path.exists(dir):
- if commithash is None:
- return
-
- current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
- if current_hash == commithash:
- return
-
- run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
- run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
- return
-
- run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
-
- if commithash is not None:
- run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
-
-
-def git_pull_recursive(dir):
- for subdir, _, _ in os.walk(dir):
- if os.path.exists(os.path.join(subdir, '.git')):
- try:
- output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
- print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
- except subprocess.CalledProcessError as e:
- print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
-
-
-def version_check(commit):
- try:
- import requests
- commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
- if commit != "<none>" and commits['commit']['sha'] != commit:
- print("--------------------------------------------------------")
- print("| You are not up to date with the most recent release. |")
- print("| Consider running `git pull` to update. |")
- print("--------------------------------------------------------")
- elif commits['commit']['sha'] == commit:
- print("You are up to date with the most recent release.")
- else:
- print("Not a git clone, can't perform version check.")
- except Exception as e:
- print("version check failed", e)
-
-
-def run_extension_installer(extension_dir):
- path_installer = os.path.join(extension_dir, "install.py")
- if not os.path.isfile(path_installer):
- return
-
- try:
- env = os.environ.copy()
- env['PYTHONPATH'] = os.path.abspath(".")
-
- print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
- except Exception as e:
- print(e, file=sys.stderr)
-
-
-def list_extensions(settings_file):
- settings = {}
-
- try:
- if os.path.isfile(settings_file):
- with open(settings_file, "r", encoding="utf8") as file:
- settings = json.load(file)
- except Exception as e:
- print(e, file=sys.stderr)
-
- disabled_extensions = set(settings.get('disabled_extensions', []))
- disable_all_extensions = settings.get('disable_all_extensions', 'none')
-
- if disable_all_extensions != 'none':
- return []
-
- return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
-
-
-def run_extensions_installers(settings_file):
- if not os.path.isdir(extensions_dir):
- return
-
- for dirname_extension in list_extensions(settings_file):
- run_extension_installer(os.path.join(extensions_dir, dirname_extension))
-
-
-def prepare_environment():
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118")
- requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
-
- xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
- gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
- clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
- openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
-
- stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
- taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
- k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
- codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
- blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
-
- stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
- taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
- k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
- codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
- blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
-
- if not args.skip_python_version_check:
- check_python_version()
-
- commit = commit_hash()
- tag = git_tag()
-
- print(f"Python {sys.version}")
- print(f"Version: {tag}")
- print(f"Commit hash: {commit}")
-
- if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
- run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
-
- if not args.skip_torch_cuda_test:
- run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
-
- if not is_installed("gfpgan"):
- run_pip(f"install {gfpgan_package}", "gfpgan")
-
- if not is_installed("clip"):
- run_pip(f"install {clip_package}", "clip")
-
- if not is_installed("open_clip"):
- run_pip(f"install {openclip_package}", "open_clip")
-
- if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
- if platform.system() == "Windows":
- if platform.python_version().startswith("3.10"):
- run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
- else:
- print("Installation of xformers is not supported in this version of Python.")
- print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
- if not is_installed("xformers"):
- exit(0)
- elif platform.system() == "Linux":
- run_pip(f"install {xformers_package}", "xformers")
-
- if not is_installed("pyngrok") and args.ngrok:
- run_pip("install pyngrok", "ngrok")
-
- os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
-
- git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
- git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
- git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
- git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
- git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
-
- if not is_installed("lpips"):
- run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
-
- if not os.path.isfile(requirements_file):
- requirements_file = os.path.join(script_path, requirements_file)
- run_pip(f"install -r \"{requirements_file}\"", "requirements")
-
- run_extensions_installers(settings_file=args.ui_settings_file)
-
- if args.update_check:
- version_check(commit)
-
- if args.update_all_extensions:
- git_pull_recursive(extensions_dir)
-
- if "--exit" in sys.argv:
- print("Exiting because of --exit argument")
- exit(0)
-
- if args.tests and not args.no_tests:
- exitcode = tests(args.tests)
- exit(exitcode)
-
-
-def tests(test_dir):
- if "--api" not in sys.argv:
- sys.argv.append("--api")
- if "--ckpt" not in sys.argv:
- sys.argv.append("--ckpt")
- sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
- if "--skip-torch-cuda-test" not in sys.argv:
- sys.argv.append("--skip-torch-cuda-test")
- if "--disable-nan-check" not in sys.argv:
- sys.argv.append("--disable-nan-check")
- if "--no-tests" not in sys.argv:
- sys.argv.append("--no-tests")
-
- print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
-
- os.environ['COMMANDLINE_ARGS'] = ""
- with open(os.path.join(script_path, 'test/stdout.txt'), "w", encoding="utf8") as stdout, open(os.path.join(script_path, 'test/stderr.txt'), "w", encoding="utf8") as stderr:
- proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
-
- import test.server_poll
- exitcode = test.server_poll.run_tests(proc, test_dir)
-
- print(f"Stopping Web UI process with id {proc.pid}")
- proc.kill()
- return exitcode
-
-
-def start():
- print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
- import webui
- if '--nowebui' in sys.argv:
- webui.api_only()
- else:
- webui.webui()
+ start()
if __name__ == "__main__":
- prepare_environment()
- start()
+ main()
diff --git a/modules/Roboto-Regular.ttf b/modules/Roboto-Regular.ttf
new file mode 100644
index 00000000..500b1045
--- /dev/null
+++ b/modules/Roboto-Regular.ttf
Binary files differ
diff --git a/modules/api/api.py b/modules/api/api.py
index 9bb95dfd..224bbfc6 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -14,32 +14,31 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
-from modules.api.models import *
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
+from modules.api import models
+from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
-from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
+from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_alisases
+from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
-from typing import List
+from typing import Dict, List, Any
import piexif
import piexif.helper
+from contextlib import closing
-def upscaler_to_index(name: str):
- try:
- return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
- except:
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
def script_name_to_index(name, scripts):
try:
return [script.title().lower() for script in scripts].index(name.lower())
- except:
- raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
+ except Exception as e:
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
+
def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None)
@@ -48,20 +47,23 @@ def validate_sampler_name(name):
return name
+
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict
+
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
try:
image = Image.open(BytesIO(base64.b64decode(encoding)))
return image
- except Exception as err:
- raise HTTPException(status_code=500, detail="Invalid encoded image")
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Invalid encoded image") from e
+
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
@@ -76,6 +78,8 @@ def encode_pil_to_base64(image):
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
+ if image.mode == "RGBA":
+ image = image.convert("RGB")
parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
@@ -92,6 +96,7 @@ def encode_pil_to_base64(image):
return base64.b64encode(bytes_data)
+
def api_middleware(app: FastAPI):
rich_available = True
try:
@@ -99,8 +104,7 @@ def api_middleware(app: FastAPI):
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
- except:
- import traceback
+ except Exception:
rich_available = False
@app.middleware("http")
@@ -131,11 +135,12 @@ def api_middleware(app: FastAPI):
"errors": str(e),
}
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
- print(f"API error: {request.method}: {request.url} {err}")
+ message = f"API error: {request.method}: {request.url} {err}"
if rich_available:
+ print(message)
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
- traceback.print_exc()
+ errors.report(message, exc_info=True)
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
@@ -157,7 +162,7 @@ def api_middleware(app: FastAPI):
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
if shared.cmd_opts.api_auth:
- self.credentials = dict()
+ self.credentials = {}
for auth in shared.cmd_opts.api_auth.split(","):
user, password = auth.split(":")
self.credentials[user] = password
@@ -166,36 +171,44 @@ class Api:
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
- self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
- self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
- self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
- self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
- self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
- self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
+ self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
+ self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
+ self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
+ self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
+ self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
+ self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
- self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
+ self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
- self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
- self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
- self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
- self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
- self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
- self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
- self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
- self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
- self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
+ self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
+ self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
+ self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
- self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
- self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
- self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
- self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
- self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
- self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
+ self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
+ self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
+ self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
+ self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
+ self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
+ self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
- self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
+ self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
+
+ if shared.cmd_opts.api_server_stop:
+ self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
+ self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
@@ -219,17 +232,25 @@ class Api:
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx]
return script, script_idx
-
+
def get_scripts_list(self):
- t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
- i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
+ t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
+ i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
+
+ return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
+
+ def get_script_info(self):
+ res = []
- return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
+ for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
+ res += [script.api_info for script in script_list if script.api_info is not None]
+
+ return res
def get_script(self, script_name, script_runner):
if script_name is None or script_name == "":
return None, None
-
+
script_idx = script_name_to_index(script_name, script_runner.scripts)
return script_runner.scripts[script_idx]
@@ -261,14 +282,14 @@ class Api:
script_args[0] = selectable_idx + 1
# Now check for always on scripts
- if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
+ if request.alwayson_scripts:
for alwayson_script_name in request.alwayson_scripts.keys():
alwayson_script = self.get_script(alwayson_script_name, script_runner)
- if alwayson_script == None:
+ if alwayson_script is None:
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
# Selectable script in always on script param check
- if alwayson_script.alwayson == False:
- raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
+ if alwayson_script.alwayson is False:
+ raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
# always on script with no arg should always run so you don't really need to add them to the requests
if "args" in request.alwayson_scripts[alwayson_script_name]:
# min between arg length in scriptrunner and arg length in the request
@@ -276,7 +297,7 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
- def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
+ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False)
@@ -304,25 +325,25 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_txt2img_grids
- p.outpath_samples = opts.outdir_txt2img_samples
-
- shared.state.begin()
- if selectable_scripts != None:
- p.script_args = script_args
- processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_txt2img_grids
+ p.outpath_samples = opts.outdir_txt2img_samples
+
+ shared.state.begin(job="scripts_txt2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
- return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
+ return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
- def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
+ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
@@ -360,20 +381,20 @@ class Api:
args.pop('save_images', None)
with self.queue_lock:
- p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
- p.init_images = [decode_base64_to_image(x) for x in init_images]
- p.scripts = script_runner
- p.outpath_grids = opts.outdir_img2img_grids
- p.outpath_samples = opts.outdir_img2img_samples
-
- shared.state.begin()
- if selectable_scripts != None:
- p.script_args = script_args
- processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
- else:
- p.script_args = tuple(script_args) # Need to pass args as tuple here
- processed = process_images(p)
- shared.state.end()
+ with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
+ p.init_images = [decode_base64_to_image(x) for x in init_images]
+ p.scripts = script_runner
+ p.outpath_grids = opts.outdir_img2img_grids
+ p.outpath_samples = opts.outdir_img2img_samples
+
+ shared.state.begin(job="scripts_img2img")
+ if selectable_scripts is not None:
+ p.script_args = script_args
+ processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+ else:
+ p.script_args = tuple(script_args) # Need to pass args as tuple here
+ processed = process_images(p)
+ shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -381,9 +402,9 @@ class Api:
img2imgreq.init_images = None
img2imgreq.mask = None
- return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
+ return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
- def extras_single_image_api(self, req: ExtrasSingleImageRequest):
+ def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
reqDict = setUpscalers(req)
reqDict['image'] = decode_base64_to_image(reqDict['image'])
@@ -391,9 +412,9 @@ class Api:
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
- return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
+ return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
- def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
+ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
image_list = reqDict.pop('imageList', [])
@@ -402,15 +423,15 @@ class Api:
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
- return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
+ return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
- def pnginfoapi(self, req: PNGInfoRequest):
+ def pnginfoapi(self, req: models.PNGInfoRequest):
if(not req.image.strip()):
- return PNGInfoResponse(info="")
+ return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip())
if image is None:
- return PNGInfoResponse(info="")
+ return models.PNGInfoResponse(info="")
geninfo, items = images.read_info_from_image(image)
if geninfo is None:
@@ -418,13 +439,13 @@ class Api:
items = {**{'parameters': geninfo}, **items}
- return PNGInfoResponse(info=geninfo, items=items)
+ return models.PNGInfoResponse(info=geninfo, items=items)
- def progressapi(self, req: ProgressRequest = Depends()):
+ def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
- return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
+ return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
# avoid dividing zero
progress = 0.01
@@ -446,9 +467,9 @@ class Api:
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
+ return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
- def interrogateapi(self, interrogatereq: InterrogateRequest):
+ def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
@@ -465,7 +486,7 @@ class Api:
else:
raise HTTPException(status_code=404, detail="Model not found")
- return InterrogateResponse(caption=processed)
+ return models.InterrogateResponse(caption=processed)
def interruptapi(self):
shared.state.interrupt()
@@ -497,6 +518,10 @@ class Api:
return options
def set_config(self, req: Dict[str, Any]):
+ checkpoint_name = req.get("sd_model_checkpoint", None)
+ if checkpoint_name is not None and checkpoint_name not in checkpoint_alisases:
+ raise RuntimeError(f"model {checkpoint_name!r} not found")
+
for k, v in req.items():
shared.opts.set(k, v)
@@ -521,9 +546,20 @@ class Api:
for upscaler in shared.sd_upscalers
]
+ def get_latent_upscale_modes(self):
+ return [
+ {
+ "name": upscale_mode,
+ }
+ for upscale_mode in [*(shared.latent_upscale_modes or {})]
+ ]
+
def get_sd_models(self):
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
+ def get_sd_vaes(self):
+ return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
+
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
@@ -566,44 +602,42 @@ class Api:
def create_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
- shared.state.end()
- return CreateResponse(info=f"create embedding filename: {filename}")
+ return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
+ return models.TrainResponse(info=f"create embedding error: {e}")
+ finally:
shared.state.end()
- return TrainResponse(info=f"create embedding error: {e}")
+
def create_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
- shared.state.end()
- return CreateResponse(info=f"create hypernetwork filename: {filename}")
+ return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e:
+ return models.TrainResponse(info=f"create hypernetwork error: {e}")
+ finally:
shared.state.end()
- return TrainResponse(info=f"create hypernetwork error: {e}")
def preprocess(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
- return PreprocessResponse(info = 'preprocess complete')
+ return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
+ return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
+ except Exception as e:
+ return models.PreprocessResponse(info=f"preprocess error: {e}")
+ finally:
shared.state.end()
- return PreprocessResponse(info=f"preprocess error: invalid token: {e}")
- except AssertionError as e:
- shared.state.end()
- return PreprocessResponse(info=f"preprocess error: {e}")
- except FileNotFoundError as e:
- shared.state.end()
- return PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_embedding")
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -616,15 +650,15 @@ class Api:
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
- shared.state.end()
- return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError as msg:
+ return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
+ except Exception as msg:
+ return models.TrainResponse(info=f"train embedding error: {msg}")
+ finally:
shared.state.end()
- return TrainResponse(info=f"train embedding error: {msg}")
def train_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_hypernetwork")
shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
@@ -641,14 +675,16 @@ class Api:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
- return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError as msg:
+ return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
+ except Exception as exc:
+ return models.TrainResponse(info=f"train embedding error: {exc}")
+ finally:
shared.state.end()
- return TrainResponse(info=f"train embedding error: {error}")
def get_memory(self):
try:
- import os, psutil
+ import os
+ import psutil
process = psutil.Process(os.getpid())
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
@@ -675,11 +711,23 @@ class Api:
'events': warnings,
}
else:
- cuda = { 'error': 'unavailable' }
+ cuda = {'error': 'unavailable'}
except Exception as err:
- cuda = { 'error': f'{err}' }
- return MemoryResponse(ram = ram, cuda = cuda)
+ cuda = {'error': f'{err}'}
+ return models.MemoryResponse(ram=ram, cuda=cuda)
def launch(self, server_name, port):
self.app.include_router(self.router)
- uvicorn.run(self.app, host=server_name, port=port)
+ uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0)
+
+ def kill_webui(self):
+ restart.stop_program()
+
+ def restart_webui(self):
+ if restart.is_restartable():
+ restart.restart_program()
+ return Response(status_code=501)
+
+ def stop_webui(request):
+ shared.state.server_command = "stop"
+ return Response("Stopping.")
diff --git a/modules/api/models.py b/modules/api/models.py
index 4a70f440..b5683071 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -223,8 +223,9 @@ for key in _options:
if(_options[key].dest != 'help'):
flag = _options[key]
_type = str
- if _options[key].default is not None: _type = type(_options[key].default)
- flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
+ if _options[key].default is not None:
+ _type = type(_options[key].default)
+ flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags)
@@ -240,6 +241,9 @@ class UpscalerItem(BaseModel):
model_url: Optional[str] = Field(title="URL")
scale: Optional[float] = Field(title="Scale")
+class LatentUpscalerModeItem(BaseModel):
+ name: str = Field(title="Name")
+
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
@@ -248,6 +252,10 @@ class SDModelItem(BaseModel):
filename: str = Field(title="Filename")
config: Optional[str] = Field(title="Config file")
+class SDVaeItem(BaseModel):
+ model_name: str = Field(title="Model Name")
+ filename: str = Field(title="Filename")
+
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
@@ -266,10 +274,6 @@ class PromptStyleItem(BaseModel):
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
-class ArtistItem(BaseModel):
- name: str = Field(title="Name")
- score: float = Field(title="Score")
- category: str = Field(title="Category")
class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
@@ -286,6 +290,23 @@ class MemoryResponse(BaseModel):
ram: dict = Field(title="RAM", description="System memory stats")
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
+
class ScriptsList(BaseModel):
- txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
- img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") \ No newline at end of file
+ txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
+ img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
+
+
+class ScriptArg(BaseModel):
+ label: str = Field(default=None, title="Label", description="Name of the argument in UI")
+ value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
+ minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
+ maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
+ step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
+ choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+
+
+class ScriptInfo(BaseModel):
+ name: str = Field(default=None, title="Name", description="Script name")
+ is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
+ is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
+ args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 447bb764..3b94f8a4 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -1,10 +1,9 @@
+from functools import wraps
import html
-import sys
import threading
-import traceback
import time
-from modules import shared, progress
+from modules import shared, progress, errors
queue_lock = threading.Lock()
@@ -20,17 +19,18 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
+ @wraps(func)
def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id
- if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+ if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
id_task = args[0]
progress.add_task_to_queue(id_task)
else:
id_task = None
with queue_lock:
- shared.state.begin()
+ shared.state.begin(job=id_task)
progress.start_task(id_task)
try:
@@ -47,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
+ @wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
@@ -56,16 +57,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
try:
res = list(func(*args, **kwargs))
except Exception as e:
- # When printing out our debug argument list, do not print out more than a MB of text
- max_debug_str_len = 131072 # (1024*1024)/8
-
- print("Error completing request", file=sys.stderr)
- argStr = f"Arguments: {args} {kwargs}"
- print(argStr[:max_debug_str_len], file=sys.stderr)
- if len(argStr) > max_debug_str_len:
- print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
-
- print(traceback.format_exc(), file=sys.stderr)
+ # When printing out our debug argument list,
+ # do not print out more than a 100 KB of text
+ max_debug_str_len = 131072
+ message = "Error completing request"
+ arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len]
+ if len(arg_str) > max_debug_str_len:
+ arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)"
+ errors.report(f"{message}\n{arg_str}", exc_info=True)
shared.state.job = ""
shared.state.job_count = 0
@@ -108,4 +107,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
return tuple(res)
return f
-
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index d906a571..278a605e 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -1,6 +1,7 @@
import argparse
+import json
import os
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
parser = argparse.ArgumentParser()
@@ -10,9 +11,9 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
-parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
-parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
-parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
+parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
+parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
+parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
@@ -39,7 +40,8 @@ parser.add_argument("--precision", type=str, help="evaluate at this precision",
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
-parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
+parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
+parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
@@ -51,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
-parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
-parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
+parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
+parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
-parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
-parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
-parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
-parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
-parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
+parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
+parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
+parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
+parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
+parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
@@ -75,6 +77,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
+parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
@@ -102,4 +105,5 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
-parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') \ No newline at end of file
+parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
+parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py
index 11dcc3ee..12db6814 100644
--- a/modules/codeformer/codeformer_arch.py
+++ b/modules/codeformer/codeformer_arch.py
@@ -1,14 +1,12 @@
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
import math
-import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
-from typing import Optional, List
+from typing import Optional
-from modules.codeformer.vqgan_arch import *
-from basicsr.utils import get_root_logger
+from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5):
@@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
-
+
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
@@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
- def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
- connect_list=['32', '64', '128', '256'],
- fix_modules=['quantize','generator']):
+ connect_list=('32', '64', '128', '256'),
+ fix_modules=('quantize', 'generator')):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None:
@@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
- self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False))
-
+
self.channels = {
'16': 512,
'32': 256,
@@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
- x = block(x)
+ x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
@@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
- x = block(x)
+ x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
# logits doesn't need softmax before cross_entropy loss
- return out, logits, lq_feat \ No newline at end of file
+ return out, logits, lq_feat
diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py
index e7293683..09ee6660 100644
--- a/modules/codeformer/vqgan_arch.py
+++ b/modules/codeformer/vqgan_arch.py
@@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-
+
@torch.jit.script
def swish(x):
@@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
- q = q.permute(0, 2, 1)
+ q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
- w_ = torch.bmm(q, k)
+ w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
- w_ = w_.permute(0, 2, 1)
+ w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
@@ -272,18 +270,18 @@ class Encoder(nn.Module):
def forward(self, x):
for block in self.blocks:
x = block(x)
-
+
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
- self.nf = nf
- self.ch_mult = ch_mult
+ self.nf = nf
+ self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
- self.resolution = img_size
+ self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
@@ -317,29 +315,29 @@ class Generator(nn.Module):
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
-
+
def forward(self, x):
for block in self.blocks:
x = block(x)
-
+
return x
-
+
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
- def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
- self.in_channels = 3
- self.nf = nf
- self.n_blocks = res_blocks
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
- self.attn_resolutions = attn_resolutions
+ self.attn_resolutions = attn_resolutions or [16]
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
@@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
self.kl_weight
)
self.generator = Generator(
- self.nf,
+ self.nf,
self.embed_dim,
- self.ch_mult,
- self.n_blocks,
- self.resolution,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
self.attn_resolutions
)
@@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
raise ValueError('Wrong params!')
def forward(self, x):
- return self.main(x) \ No newline at end of file
+ return self.main(x)
diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py
index 8d84bbc9..f293acf5 100644
--- a/modules/codeformer_model.py
+++ b/modules/codeformer_model.py
@@ -1,13 +1,11 @@
import os
-import sys
-import traceback
import cv2
import torch
import modules.face_restoration
import modules.shared
-from modules import shared, devices, modelloader
+from modules import shared, devices, modelloader, errors
from modules.paths import models_path
# codeformer people made a choice to include modified basicsr library to their project which makes
@@ -17,14 +15,11 @@ model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
-have_codeformer = False
codeformer = None
def setup_model(dirname):
- global model_path
- if not os.path.exists(model_path):
- os.makedirs(model_path)
+ os.makedirs(model_path, exist_ok=True)
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
@@ -33,11 +28,9 @@ def setup_model(dirname):
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
- from basicsr.utils.download_util import load_file_from_url
- from basicsr.utils import imwrite, img2tensor, tensor2img
+ from basicsr.utils import img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
- from modules.shared import cmd_opts
net_class = CodeFormer
@@ -96,7 +89,7 @@ def setup_model(dirname):
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
- for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+ for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
@@ -107,8 +100,8 @@ def setup_model(dirname):
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
- except Exception as error:
- print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
+ except Exception:
+ errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
@@ -129,15 +122,11 @@ def setup_model(dirname):
return restored_img
- global have_codeformer
- have_codeformer = True
-
global codeformer
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
except Exception:
- print("Error setting up CodeFormer:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error setting up CodeFormer", exc_info=True)
# sys.path = stored_sys_path
diff --git a/modules/config_states.py b/modules/config_states.py
index 2ea00929..6f1ab53f 100644
--- a/modules/config_states.py
+++ b/modules/config_states.py
@@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c
"""
import os
-import sys
-import traceback
import json
import time
import tqdm
@@ -13,8 +11,8 @@ from datetime import datetime
from collections import OrderedDict
import git
-from modules import shared, extensions
-from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
+from modules import shared, extensions, errors
+from modules.paths_internal import script_path, config_states_dir
all_config_states = OrderedDict()
@@ -35,7 +33,7 @@ def list_config_states():
j["filepath"] = path
config_states.append(j)
- config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
+ config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states:
timestamp = time.asctime(time.gmtime(cs["created_at"]))
@@ -53,8 +51,7 @@ def get_webui_config():
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
webui_remote = None
webui_commit_hash = None
@@ -83,6 +80,8 @@ def get_extension_config():
ext_config = {}
for ext in extensions.extensions:
+ ext.read_info_from_repo()
+
entry = {
"name": ext.name,
"path": ext.path,
@@ -132,8 +131,7 @@ def restore_webui_config(config):
if os.path.exists(os.path.join(script_path, ".git")):
webui_repo = git.Repo(script_path)
except Exception:
- print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
return
try:
@@ -141,8 +139,7 @@ def restore_webui_config(config):
webui_repo.git.reset(webui_commit_hash, hard=True)
print(f"* Restored webui to commit {webui_commit_hash}.")
except Exception:
- print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error restoring webui to commit{webui_commit_hash}")
def restore_extension_config(config):
diff --git a/modules/deepbooru.py b/modules/deepbooru.py
index 122fce7f..547e1b4c 100644
--- a/modules/deepbooru.py
+++ b/modules/deepbooru.py
@@ -2,7 +2,6 @@ import os
import re
import torch
-from PIL import Image
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
@@ -79,7 +78,7 @@ class DeepDanbooru:
res = []
- filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
+ filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag]
diff --git a/modules/devices.py b/modules/devices.py
index c705a3cb..620ed1a6 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,5 +1,7 @@
import sys
import contextlib
+from functools import lru_cache
+
import torch
from modules import errors
@@ -13,13 +15,6 @@ def has_mps() -> bool:
else:
return mac_specific.has_mps
-def extract_device_id(args, name):
- for x in range(len(args)):
- if name in args[x]:
- return args[x + 1]
-
- return None
-
def get_cuda_device_string():
from modules import shared
@@ -65,7 +60,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
- if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
+ if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -154,3 +149,19 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
+
+
+@lru_cache
+def first_time_calculation():
+ """
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
+ spends about 2.7 seconds doing that, at least wih NVidia.
+ """
+
+ x = torch.zeros((1, 1)).to(device, dtype)
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
+ linear(x)
+
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
+ conv2d(x)
diff --git a/modules/errors.py b/modules/errors.py
index f6b80dbb..5271a9fe 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -1,8 +1,42 @@
import sys
+import textwrap
import traceback
+exception_records = []
+
+
+def record_exception():
+ _, e, tb = sys.exc_info()
+ if e is None:
+ return
+
+ if exception_records and exception_records[-1] == e:
+ return
+
+ exception_records.append((e, tb))
+
+ if len(exception_records) > 5:
+ exception_records.pop(0)
+
+
+def report(message: str, *, exc_info: bool = False) -> None:
+ """
+ Print an error message to stderr, with optional traceback.
+ """
+
+ record_exception()
+
+ for line in message.splitlines():
+ print("***", line, file=sys.stderr)
+ if exc_info:
+ print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr)
+ print("---", file=sys.stderr)
+
+
def print_error_explanation(message):
+ record_exception()
+
lines = message.strip().split("\n")
max_len = max([len(x) for x in lines])
@@ -12,9 +46,15 @@ def print_error_explanation(message):
print('=' * max_len, file=sys.stderr)
-def display(e: Exception, task):
+def display(e: Exception, task, *, full_traceback=False):
+ record_exception()
+
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ te = traceback.TracebackException.from_exception(e)
+ if full_traceback:
+ # include frames leading up to the try-catch block
+ te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
+ print(*te.format(), sep="", file=sys.stderr)
message = str(e)
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
@@ -28,6 +68,8 @@ already_displayed = {}
def display_once(e: Exception, task):
+ record_exception()
+
if task in already_displayed:
return
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index f4369257..02a1727d 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,24 +1,20 @@
-import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
-from modules import shared, modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
+from modules import modelloader, images, devices
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
+ items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
@@ -52,9 +48,7 @@ def resrgan2normal(state_dict, nb=23):
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {}
- items = []
- for k, v in state_dict.items():
- items.append(k)
+ items = list(state_dict)
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
@@ -138,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -147,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
- model_dir=self.model_path,
+ model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py
index 6071fea7..2b9888ba 100644
--- a/modules/esrgan_model_arch.py
+++ b/modules/esrgan_model_arch.py
@@ -2,7 +2,6 @@
from collections import OrderedDict
import math
-import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
"""
@@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
x = x + sampled_noise
- return x
+ return x
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
@@ -438,9 +437,11 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
padding = padding if pad_type == 'zero' else 0
if convtype=='PartialConv2D':
+ from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='DeformConv2D':
+ from torchvision.ops import DeformConv2d # not tested
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='Conv3D':
diff --git a/modules/extensions.py b/modules/extensions.py
index 34d9d654..abc6e2b1 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,18 +1,13 @@
import os
-import sys
-import traceback
+import threading
-import time
-from datetime import datetime
-import git
-
-from modules import shared
-from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
+from modules import shared, errors
+from modules.gitpython_hack import Repo
+from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = []
-if not os.path.exists(extensions_dir):
- os.makedirs(extensions_dir)
+os.makedirs(extensions_dir, exist_ok=True)
def active():
@@ -25,6 +20,8 @@ def active():
class Extension:
+ lock = threading.Lock()
+
def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name
self.path = path
@@ -43,15 +40,19 @@ class Extension:
if self.is_builtin or self.have_info_from_repo:
return
- self.have_info_from_repo = True
+ with self.lock:
+ if self.have_info_from_repo:
+ return
+ self.do_read_info_from_repo()
+
+ def do_read_info_from_repo(self):
repo = None
try:
if os.path.exists(os.path.join(self.path, ".git")):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
except Exception:
- print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error reading github repository info from {self.path}", exc_info=True)
if repo is None or repo.bare:
self.remote = None
@@ -59,18 +60,19 @@ class Extension:
try:
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
- head = repo.head.commit
- self.commit_date = repo.head.commit.committed_date
- ts = time.asctime(time.gmtime(self.commit_date))
+ commit = repo.head.commit
+ self.commit_date = commit.committed_date
if repo.active_branch:
self.branch = repo.active_branch.name
- self.commit_hash = head.hexsha
- self.version = f'{self.commit_hash[:8]} ({ts})'
+ self.commit_hash = commit.hexsha
+ self.version = self.commit_hash[:8]
- except Exception as ex:
- print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
+ except Exception:
+ errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True)
self.remote = None
+ self.have_info_from_repo = True
+
def list_files(self, subdir, extension):
from modules import scripts
@@ -87,7 +89,7 @@ class Extension:
return res
def check_updates(self):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
for fetch in repo.remote().fetch(dry_run=True):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
@@ -109,7 +111,7 @@ class Extension:
self.status = "latest"
def fetch_and_reset_hard(self, commit='origin'):
- repo = git.Repo(self.path)
+ repo = Repo(self.path)
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch(all=True)
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 1978673d..41799b0a 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -14,9 +14,26 @@ def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network
+def register_default_extra_networks():
+ from modules.extra_networks_hypernet import ExtraNetworkHypernet
+ register_extra_network(ExtraNetworkHypernet())
+
+
class ExtraNetworkParams:
def __init__(self, items=None):
self.items = items or []
+ self.positional = []
+ self.named = {}
+
+ for item in self.items:
+ parts = item.split('=', 2) if isinstance(item, str) else [item]
+ if len(parts) == 2:
+ self.named[parts[0]] = parts[1]
+ else:
+ self.positional.append(item)
+
+ def __eq__(self, other):
+ return self.items == other.items
class ExtraNetwork:
@@ -86,12 +103,15 @@ def activate(p, extra_network_data):
except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}")
+ if p.scripts is not None:
+ p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
+
def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks"""
- for extra_network_name, extra_network_args in extra_network_data.items():
+ for extra_network_name in extra_network_data:
extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
continue
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py
index 04f27c9f..b6a6dc0e 100644
--- a/modules/extra_networks_hypernet.py
+++ b/modules/extra_networks_hypernet.py
@@ -1,4 +1,4 @@
-from modules import extra_networks, shared, extra_networks
+from modules import extra_networks, shared
from modules.hypernetworks import hypernetwork
@@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
- if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
+ if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
@@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
names = []
multipliers = []
for params in params_list:
- assert len(params.items) > 0
+ assert params.items
names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
diff --git a/modules/extras.py b/modules/extras.py
index ff4e9c4e..e9c0263e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -73,8 +73,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
- shared.state.begin()
- shared.state.job = 'model-merge'
+ shared.state.begin(job="model-merge")
def fail(message):
shared.state.textinfo = message
@@ -136,14 +135,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
result_is_instruct_pix2pix_model = False
if theta_func2:
- shared.state.textinfo = f"Loading B"
+ shared.state.textinfo = "Loading B"
print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
else:
theta_1 = None
if theta_func1:
- shared.state.textinfo = f"Loading C"
+ shared.state.textinfo = "Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
@@ -199,7 +198,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
result_is_inpainting_model = True
else:
theta_0[key] = theta_func2(a, b, multiplier)
-
+
theta_0[key] = to_half(theta_0[key], save_as_half)
shared.state.sampling_step += 1
@@ -242,9 +241,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
- metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
+ metadata = None
if save_metadata:
+ metadata = {"format": "pt"}
+
merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256,
@@ -262,15 +263,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
}
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
+ sd_merge_models = {}
+
def add_model_metadata(checkpoint_info):
checkpoint_info.calculate_shorthash()
- metadata["sd_merge_models"][checkpoint_info.sha256] = {
+ sd_merge_models[checkpoint_info.sha256] = {
"name": checkpoint_info.name,
"legacy_hash": checkpoint_info.hash,
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
}
- metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
+ sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
add_model_metadata(primary_model_info)
if secondary_model_info:
@@ -278,7 +281,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info:
add_model_metadata(tertiary_model_info)
- metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
+ metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index fe8b18b2..a3448be9 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -1,15 +1,12 @@
import base64
-import html
import io
-import math
+import json
import os
import re
-from pathlib import Path
import gradio as gr
from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks
-import tempfile
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
@@ -23,14 +20,14 @@ registered_param_bindings = []
class ParamBinding:
- def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
+ def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
self.paste_button = paste_button
self.tabname = tabname
self.source_text_component = source_text_component
self.source_image_component = source_image_component
self.source_tabname = source_tabname
self.override_settings_component = override_settings_component
- self.paste_field_names = paste_field_names
+ self.paste_field_names = paste_field_names or []
def reset():
@@ -38,20 +35,27 @@ def reset():
def quote(text):
- if ',' not in str(text):
+ if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
return text
- text = str(text)
- text = text.replace('\\', '\\\\')
- text = text.replace('"', '\\"')
- return f'"{text}"'
+ return json.dumps(text, ensure_ascii=False)
+
+
+def unquote(text):
+ if len(text) == 0 or text[0] != '"' or text[-1] != '"':
+ return text
+
+ try:
+ return json.loads(text)
+ except Exception:
+ return text
def image_from_url_text(filedata):
if filedata is None:
return None
- if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
+ if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
filedata = filedata[0]
if type(filedata) == dict and filedata.get("is_file", False):
@@ -170,31 +174,6 @@ def send_image_and_dimensions(x):
return img, w, h
-
-def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
- """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
-
- Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
- parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
-
- If the infotext has no hash, then a hypernet with the same name will be selected instead.
- """
- hypernet_name = hypernet_name.lower()
- if hypernet_hash is not None:
- # Try to match the hash in the name
- for hypernet_key in shared.hypernetworks.keys():
- result = re_hypernet_hash.search(hypernet_key)
- if result is not None and result[1] == hypernet_hash:
- return hypernet_key
- else:
- # Fall back to a hypernet with the same name
- for hypernet_key in shared.hypernetworks.keys():
- if hypernet_key.lower().startswith(hypernet_name):
- return hypernet_key
-
- return None
-
-
def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale"""
@@ -251,28 +230,40 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
lines.append(lastline)
lastline = ''
- for i, line in enumerate(lines):
+ for line in lines:
line = line.strip()
if line.startswith("Negative prompt:"):
done_with_prompt = True
line = line[16:].strip()
-
if done_with_prompt:
negative_prompt += ("" if negative_prompt == "" else "\n") + line
else:
prompt += ("" if prompt == "" else "\n") + line
+ if shared.opts.infotext_styles != "Ignore":
+ found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
+
+ if shared.opts.infotext_styles == "Apply":
+ res["Styles array"] = found_styles
+ elif shared.opts.infotext_styles == "Apply if any" and found_styles:
+ res["Styles array"] = found_styles
+
res["Prompt"] = prompt
res["Negative prompt"] = negative_prompt
for k, v in re_param.findall(lastline):
- v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
- m = re_imagesize.match(v)
- if m is not None:
- res[f"{k}-1"] = m.group(1)
- res[f"{k}-2"] = m.group(2)
- else:
- res[k] = v
+ try:
+ if v[0] == '"' and v[-1] == '"':
+ v = unquote(v)
+
+ m = re_imagesize.match(v)
+ if m is not None:
+ res[f"{k}-1"] = m.group(1)
+ res[f"{k}-2"] = m.group(2)
+ else:
+ res[k] = v
+ except Exception:
+ print(f"Error parsing \"{k}: {v}\"")
# Missing CLIP skip means it was set to 1 (the default)
if "Clip skip" not in res:
@@ -286,17 +277,34 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
res["Hires resize-1"] = 0
res["Hires resize-2"] = 0
+ if "Hires sampler" not in res:
+ res["Hires sampler"] = "Use same sampler"
+
+ if "Hires prompt" not in res:
+ res["Hires prompt"] = ""
+
+ if "Hires negative prompt" not in res:
+ res["Hires negative prompt"] = ""
+
restore_old_hires_fix_params(res)
# Missing RNG means the default was set, which is GPU RNG
if "RNG" not in res:
res["RNG"] = "GPU"
- return res
+ if "Schedule type" not in res:
+ res["Schedule type"] = "Automatic"
+
+ if "Schedule max sigma" not in res:
+ res["Schedule max sigma"] = 0
+ if "Schedule min sigma" not in res:
+ res["Schedule min sigma"] = 0
-settings_map = {}
+ if "Schedule rho" not in res:
+ res["Schedule rho"] = 0
+ return res
infotext_to_setting_name_mapping = [
@@ -304,6 +312,10 @@ infotext_to_setting_name_mapping = [
('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'),
+ ('Schedule type', 'k_sched_type'),
+ ('Schedule max sigma', 'sigma_max'),
+ ('Schedule min sigma', 'sigma_min'),
+ ('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
@@ -312,8 +324,11 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
+ ('Token merging ratio', 'token_merging_ratio'),
+ ('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'),
('NGMS', 's_min_uncond'),
+ ('Pad conds', 'pad_cond_uncond'),
]
@@ -405,7 +420,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
- return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0)
+ return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
paste_fields = paste_fields + [(override_settings_component, paste_settings)]
@@ -422,5 +437,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
outputs=[],
show_progress=False,
)
-
-
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index fbe6215a..8e0f13bd 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,12 +1,10 @@
import os
-import sys
-import traceback
import facexlib
import gfpgan
import modules.face_restoration
-from modules import paths, shared, devices, modelloader
+from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN"
user_path = None
@@ -27,7 +25,7 @@ def gfpgann():
return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
+ if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
@@ -72,13 +70,10 @@ gfpgan_constructor = None
def setup_model(dirname):
- global model_path
- if not os.path.exists(model_path):
- os.makedirs(model_path)
-
try:
+ os.makedirs(model_path, exist_ok=True)
from gfpgan import GFPGANer
- from facexlib import detection, parsing
+ from facexlib import detection, parsing # noqa: F401
global user_path
global have_gfpgan
global gfpgan_constructor
@@ -112,5 +107,4 @@ def setup_model(dirname):
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
- print("Error setting up GFPGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error setting up GFPGAN", exc_info=True)
diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py
new file mode 100644
index 00000000..e537c1df
--- /dev/null
+++ b/modules/gitpython_hack.py
@@ -0,0 +1,42 @@
+from __future__ import annotations
+
+import io
+import subprocess
+
+import git
+
+
+class Git(git.Git):
+ """
+ Git subclassed to never use persistent processes.
+ """
+
+ def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):
+ raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})")
+
+ def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:
+ ret = subprocess.check_output(
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"],
+ input=self._prepare_ref(ref),
+ cwd=self._working_dir,
+ timeout=2,
+ )
+ return self._parse_object_header(ret)
+
+ def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
+ # Not really streaming, per se; this buffers the entire object in memory.
+ # Shouldn't be a problem for our use case, since we're only using this for
+ # object headers (commit objects).
+ ret = subprocess.check_output(
+ [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"],
+ input=self._prepare_ref(ref),
+ cwd=self._working_dir,
+ timeout=30,
+ )
+ bio = io.BytesIO(ret)
+ hexsha, typename, size = self._parse_object_header(bio.readline())
+ return (hexsha, typename, size, self.CatFileContentStream(size, bio))
+
+
+class Repo(git.Repo):
+ GitCommandWrapperType = Git
diff --git a/modules/hashes.py b/modules/hashes.py
index 032120f4..8b7ea0ac 100644
--- a/modules/hashes.py
+++ b/modules/hashes.py
@@ -46,8 +46,8 @@ def calculate_sha256(filename):
return hash_sha256.hexdigest()
-def sha256_from_cache(filename, title):
- hashes = cache("hashes")
+def sha256_from_cache(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
ondisk_mtime = os.path.getmtime(filename)
if title not in hashes:
@@ -62,10 +62,10 @@ def sha256_from_cache(filename, title):
return cached_sha256
-def sha256(filename, title):
- hashes = cache("hashes")
+def sha256(filename, title, use_addnet_hash=False):
+ hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
- sha256_value = sha256_from_cache(filename, title)
+ sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
if sha256_value is not None:
return sha256_value
@@ -73,7 +73,11 @@ def sha256(filename, title):
return None
print(f"Calculating sha256 for {filename}: ", end='')
- sha256_value = calculate_sha256(filename)
+ if use_addnet_hash:
+ with open(filename, "rb") as file:
+ sha256_value = addnet_hash_safetensors(file)
+ else:
+ sha256_value = calculate_sha256(filename)
print(f"{sha256_value}")
hashes[title] = {
@@ -86,6 +90,19 @@ def sha256(filename, title):
return sha256_value
+def addnet_hash_safetensors(b):
+ """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
+ hash_sha256 = hashlib.sha256()
+ blksize = 1024 * 1024
+ b.seek(0)
+ header = b.read(8)
+ n = int.from_bytes(header, "little")
+ offset = n + 8
+ b.seek(offset)
+ for chunk in iter(lambda: b.read(blksize), b""):
+ hash_sha256.update(chunk)
+
+ return hash_sha256.hexdigest()
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 1fc49537..51941c11 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -1,10 +1,7 @@
-import csv
import datetime
import glob
import html
import os
-import sys
-import traceback
import inspect
import modules.textual_inversion.dataset
@@ -12,13 +9,13 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
+from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
-from collections import defaultdict, deque
+from collections import deque
from statistics import stdev, mean
@@ -178,34 +175,34 @@ class Hypernetwork:
def weights(self):
res = []
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
res += layer.parameters()
return res
def train(self, mode=True):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.train(mode=mode)
for param in layer.parameters():
param.requires_grad = mode
def to(self, device):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.to(device)
return self
def set_multiplier(self, multiplier):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.multiplier = multiplier
return self
def eval(self):
- for k, layers in self.layers.items():
+ for layers in self.layers.values():
for layer in layers:
layer.eval()
for param in layer.parameters():
@@ -326,17 +323,14 @@ def load_hypernetwork(name):
if path is None:
return None
- hypernetwork = Hypernetwork()
-
try:
+ hypernetwork = Hypernetwork()
hypernetwork.load(path)
+ return hypernetwork
except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading hypernetwork {path}", exc_info=True)
return None
- return hypernetwork
-
def load_hypernetworks(names, multipliers=None):
already_loaded = {}
@@ -359,17 +353,6 @@ def load_hypernetworks(names, multipliers=None):
shared.loaded_hypernetworks.append(hypernetwork)
-def find_closest_hypernetwork_name(search: str):
- if not search:
- return None
- search = search.lower()
- applicable = [name for name in shared.hypernetworks if search in name.lower()]
- if not applicable:
- return None
- applicable = sorted(applicable, key=lambda name: len(name))
- return applicable[0]
-
-
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
@@ -404,7 +387,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
k = self.to_k(context_k)
v = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@@ -452,18 +435,6 @@ def statistics(data):
return total_information, recent_information
-def report_statistics(loss_info:dict):
- keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
- for key in keys:
- try:
- print("Loss statistics for file " + key)
- info, recent = statistics(list(loss_info[key]))
- print(info)
- print(recent)
- except Exception as e:
- print(e)
-
-
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@@ -541,7 +512,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
+
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
@@ -594,7 +565,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
print(e)
scaler = torch.cuda.amp.GradScaler()
-
+
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
@@ -620,7 +591,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
@@ -637,7 +608,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if clip_grad:
clip_grad_sched.step(hypernetwork.step)
-
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight:
@@ -658,14 +629,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step += loss.item()
scaler.scale(loss).backward()
-
+
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
loss_logging.append(_loss_step)
if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate)
-
+
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
@@ -675,7 +646,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step = 0
steps_done = hypernetwork.step + 1
-
+
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
@@ -771,12 +742,11 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Exception in training hypernetwork", exc_info=True)
finally:
pbar.leave = False
pbar.close()
hypernetwork.eval()
- #report_statistics(loss_dict)
sd_hijack_checkpoint.remove()
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py
index 76599f5a..8b6255e2 100644
--- a/modules/hypernetworks/ui.py
+++ b/modules/hypernetworks/ui.py
@@ -1,19 +1,17 @@
import html
-import os
-import re
import gradio as gr
import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
+ return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
def train_hypernetwork(*args):
diff --git a/modules/images.py b/modules/images.py
index a41965ab..b5412548 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -1,6 +1,6 @@
+from __future__ import annotations
+
import datetime
-import sys
-import traceback
import pytz
import io
@@ -12,18 +12,27 @@ import re
import numpy as np
import piexif
import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
-from fonts.ttf import Roboto
+from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
import string
import json
import hashlib
from modules import sd_samplers, shared, script_callbacks, errors
-from modules.shared import opts, cmd_opts
+from modules.paths_internal import roboto_ttf_file
+from modules.shared import opts
+
+import modules.sd_vae as sd_vae
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+def get_font(fontsize: int):
+ try:
+ return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
+ except Exception:
+ return ImageFont.truetype(roboto_ttf_file, fontsize)
+
+
def image_grid(imgs, batch_size=1, rows=None):
if rows is None:
if opts.n_rows > 0:
@@ -132,6 +141,11 @@ class GridAnnotation:
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
+
+ color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
+ color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
+ color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
+
def wrap(drawing, text, font, line_length):
lines = ['']
for word in text.split():
@@ -142,14 +156,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
lines.append(word)
return lines
- def get_font(fontsize):
- try:
- return ImageFont.truetype(opts.font or Roboto, fontsize)
- except Exception:
- return ImageFont.truetype(Roboto, fontsize)
-
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
- for i, line in enumerate(lines):
+ for line in lines:
fnt = initial_fnt
fontsize = initial_fontsize
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
@@ -167,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
fnt = get_font(fontsize)
- color_active = (0, 0, 0)
- color_inactive = (153, 153, 153)
-
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
cols = im.width // width
@@ -178,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
- calc_img = Image.new("RGB", (1, 1), "white")
+ calc_img = Image.new("RGB", (1, 1), color_background)
calc_d = ImageDraw.Draw(calc_img)
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
@@ -199,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
- result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white")
+ result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
for row in range(rows):
for col in range(cols):
@@ -335,8 +340,20 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator:
+ def get_vae_filename(self): #get the name of the VAE file.
+ if sd_vae.loaded_vae_file is None:
+ return "NoneType"
+ file_name = os.path.basename(sd_vae.loaded_vae_file)
+ split_file_name = file_name.split('.')
+ if len(split_file_name) > 1 and split_file_name[0] == '':
+ return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
+ else:
+ return split_file_name[0]
+
replacements = {
'seed': lambda self: self.seed if self.seed is not None else '',
+ 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
+ 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.image.width,
@@ -353,20 +370,24 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
- 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
- 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
+ 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,
+ 'batch_size': lambda self: self.p.batch_size,
+ 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
+ 'user': lambda self: self.p.user,
+ 'vae_filename': lambda self: self.get_vae_filename(),
}
default_time_format = '%Y%m%d%H%M%S'
- def __init__(self, p, seed, prompt, image):
+ def __init__(self, p, seed, prompt, image, zip=False):
self.p = p
self.seed = seed
self.prompt = prompt
self.image = image
-
+ self.zip = zip
+
def hasprompt(self, *args):
lower = self.prompt.lower()
if self.p is None or self.prompt is None:
@@ -389,7 +410,7 @@ class FilenameGenerator:
prompt_no_style = self.prompt
for style in shared.prompt_styles.get_style_prompts(self.p.styles):
- if len(style) > 0:
+ if style:
for part in style.split("{prompt}"):
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
@@ -398,7 +419,7 @@ class FilenameGenerator:
return sanitize_filename_part(prompt_no_style, replace_spaces=False)
def prompt_words(self):
- words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
+ words = [x for x in re_nonletters.split(self.prompt or "") if x]
if len(words) == 0:
words = ["empty"]
return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)
@@ -406,16 +427,16 @@ class FilenameGenerator:
def datetime(self, *args):
time_datetime = datetime.datetime.now()
- time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
+ time_format = args[0] if (args and args[0] != "") else self.default_time_format
try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
- except pytz.exceptions.UnknownTimeZoneError as _:
+ except pytz.exceptions.UnknownTimeZoneError:
time_zone = None
time_zone_time = time_datetime.astimezone(time_zone)
try:
formatted_time = time_zone_time.strftime(time_format)
- except (ValueError, TypeError) as _:
+ except (ValueError, TypeError):
formatted_time = time_zone_time.strftime(self.default_time_format)
return sanitize_filename_part(formatted_time, replace_spaces=False)
@@ -445,8 +466,7 @@ class FilenameGenerator:
replacement = fun(self, *pattern_args)
except Exception:
replacement = None
- print(f"Error adding [{pattern}] to filename", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error adding [{pattern}] to filename", exc_info=True)
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
continue
@@ -472,15 +492,61 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename)
for p in os.listdir(path):
if p.startswith(basename):
- l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+ parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try:
- result = max(int(l[0]), result)
+ result = max(int(parts[0]), result)
except ValueError:
pass
return result + 1
+def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
+ """
+ Saves image to filename, including geninfo as text information for generation info.
+ For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
+ For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
+ """
+
+ if extension is None:
+ extension = os.path.splitext(filename)[1]
+
+ image_format = Image.registered_extensions()[extension]
+
+ if extension.lower() == '.png':
+ existing_pnginfo = existing_pnginfo or {}
+ if opts.enable_pnginfo:
+ existing_pnginfo[pnginfo_section_name] = geninfo
+
+ if opts.enable_pnginfo:
+ pnginfo_data = PngImagePlugin.PngInfo()
+ for k, v in (existing_pnginfo or {}).items():
+ pnginfo_data.add_text(k, str(v))
+ else:
+ pnginfo_data = None
+
+ image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
+
+ elif extension.lower() in (".jpg", ".jpeg", ".webp"):
+ if image.mode == 'RGBA':
+ image = image.convert("RGB")
+ elif image.mode == 'I;16':
+ image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
+
+ image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
+
+ if opts.enable_pnginfo and geninfo is not None:
+ exif_bytes = piexif.dump({
+ "Exif": {
+ piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
+ },
+ })
+
+ piexif.insert(exif_bytes, filename)
+ else:
+ image.save(filename, format=image_format, quality=opts.jpeg_quality)
+
+
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
"""Save an image.
@@ -565,38 +631,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
info = params.pnginfo.get(pnginfo_section_name, None)
def _atomically_save_image(image_to_save, filename_without_extension, extension):
- # save image with .tmp extension to avoid race condition when another process detects new image in the directory
+ """
+ save image with .tmp extension to avoid race condition when another process detects new image in the directory
+ """
temp_file_path = f"{filename_without_extension}.tmp"
- image_format = Image.registered_extensions()[extension]
- if extension.lower() == '.png':
- pnginfo_data = PngImagePlugin.PngInfo()
- if opts.enable_pnginfo:
- for k, v in params.pnginfo.items():
- pnginfo_data.add_text(k, str(v))
+ save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
-
- elif extension.lower() in (".jpg", ".jpeg", ".webp"):
- if image_to_save.mode == 'RGBA':
- image_to_save = image_to_save.convert("RGB")
- elif image_to_save.mode == 'I;16':
- image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
-
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
-
- if opts.enable_pnginfo and info is not None:
- exif_bytes = piexif.dump({
- "Exif": {
- piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
- },
- })
-
- piexif.insert(exif_bytes, temp_file_path)
- else:
- image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
-
- # atomically rename the file with correct extension
os.replace(temp_file_path, filename_without_extension + extension)
fullfn_without_extension, extension = os.path.splitext(params.filename)
@@ -612,12 +653,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height
-
+ resize_to = None
if oversize and ratio > 1:
- image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS)
+ resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
elif oversize:
- image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS)
+ resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
+ if resize_to is not None:
+ try:
+ # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
+ image = image.resize(resize_to, LANCZOS)
+ except Exception:
+ image = image.resize(resize_to)
try:
_atomically_save_image(image, fullfn_without_extension, ".jpg")
except Exception as e:
@@ -635,8 +682,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
return fullfn, txt_fullfn
-def read_info_from_image(image):
- items = image.info or {}
+IGNORED_INFO_KEYS = {
+ 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
+ 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
+ 'icc_profile', 'chromaticity', 'photoshop',
+}
+
+
+def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
+ items = (image.info or {}).copy()
geninfo = items.pop('parameters', None)
@@ -652,9 +706,8 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment
geninfo = exif_comment
- for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
- 'loop', 'background', 'timestamp', 'duration']:
- items.pop(field, None)
+ for field in IGNORED_INFO_KEYS:
+ items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
@@ -665,8 +718,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error parsing NovelAI image generation parameters", exc_info=True)
return geninfo, items
diff --git a/modules/img2img.py b/modules/img2img.py
index d1872bed..3b83814b 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,23 +1,21 @@
-import math
import os
-import sys
-import traceback
+from pathlib import Path
import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+import gradio as gr
-from modules import devices, sd_samplers
-from modules.generation_parameters_copypaste import create_override_settings_dict
+from modules import sd_samplers, images as imgutil
+from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
-import modules.images as images
import modules.scripts
-def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
+def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
processing.fix_seed(p)
images = []
@@ -31,9 +29,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
is_inpaint_batch = False
if inpaint_mask_dir:
inpaint_masks = shared.listfiles(inpaint_mask_dir)
- is_inpaint_batch = len(inpaint_masks) > 0
- if is_inpaint_batch:
- print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
+ is_inpaint_batch = bool(inpaint_masks)
+
+ if is_inpaint_batch:
+ print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
@@ -44,6 +43,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
state.job_count = len(images) * p.n_iter
+ # extract "default" params to use in case getting png info fails
+ prompt = p.prompt
+ negative_prompt = p.negative_prompt
+ seed = p.seed
+ cfg_scale = p.cfg_scale
+ sampler_name = p.sampler_name
+ steps = p.steps
+
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
@@ -59,23 +66,59 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
+
+ if to_scale:
+ p.width = int(img.width * scale_by)
+ p.height = int(img.height * scale_by)
+
p.init_images = [img] * p.batch_size
+ image_path = Path(image)
if is_inpaint_batch:
# try to find corresponding mask for an image using simple filename matching
- mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
- # if not found use first one ("same mask for all images" use-case)
- if not mask_image_path in inpaint_masks:
+ if len(inpaint_masks) == 1:
mask_image_path = inpaint_masks[0]
+ else:
+ # try to find corresponding mask for an image using simple filename matching
+ mask_image_dir = Path(inpaint_mask_dir)
+ masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*"))
+
+ if len(masks_found) == 0:
+ print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.")
+ continue
+
+ # it should contain only 1 matching mask
+ # otherwise user has many masks with the same name but different extensions
+ mask_image_path = masks_found[0]
+
mask_image = Image.open(mask_image_path)
p.image_mask = mask_image
+ if use_png_info:
+ try:
+ info_img = img
+ if png_info_dir:
+ info_img_path = os.path.join(png_info_dir, os.path.basename(image))
+ info_img = Image.open(info_img_path)
+ geninfo, _ = imgutil.read_info_from_image(info_img)
+ parsed_parameters = parse_generation_parameters(geninfo)
+ parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
+ except Exception:
+ parsed_parameters = {}
+
+ p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
+ p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
+ p.seed = int(parsed_parameters.get("Seed", seed))
+ p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
+ p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
+ p.steps = int(parsed_parameters.get("Steps", steps))
+
proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None:
proc = process_images(p)
for n, processed_image in enumerate(proc.images):
- filename = os.path.basename(image)
+ filename = image_path.name
relpath = os.path.dirname(os.path.relpath(image, input_dir))
if n > 0:
@@ -89,7 +132,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
processed_image.save(os.path.join(output_dir, relpath, filename))
-def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@@ -103,7 +146,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
+ mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, mask).convert('L')
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
@@ -125,7 +169,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if image is not None:
image = ImageOps.exif_transpose(image)
- if selected_scale_tab == 1:
+ if selected_scale_tab == 1 and not is_batch:
assert image, "Can't scale by because no image is selected"
width = int(image.width * scale_by)
@@ -171,6 +215,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.scripts = modules.scripts.scripts_img2img
p.script_args = args
+ p.user = request.username
+
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
@@ -180,7 +226,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
- process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args)
+ process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
processed = Processed(p, [], p.seed, "")
else:
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9f7d657f..a3ae1dd5 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -1,6 +1,5 @@
import os
import sys
-import traceback
from collections import namedtuple
from pathlib import Path
import re
@@ -11,7 +10,6 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-import modules.shared as shared
from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384
@@ -160,7 +158,7 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array))
- text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
+ text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True)
@@ -186,8 +184,7 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
- shared.state.begin()
- shared.state.job = 'interrogate'
+ shared.state.begin(job="interrogate")
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
@@ -208,8 +205,8 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True)
- for name, topn, items in self.categories():
- matches = self.rank(image_features, items, top_count=topn)
+ for cat in self.categories():
+ matches = self.rank(image_features, cat.items, top_count=cat.topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:
res += f", ({match}:{score/100:.3f})"
@@ -217,8 +214,7 @@ class InterrogateModels:
res += f", {match}"
except Exception:
- print("Error interrogating", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error interrogating", exc_info=True)
res += "<error>"
self.unload()
diff --git a/modules/launch_utils.py b/modules/launch_utils.py
new file mode 100644
index 00000000..0e0dbca4
--- /dev/null
+++ b/modules/launch_utils.py
@@ -0,0 +1,344 @@
+# this scripts installs necessary requirements and launches main program in webui.py
+import subprocess
+import os
+import sys
+import importlib.util
+import platform
+import json
+from functools import lru_cache
+
+from modules import cmd_args, errors
+from modules.paths_internal import script_path, extensions_dir
+
+args, _ = cmd_args.parser.parse_known_args()
+
+python = sys.executable
+git = os.environ.get('GIT', "git")
+index_url = os.environ.get('INDEX_URL', "")
+dir_repos = "repositories"
+
+# Whether to default to printing command output
+default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
+
+if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+
+
+def check_python_version():
+ is_windows = platform.system() == "Windows"
+ major = sys.version_info.major
+ minor = sys.version_info.minor
+ micro = sys.version_info.micro
+
+ if is_windows:
+ supported_minors = [10]
+ else:
+ supported_minors = [7, 8, 9, 10, 11]
+
+ if not (major == 3 and minor in supported_minors):
+ import modules.errors
+
+ modules.errors.print_error_explanation(f"""
+INCOMPATIBLE PYTHON VERSION
+
+This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
+If you encounter an error with "RuntimeError: Couldn't install torch." message,
+or any other error regarding unsuccessful package (library) installation,
+please downgrade (or upgrade) to the latest version of 3.10 Python
+and delete current Python and "venv" folder in WebUI's directory.
+
+You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
+
+{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
+
+Use --skip-python-version-check to suppress this warning.
+""")
+
+
+@lru_cache()
+def commit_hash():
+ try:
+ return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
+ except Exception:
+ return "<none>"
+
+
+@lru_cache()
+def git_tag():
+ try:
+ return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
+ except Exception:
+ try:
+ from pathlib import Path
+ changelog_md = Path(__file__).parent.parent / "CHANGELOG.md"
+ with changelog_md.open(encoding="utf-8") as file:
+ return next((line.strip() for line in file if line.strip()), "<none>")
+ except Exception:
+ return "<none>"
+
+
+def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
+ if desc is not None:
+ print(desc)
+
+ run_kwargs = {
+ "args": command,
+ "shell": True,
+ "env": os.environ if custom_env is None else custom_env,
+ "encoding": 'utf8',
+ "errors": 'ignore',
+ }
+
+ if not live:
+ run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
+
+ result = subprocess.run(**run_kwargs)
+
+ if result.returncode != 0:
+ error_bits = [
+ f"{errdesc or 'Error running command'}.",
+ f"Command: {command}",
+ f"Error code: {result.returncode}",
+ ]
+ if result.stdout:
+ error_bits.append(f"stdout: {result.stdout}")
+ if result.stderr:
+ error_bits.append(f"stderr: {result.stderr}")
+ raise RuntimeError("\n".join(error_bits))
+
+ return (result.stdout or "")
+
+
+def is_installed(package):
+ try:
+ spec = importlib.util.find_spec(package)
+ except ModuleNotFoundError:
+ return False
+
+ return spec is not None
+
+
+def repo_dir(name):
+ return os.path.join(script_path, dir_repos, name)
+
+
+def run_pip(command, desc=None, live=default_command_live):
+ if args.skip_install:
+ return
+
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
+ return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
+
+
+def check_run_python(code: str) -> bool:
+ result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
+ return result.returncode == 0
+
+
+def git_clone(url, dir, name, commithash=None):
+ # TODO clone into temporary dir and move if successful
+
+ if os.path.exists(dir):
+ if commithash is None:
+ return
+
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
+ if current_hash == commithash:
+ return
+
+ run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
+ return
+
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
+
+ if commithash is not None:
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
+
+
+def git_pull_recursive(dir):
+ for subdir, _, _ in os.walk(dir):
+ if os.path.exists(os.path.join(subdir, '.git')):
+ try:
+ output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
+ print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
+ except subprocess.CalledProcessError as e:
+ print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
+
+
+def version_check(commit):
+ try:
+ import requests
+ commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
+ if commit != "<none>" and commits['commit']['sha'] != commit:
+ print("--------------------------------------------------------")
+ print("| You are not up to date with the most recent release. |")
+ print("| Consider running `git pull` to update. |")
+ print("--------------------------------------------------------")
+ elif commits['commit']['sha'] == commit:
+ print("You are up to date with the most recent release.")
+ else:
+ print("Not a git clone, can't perform version check.")
+ except Exception as e:
+ print("version check failed", e)
+
+
+def run_extension_installer(extension_dir):
+ path_installer = os.path.join(extension_dir, "install.py")
+ if not os.path.isfile(path_installer):
+ return
+
+ try:
+ env = os.environ.copy()
+ env['PYTHONPATH'] = os.path.abspath(".")
+
+ print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
+ except Exception as e:
+ errors.report(str(e))
+
+
+def list_extensions(settings_file):
+ settings = {}
+
+ try:
+ if os.path.isfile(settings_file):
+ with open(settings_file, "r", encoding="utf8") as file:
+ settings = json.load(file)
+ except Exception:
+ errors.report("Could not load settings", exc_info=True)
+
+ disabled_extensions = set(settings.get('disabled_extensions', []))
+ disable_all_extensions = settings.get('disable_all_extensions', 'none')
+
+ if disable_all_extensions != 'none':
+ return []
+
+ return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
+
+
+def run_extensions_installers(settings_file):
+ if not os.path.isdir(extensions_dir):
+ return
+
+ for dirname_extension in list_extensions(settings_file):
+ run_extension_installer(os.path.join(extensions_dir, dirname_extension))
+
+
+def prepare_environment():
+ torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
+ torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
+ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
+
+ xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
+ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
+ clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
+ openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
+
+ stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
+ k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
+ codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
+ blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
+
+ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
+ k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
+ codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
+ blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
+
+ try:
+ # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
+ os.remove(os.path.join(script_path, "tmp", "restart"))
+ os.environ.setdefault('SD_WEBUI_RESTARTING ', '1')
+ except OSError:
+ pass
+
+ if not args.skip_python_version_check:
+ check_python_version()
+
+ commit = commit_hash()
+ tag = git_tag()
+
+ print(f"Python {sys.version}")
+ print(f"Version: {tag}")
+ print(f"Commit hash: {commit}")
+
+ if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
+
+ if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
+ raise RuntimeError(
+ 'Torch is not able to use GPU; '
+ 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
+ )
+
+ if not is_installed("gfpgan"):
+ run_pip(f"install {gfpgan_package}", "gfpgan")
+
+ if not is_installed("clip"):
+ run_pip(f"install {clip_package}", "clip")
+
+ if not is_installed("open_clip"):
+ run_pip(f"install {openclip_package}", "open_clip")
+
+ if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
+ if platform.system() == "Windows":
+ if platform.python_version().startswith("3.10"):
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
+ else:
+ print("Installation of xformers is not supported in this version of Python.")
+ print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
+ if not is_installed("xformers"):
+ exit(0)
+ elif platform.system() == "Linux":
+ run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
+
+ if not is_installed("ngrok") and args.ngrok:
+ run_pip("install ngrok", "ngrok")
+
+ os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
+
+ git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
+ git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
+ git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
+ git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
+
+ if not is_installed("lpips"):
+ run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
+
+ if not os.path.isfile(requirements_file):
+ requirements_file = os.path.join(script_path, requirements_file)
+ run_pip(f"install -r \"{requirements_file}\"", "requirements")
+
+ run_extensions_installers(settings_file=args.ui_settings_file)
+
+ if args.update_check:
+ version_check(commit)
+
+ if args.update_all_extensions:
+ git_pull_recursive(extensions_dir)
+
+ if "--exit" in sys.argv:
+ print("Exiting because of --exit argument")
+ exit(0)
+
+
+def configure_for_tests():
+ if "--api" not in sys.argv:
+ sys.argv.append("--api")
+ if "--ckpt" not in sys.argv:
+ sys.argv.append("--ckpt")
+ sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
+ if "--skip-torch-cuda-test" not in sys.argv:
+ sys.argv.append("--skip-torch-cuda-test")
+ if "--disable-nan-check" not in sys.argv:
+ sys.argv.append("--disable-nan-check")
+
+ os.environ['COMMANDLINE_ARGS'] = ""
+
+
+def start():
+ print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
+ import webui
+ if '--nowebui' in sys.argv:
+ webui.api_only()
+ else:
+ webui.webui()
diff --git a/modules/localization.py b/modules/localization.py
index ee9c65e7..e8f585da 100644
--- a/modules/localization.py
+++ b/modules/localization.py
@@ -1,8 +1,7 @@
import json
import os
-import sys
-import traceback
+from modules import errors
localizations = {}
@@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str:
with open(fn, "r", encoding="utf8") as file:
data = json.load(file)
except Exception:
- print(f"Error loading localization from {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading localization from {fn}", exc_info=True)
return f"window.localization = {json.dumps(data)}"
diff --git a/modules/lowvram.py b/modules/lowvram.py
index e254cc13..d95bcfbf 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -15,6 +15,8 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram):
+ sd_model.lowvram = True
+
parents = {}
def send_me_to_gpu(module, _):
@@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram):
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
+
+
+def is_enabled(sd_model):
+ return getattr(sd_model, 'lowvram', False)
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 40ce2101..735847f5 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,20 +1,24 @@
import torch
import platform
-from modules import paths
from modules.sd_hijack_utils import CondFunc
from packaging import version
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
+# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
+# use check `getattr` and try it for compatibility.
+# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
+# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
- try:
- torch.zeros(1).to(torch.device("mps"))
- return True
- except Exception:
- return False
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+ else:
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps()
@@ -43,7 +47,7 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
- # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
@@ -61,4 +65,4 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
- CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') \ No newline at end of file
+ CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
diff --git a/modules/masking.py b/modules/masking.py
index a5c4d2da..be9f84c7 100644
--- a/modules/masking.py
+++ b/modules/masking.py
@@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
def get_crop_region(mask, pad=0):
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
-
+
h, w = mask.shape
crop_left = 0
diff --git a/modules/modelloader.py b/modules/modelloader.py
index a70aa0e3..098bcb79 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,4 +1,5 @@
-import glob
+from __future__ import annotations
+
import os
import shutil
import importlib
@@ -9,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+def load_file_from_url(
+ url: str,
+ *,
+ model_dir: str,
+ progress: bool = True,
+ file_name: str | None = None,
+) -> str:
+ """Download a file from `url` into `model_dir`, using the file present if possible.
+
+ Returns the path to the downloaded file.
+ """
+ os.makedirs(model_dir, exist_ok=True)
+ if not file_name:
+ parts = urlparse(url)
+ file_name = os.path.basename(parts.path)
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ from torch.hub import download_url_to_file
+ download_url_to_file(url, cached_file, progress=progress)
+ return cached_file
+
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -40,16 +64,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if os.path.islink(full_path) and not os.path.exists(full_path):
print(f"Skipping broken symlink: {full_path}")
continue
- if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
+ if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
continue
if full_path not in output:
output.append(full_path)
if model_url is not None and len(output) == 0:
if download_name is not None:
- from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, model_path, True, download_name)
- output.append(dl)
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)
@@ -60,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str):
- if "http" in file:
+ if file.startswith("http"):
file = urlparse(file).path
file = os.path.basename(file)
@@ -96,8 +118,7 @@ def cleanup_models():
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try:
- if not os.path.exists(dest_path):
- os.makedirs(dest_path)
+ os.makedirs(dest_path, exist_ok=True)
if os.path.exists(src_path):
for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file)
@@ -108,12 +129,12 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
print(f"Moving {file} from {src_path} to {dest_path}.")
try:
shutil.move(fullpath, dest_path)
- except:
+ except Exception:
pass
if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True)
- except:
+ except Exception:
pass
@@ -127,7 +148,7 @@ def load_upscalers():
full_model = f"modules.{model_name}_model"
try:
importlib.import_module(full_model)
- except:
+ except Exception:
pass
datas = []
@@ -145,7 +166,10 @@ def load_upscalers():
for cls in reversed(used_classes.values()):
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
- scaler = cls(commandline_options.get(cmd_name, None))
+ commandline_model_path = commandline_options.get(cmd_name, None)
+ scaler = cls(commandline_model_path)
+ scaler.user_path = commandline_model_path
+ scaler.model_download_path = commandline_model_path or scaler.model_path
datas += scaler.scalers
shared.sd_upscalers = sorted(
diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py
index f880bc3c..b892d5fc 100644
--- a/modules/models/diffusion/ddpm_edit.py
+++ b/modules/models/diffusion/ddpm_edit.py
@@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
# If initialing from EMA-only checkpoint, create EMA model after loading.
if self.use_ema and not load_ema:
@@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
if context is not None:
print(f"{context}: Restored training weights")
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
+ ignore_keys = ignore_keys or []
+
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
@@ -228,9 +230,9 @@ class DDPM(pl.LightningModule):
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
+ if missing:
print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
+ if unexpected:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
@@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
+ log = {}
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
@@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
log["inputs"] = x
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
x_start = x[:n_row]
for t in range(self.num_timesteps):
@@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
+ super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
+ except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
@@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
- def rescale_bbox(bbox):
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
- return x0, y0, w, h
-
- return [rescale_bbox(b) for b in bboxes]
-
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
@@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
return img, intermediates
@torch.no_grad()
@@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
if return_intermediates:
return img, intermediates
@@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ [x[:batch_size] for x in cond[key]] for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
@@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
use_ddim = False
- log = dict()
+ log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
@@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
if plot_diffusion_rows:
# get diffusion row
- diffusion_row = list()
+ diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
if inpaint:
# make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+ super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+ logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py
index e1265e3f..dbb35964 100644
--- a/modules/models/diffusion/uni_pc/__init__.py
+++ b/modules/models/diffusion/uni_pc/__init__.py
@@ -1 +1 @@
-from .sampler import UniPCSampler
+from .sampler import UniPCSampler # noqa: F401
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py
index a241c8a7..0a9defa1 100644
--- a/modules/models/diffusion/uni_pc/sampler.py
+++ b/modules/models/diffusion/uni_pc/sampler.py
@@ -54,7 +54,8 @@ class UniPCSampler(object):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
- while isinstance(ctmp, list): ctmp = ctmp[0]
+ while isinstance(ctmp, list):
+ ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py
index 11b330bc..d257a728 100644
--- a/modules/models/diffusion/uni_pc/uni_pc.py
+++ b/modules/models/diffusion/uni_pc/uni_pc.py
@@ -1,7 +1,6 @@
import torch
-import torch.nn.functional as F
import math
-from tqdm.auto import trange
+import tqdm
class NoiseScheduleVP:
@@ -179,13 +178,13 @@ def model_wrapper(
model,
noise_schedule,
model_type="noise",
- model_kwargs={},
+ model_kwargs=None,
guidance_type="uncond",
#condition=None,
#unconditional_condition=None,
guidance_scale=1.,
classifier_fn=None,
- classifier_kwargs={},
+ classifier_kwargs=None,
):
"""Create a wrapper function for the noise prediction model.
@@ -276,6 +275,9 @@ def model_wrapper(
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
+ model_kwargs = model_kwargs or {}
+ classifier_kwargs = classifier_kwargs or {}
+
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
@@ -342,7 +344,7 @@ def model_wrapper(
t_in = torch.cat([t_continuous] * 2)
if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict)
- c_in = dict()
+ c_in = {}
for k in condition:
if isinstance(condition[k], list):
c_in[k] = [torch.cat([
@@ -353,7 +355,7 @@ def model_wrapper(
unconditional_condition[k],
condition[k]])
elif isinstance(condition, list):
- c_in = list()
+ c_in = []
assert isinstance(unconditional_condition, list)
for i in range(len(condition)):
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
@@ -757,40 +759,44 @@ class UniPC:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
- # Init the first `order` values by lower order multistep DPM-Solver.
- for init_order in range(1, order):
- vec_t = timesteps[init_order].expand(x.shape[0])
- x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
- if model_x is None:
- model_x = self.model_fn(x, vec_t)
- if self.after_update is not None:
- self.after_update(x, model_x)
- model_prev_list.append(model_x)
- t_prev_list.append(vec_t)
- for step in trange(order, steps + 1):
- vec_t = timesteps[step].expand(x.shape[0])
- if lower_order_final:
- step_order = min(order, steps + 1 - step)
- else:
- step_order = order
- #print('this step order:', step_order)
- if step == steps:
- #print('do not run corrector at the last step')
- use_corrector = False
- else:
- use_corrector = True
- x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
- if self.after_update is not None:
- self.after_update(x, model_x)
- for i in range(order - 1):
- t_prev_list[i] = t_prev_list[i + 1]
- model_prev_list[i] = model_prev_list[i + 1]
- t_prev_list[-1] = vec_t
- # We do not need to evaluate the final model value.
- if step < steps:
+ with tqdm.tqdm(total=steps) as pbar:
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in range(1, order):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
if model_x is None:
model_x = self.model_fn(x, vec_t)
- model_prev_list[-1] = model_x
+ if self.after_update is not None:
+ self.after_update(x, model_x)
+ model_prev_list.append(model_x)
+ t_prev_list.append(vec_t)
+ pbar.update()
+
+ for step in range(order, steps + 1):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ #print('this step order:', step_order)
+ if step == steps:
+ #print('do not run corrector at the last step')
+ use_corrector = False
+ else:
+ use_corrector = True
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
+ if self.after_update is not None:
+ self.after_update(x, model_x)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ if model_x is None:
+ model_x = self.model_fn(x, vec_t)
+ model_prev_list[-1] = model_x
+ pbar.update()
else:
raise NotImplementedError()
if denoise_to_zero:
diff --git a/modules/ngrok.py b/modules/ngrok.py
index 7a7b4b26..0c713e27 100644
--- a/modules/ngrok.py
+++ b/modules/ngrok.py
@@ -1,6 +1,7 @@
-from pyngrok import ngrok, conf, exception
+import ngrok
-def connect(token, port, region):
+# Connect to ngrok for ingress
+def connect(token, port, options):
account = None
if token is None:
token = 'None'
@@ -10,28 +11,19 @@ def connect(token, port, region):
token, username, password = token.split(':', 2)
account = f"{username}:{password}"
- config = conf.PyngrokConfig(
- auth_token=token, region=region
- )
-
- # Guard for existing tunnels
- existing = ngrok.get_tunnels(pyngrok_config=config)
- if existing:
- for established in existing:
- # Extra configuration in the case that the user is also using ngrok for other tunnels
- if established.config['addr'][-4:] == str(port):
- public_url = existing[0].public_url
- print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
- 'You can use this link after the launch is complete.')
- return
-
+ # For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
+ if not options.get('authtoken_from_env'):
+ options['authtoken'] = token
+ if account:
+ options['basic_auth'] = account
+ if not options.get('session_metadata'):
+ options['session_metadata'] = 'stable-diffusion-webui'
+
+
try:
- if account is None:
- public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
- else:
- public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
- except exception.PyngrokNgrokError:
- print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
+ public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
+ except Exception as e:
+ print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
else:
print(f'ngrok connected to localhost:{port}! URL: {public_url}\n'
diff --git a/modules/paths.py b/modules/paths.py
index acf1894b..bada804e 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -1,8 +1,8 @@
import os
import sys
-from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
-import modules.safe
+import modules.safe # noqa: F401
# data_path = cmd_opts_pre.data
@@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
- (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
@@ -39,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
else:
sys.path.append(d)
paths[what] = d
-
-
-class Prioritize:
- def __init__(self, name):
- self.name = name
- self.path = None
-
- def __enter__(self):
- self.path = sys.path.copy()
- sys.path = [paths[self.name]] + sys.path
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.path = self.path
- self.path = None
diff --git a/modules/paths_internal.py b/modules/paths_internal.py
index 6765bafe..005a9b0a 100644
--- a/modules/paths_internal.py
+++ b/modules/paths_internal.py
@@ -2,8 +2,14 @@
import argparse
import os
+import sys
+import shlex
-script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
+sys.argv += shlex.split(commandline_args)
+
+modules_path = os.path.dirname(os.path.realpath(__file__))
+script_path = os.path.dirname(modules_path)
sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
@@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser_pre = argparse.ArgumentParser(add_help=False)
-parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
+parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
cmd_opts_pre = parser_pre.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
@@ -21,3 +27,5 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
+
+roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 736315e2..136e9c88 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -9,8 +9,7 @@ from modules.shared import opts
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
- shared.state.begin()
- shared.state.job = 'extras'
+ shared.state.begin(job="extras")
image_data = []
image_names = []
@@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
for image, name in zip(image_data, image_names):
shared.state.textinfo = name
- existing_pnginfo = image.info or {}
+ parameters, existing_pnginfo = images.read_info_from_image(image)
+ if parameters:
+ existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
diff --git a/modules/processing.py b/modules/processing.py
index 1a76e552..21d1492c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,20 +1,20 @@
import json
+import logging
import math
import os
import sys
-import warnings
import hashlib
import torch
import numpy as np
-from PIL import Image, ImageFilter, ImageOps
+from PIL import Image, ImageOps
import random
import cv2
from skimage import exposure
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -24,13 +24,13 @@ import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
-import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@@ -106,6 +106,9 @@ class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
+ cached_uc = [None, None]
+ cached_c = [None, None]
+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -150,6 +153,8 @@ class StableDiffusionProcessing:
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
self.disable_extra_networks = False
+ self.token_merging_ratio = 0
+ self.token_merging_ratio_hr = 0
if not seed_enable_extras:
self.subseed = -1
@@ -165,7 +170,21 @@ class StableDiffusionProcessing:
self.all_subseeds = None
self.iteration = 0
self.is_hr_pass = False
-
+ self.sampler = None
+
+ self.prompts = None
+ self.negative_prompts = None
+ self.extra_network_data = None
+ self.seeds = None
+ self.subseeds = None
+
+ self.step_multiplier = 1
+ self.cached_uc = StableDiffusionProcessing.cached_uc
+ self.cached_c = StableDiffusionProcessing.cached_c
+ self.uc = None
+ self.c = None
+
+ self.user = None
@property
def sd_model(self):
@@ -273,6 +292,64 @@ class StableDiffusionProcessing:
def close(self):
self.sampler = None
+ self.c = None
+ self.uc = None
+ if not opts.experimental_persistent_cond_cache:
+ StableDiffusionProcessing.cached_c = [None, None]
+ StableDiffusionProcessing.cached_uc = [None, None]
+
+ def get_token_merging_ratio(self, for_hr=False):
+ if for_hr:
+ return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
+
+ return self.token_merging_ratio or opts.token_merging_ratio
+
+ def setup_prompts(self):
+ if type(self.prompt) == list:
+ self.all_prompts = self.prompt
+ else:
+ self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
+
+ if type(self.negative_prompt) == list:
+ self.all_negative_prompts = self.negative_prompt
+ else:
+ self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
+
+ self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
+ self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
+
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
+ """
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
+ using a cache to store the result if the same arguments have been used before.
+
+ cache is an array containing two elements. The first element is a tuple
+ representing the previously used arguments, or None if no arguments
+ have been used before. The second element is where the previously
+ computed result is stored.
+
+ caches is a list with items described above.
+ """
+ for cache in caches:
+ if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
+ return cache[1]
+
+ cache = caches[0]
+
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+
+ cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
+ return cache[1]
+
+ def setup_conds(self):
+ sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
+ self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+
+ def parse_extra_network_prompts(self):
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
class Processed:
@@ -303,6 +380,8 @@ class Processed:
self.styles = p.styles
self.job_timestamp = state.job_timestamp
self.clip_skip = opts.CLIP_stop_at_last_layers
+ self.token_merging_ratio = p.token_merging_ratio
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
self.eta = p.eta
self.ddim_discretize = p.ddim_discretize
@@ -310,6 +389,7 @@ class Processed:
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
+ self.s_min_uncond = p.s_min_uncond
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
@@ -360,6 +440,9 @@ class Processed:
def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
@@ -468,10 +551,17 @@ def program_version():
return res
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
+ enable_hr = getattr(p, 'enable_hr', False)
+ token_merging_ratio = p.get_token_merging_ratio()
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
+
+ uses_ensd = opts.eta_noise_seed_delta != 0
+ if uses_ensd:
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
generation_params = {
"Steps": p.steps,
@@ -485,27 +575,33 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
- "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
- "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
+ "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
+ "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
+ **p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
+ "User": p.user if opts.add_user_name_to_info else None,
}
- generation_params.update(p.extra_generation_params)
-
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ prompt_text = p.prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
- return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
+ if p.scripts is not None:
+ p.scripts.before_process(p)
+
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
@@ -523,9 +619,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
+ sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
+
res = process_images_inner(p)
finally:
+ sd_models.apply_token_merging(p.sd_model, 0)
+
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@@ -555,15 +655,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
comments = {}
- if type(p.prompt) == list:
- p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
- else:
- p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
-
- if type(p.negative_prompt) == list:
- p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
- else:
- p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
+ p.setup_prompts()
if type(seed) == list:
p.all_seeds = seed
@@ -575,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
- def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -587,29 +679,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
- cached_uc = [None, None]
- cached_c = [None, None]
-
- def get_conds_with_caching(function, required_prompts, steps, cache):
- """
- Returns the result of calling function(shared.sd_model, required_prompts, steps)
- using a cache to store the result if the same arguments have been used before.
-
- cache is an array containing two elements. The first element is a tuple
- representing the previously used arguments, or None if no arguments
- have been used before. The second element is where the previously
- computed result is stored.
- """
-
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
- return cache[1]
-
- with devices.autocast():
- cache[1] = function(shared.sd_model, required_prompts, steps)
-
- cache[0] = (required_prompts, steps)
- return cache[1]
-
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -618,10 +687,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
+ sd_unet.apply_unet()
+
if state.job_count == -1:
state.job_count = p.n_iter
- extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@@ -631,25 +701,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
- prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if p.scripts is not None:
- p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
- if len(prompts) == 0:
+ if len(p.prompts) == 0:
break
- prompts, extra_network_data = extra_networks.parse_prompts(prompts)
+ p.parse_extra_network_prompts()
if not p.disable_extra_networks:
with devices.autocast():
- extra_networks.activate(p, extra_network_data)
+ extra_networks.activate(p, p.extra_network_data)
if p.scripts is not None:
- p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
@@ -660,14 +730,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- step_multiplier = 1
- if not shared.opts.dont_fix_second_order_samplers_schedule:
- try:
- step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
- except:
- pass
- uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
- c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
+ p.setup_conds()
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -677,7 +740,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
@@ -688,7 +751,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
del samples_ddim
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -704,7 +767,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
@@ -721,13 +784,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
- images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
text = infotext(n, i)
infotexts.append(text)
@@ -740,10 +803,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
- images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
if opts.save_mask_composite:
- images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
@@ -765,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
- text = infotext()
+ text = infotext(use_main_prompt=True)
infotexts.insert(0, text)
if opts.enable_pnginfo:
grid.info["parameters"] = text
@@ -773,10 +836,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- if not p.disable_extra_networks and extra_network_data:
- extra_networks.deactivate(p, extra_network_data)
+ if not p.disable_extra_networks and p.extra_network_data:
+ extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc()
@@ -785,7 +848,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images_list=output_images,
seed=p.all_seeds[0],
info=infotext(),
- comments="".join(f"\n\n{comment}" for comment in comments),
+ comments="".join(f"{comment}\n" for comment in comments),
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
infotexts=infotexts,
@@ -811,8 +874,10 @@ def old_hires_fix_first_pass_dimensions(width, height):
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
+ cached_hr_uc = [None, None]
+ cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -823,6 +888,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
+ self.hr_sampler_name = hr_sampler_name
+ self.hr_prompt = hr_prompt
+ self.hr_negative_prompt = hr_negative_prompt
+ self.all_hr_prompts = None
+ self.all_hr_negative_prompts = None
if firstphase_width != 0 or firstphase_height != 0:
self.hr_upscale_to_x = self.width
@@ -834,8 +904,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = 0
self.applied_old_hires_behavior_to = None
+ self.hr_prompts = None
+ self.hr_negative_prompts = None
+ self.hr_extra_network_data = None
+
+ self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
+ self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
+ self.hr_c = None
+ self.hr_uc = None
+
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
+
+ if tuple(self.hr_prompt) != tuple(self.prompt):
+ self.extra_generation_params["Hires prompt"] = self.hr_prompt
+
+ if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
+ self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
self.hr_resize_y = self.height
@@ -901,7 +989,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
- assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
@@ -965,9 +1054,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- img2img_sampler_name = self.sampler_name
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
img2img_sampler_name = 'DDIM'
+
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
@@ -978,17 +1069,101 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ if not self.disable_extra_networks:
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
+ with devices.autocast():
+ self.calculate_hr_conds()
+
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+
+ if self.scripts is not None:
+ self.scripts.before_hr(self)
+
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.is_hr_pass = False
return samples
+ def close(self):
+ super().close()
+ self.hr_c = None
+ self.hr_uc = None
+ if not opts.experimental_persistent_cond_cache:
+ StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
+ StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
+
+ def setup_prompts(self):
+ super().setup_prompts()
+
+ if not self.enable_hr:
+ return
+
+ if self.hr_prompt == '':
+ self.hr_prompt = self.prompt
+
+ if self.hr_negative_prompt == '':
+ self.hr_negative_prompt = self.negative_prompt
+
+ if type(self.hr_prompt) == list:
+ self.all_hr_prompts = self.hr_prompt
+ else:
+ self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
+
+ if type(self.hr_negative_prompt) == list:
+ self.all_hr_negative_prompts = self.hr_negative_prompt
+ else:
+ self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
+
+ self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
+ self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+
+ def calculate_hr_conds(self):
+ if self.hr_c is not None:
+ return
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+
+ def setup_conds(self):
+ super().setup_conds()
+
+ self.hr_uc = None
+ self.hr_c = None
+
+ if self.enable_hr:
+ if shared.opts.hires_fix_use_firstpass_conds:
+ self.calculate_hr_conds()
+
+ elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
+ self.calculate_hr_conds()
+
+ with devices.autocast():
+ extra_networks.activate(self, self.extra_network_data)
+
+ def parse_extra_network_prompts(self):
+ res = super().parse_extra_network_prompts()
+
+ if self.enable_hr:
+ self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+ self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+
+ self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
+
+ return res
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -999,7 +1174,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_mask = mask
self.latent_mask = None
self.mask_for_overlay = None
- self.mask_blur = mask_blur
+ if mask_blur is not None:
+ mask_blur_x = mask_blur
+ mask_blur_y = mask_blur
+ self.mask_blur_x = mask_blur_x
+ self.mask_blur_y = mask_blur_y
self.inpainting_fill = inpainting_fill
self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding
@@ -1021,8 +1200,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
- if self.mask_blur > 0:
- image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
+ if self.mask_blur_x > 0:
+ np_mask = np.array(image_mask)
+ kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
+ image_mask = Image.fromarray(np_mask)
+
+ if self.mask_blur_y > 0:
+ np_mask = np.array(image_mask)
+ kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
+ image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
self.mask_for_overlay = image_mask
@@ -1141,3 +1329,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
devices.torch_gc()
return samples
+
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
diff --git a/modules/progress.py b/modules/progress.py
index 948e6f00..f405f07f 100644
--- a/modules/progress.py
+++ b/modules/progress.py
@@ -95,9 +95,20 @@ def progressapi(req: ProgressRequest):
image = shared.state.current_image
if image is not None:
buffered = io.BytesIO()
- image.save(buffered, format="png")
+
+ if opts.live_previews_image_format == "png":
+ # using optimize for large images takes an enormous amount of time
+ if max(*image.size) <= 256:
+ save_kwargs = {"optimize": True}
+ else:
+ save_kwargs = {"optimize": False, "compress_level": 1}
+
+ else:
+ save_kwargs = {}
+
+ image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
- live_preview = f"data:image/png;base64,{base64_image}"
+ live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
id_live_preview = shared.state.id_live_preview
else:
live_preview = None
diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 69665372..0069d8b0 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
"""
def collect_steps(steps, tree):
- l = [steps]
+ res = [steps]
+
class CollectSteps(lark.Visitor):
def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1:
tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1]))
- l.append(tree.children[-1])
+ res.append(tree.children[-1])
+
def alternate(self, tree):
- l.extend(range(1, steps+1))
+ res.extend(range(1, steps+1))
+
CollectSteps().visit(tree)
- return sorted(set(l))
+ return sorted(set(res))
def at_step(step, tree):
class AtStep(lark.Transformer):
@@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
def get_schedule(prompt):
try:
tree = schedule_parser.parse(prompt)
- except lark.exceptions.LarkError as e:
+ except lark.exceptions.LarkError:
if 0:
import traceback
traceback.print_exc()
@@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
conds = model.get_learned_conditioning(texts)
cond_schedule = []
- for i, (end_at_step, text) in enumerate(prompt_schedule):
+ for i, (end_at_step, _) in enumerate(prompt_schedule):
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
cache[prompt] = cond_schedule
@@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c):
target_index = 0
- for current, (end_at, cond) in enumerate(cond_schedule):
- if current_step <= end_at:
+ for current, entry in enumerate(cond_schedule):
+ if current_step <= entry.end_at_step:
target_index = current
break
res[i] = cond_schedule[target_index].cond
@@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
tensors = []
conds_list = []
- for batch_no, composable_prompts in enumerate(c.batch):
+ for composable_prompts in c.batch:
conds_for_batch = []
- for cond_index, composable_prompt in enumerate(composable_prompts):
+ for composable_prompt in composable_prompts:
target_index = 0
- for current, (end_at, cond) in enumerate(composable_prompt.schedules):
- if current_step <= end_at:
+ for current, entry in enumerate(composable_prompt.schedules):
+ if current_step <= entry.end_at_step:
target_index = current
break
@@ -333,11 +336,11 @@ def parse_prompt_attention(text):
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
- elif weight is not None and len(round_brackets) > 0:
+ elif weight is not None and round_brackets:
multiply_range(round_brackets.pop(), float(weight))
- elif text == ')' and len(round_brackets) > 0:
+ elif text == ')' and round_brackets:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
- elif text == ']' and len(square_brackets) > 0:
+ elif text == ']' and square_brackets:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
parts = re.split(re_break, text)
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py
index efd7fca5..0700b853 100644
--- a/modules/realesrgan_model.py
+++ b/modules/realesrgan_model.py
@@ -1,15 +1,13 @@
import os
-import sys
-import traceback
import numpy as np
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
-from modules import modelloader
+from modules import modelloader, errors
+
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
@@ -17,9 +15,9 @@ class UpscalerRealESRGAN(Upscaler):
self.user_path = path
super().__init__()
try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+ from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
+ from realesrgan import RealESRGANer # noqa: F401
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True
self.scalers = []
scalers = self.load_models(path)
@@ -36,8 +34,7 @@ class UpscalerRealESRGAN(Upscaler):
self.scalers.append(scaler)
except Exception:
- print("Error importing Real-ESRGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error importing Real-ESRGAN", exc_info=True)
self.enable = False
self.scalers = []
@@ -45,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable:
return img
- info = self.load_model(path)
- if not os.path.exists(info.local_data_path):
- print(f"Unable to load RealESRGAN model: {info.name}")
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
upsampler = RealESRGANer(
@@ -65,21 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image
def load_model(self, path):
- try:
- info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
- if info is None:
- print(f"Unable to find model info: {path}")
- return None
-
- if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
-
- return info
- except Exception as e:
- print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- return None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
@@ -134,6 +128,5 @@ def get_realesrgan_models(scaler):
),
]
return models
- except Exception as e:
- print("Error making Real-ESRGAN models list:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ except Exception:
+ errors.report("Error making Real-ESRGAN models list", exc_info=True)
diff --git a/modules/restart.py b/modules/restart.py
new file mode 100644
index 00000000..18eacaf3
--- /dev/null
+++ b/modules/restart.py
@@ -0,0 +1,23 @@
+import os
+from pathlib import Path
+
+from modules.paths_internal import script_path
+
+
+def is_restartable() -> bool:
+ """
+ Return True if the webui is restartable (i.e. there is something watching to restart it with)
+ """
+ return bool(os.environ.get('SD_WEBUI_RESTART'))
+
+
+def restart_program() -> None:
+ """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
+
+ (Path(script_path) / "tmp" / "restart").touch()
+
+ stop_program()
+
+
+def stop_program() -> None:
+ os._exit(0)
diff --git a/modules/safe.py b/modules/safe.py
index e1a67f73..b1d08a79 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -2,8 +2,6 @@
import pickle
import collections
-import sys
-import traceback
import torch
import numpy
@@ -11,7 +9,10 @@ import _codecs
import zipfile
import re
+
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
+from modules import errors
+
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
@@ -95,16 +96,16 @@ def check_pt(filename, extra_handler):
except zipfile.BadZipfile:
- # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
+ # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
- for i in range(5):
+ for _ in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
- return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
+ return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
@@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ errors.report(
+ f"Error verifying pickled file from {filename}\n"
+ "-----> !!!! The file is most likely corrupted !!!! <-----\n"
+ "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
+ exc_info=True,
+ )
return None
-
except Exception:
- print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ errors.report(
+ f"Error verifying pickled file from {filename}\n"
+ f"The file may be malicious, so the program is not going to read it.\n"
+ f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
+ exc_info=True,
+ )
return None
return unsafe_torch_load(filename, *args, **kwargs)
@@ -190,4 +194,3 @@ with safe.Extra(handler):
unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None
-
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 17109732..77ee55ee 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -1,16 +1,16 @@
-import sys
-import traceback
-from collections import namedtuple
import inspect
+import os
+from collections import namedtuple
from typing import Optional, Dict, Any
from fastapi import FastAPI
from gradio import Blocks
+from modules import errors, timer
+
def report_exception(c, job):
- print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error executing callback {job} for {c.script}", exc_info=True)
class ImageSaveParams:
@@ -32,27 +32,42 @@ class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x
"""Latent image representation in the process of being denoised"""
-
+
self.image_cond = image_cond
"""Conditioning image"""
-
+
self.sigma = sigma
"""Current sigma noise step value"""
-
+
self.sampling_step = sampling_step
"""Current Sampling step number"""
-
+
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
-
+
self.text_cond = text_cond
""" Encoder hidden states of text conditioning from prompt"""
-
+
self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt"""
class CFGDenoisedParams:
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+ self.inner_model = inner_model
+ """Inner model reference used for denoising"""
+
+
+class AfterCFGCallbackParams:
def __init__(self, x, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
@@ -87,6 +102,7 @@ callback_map = dict(
callbacks_image_saved=[],
callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[],
+ callbacks_cfg_after_cfg=[],
callbacks_before_component=[],
callbacks_after_component=[],
callbacks_image_grid=[],
@@ -94,6 +110,8 @@ callback_map = dict(
callbacks_script_unloaded=[],
callbacks_before_ui=[],
callbacks_on_reload=[],
+ callbacks_list_optimizers=[],
+ callbacks_list_unets=[],
)
@@ -106,6 +124,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
+ timer.startup_timer.record(os.path.basename(c.script))
except Exception:
report_exception(c, 'app_started_callback')
@@ -186,6 +205,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
report_exception(c, 'cfg_denoised_callback')
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
+ for c in callback_map['callbacks_cfg_after_cfg']:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_after_cfg_callback')
+
+
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
try:
@@ -234,16 +261,40 @@ def before_ui_callback():
report_exception(c, 'before_ui')
+def list_optimizers_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_optimizers']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_optimizers')
+
+ return res
+
+
+def list_unets_callback():
+ res = []
+
+ for c in callback_map['callbacks_list_unets']:
+ try:
+ c.callback(res)
+ except Exception:
+ report_exception(c, 'list_unets')
+
+ return res
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
- filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ filename = stack[0].filename if stack else 'unknown file'
callbacks.append(ScriptCallback(filename, fun))
-
+
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
- filename = stack[0].filename if len(stack) > 0 else 'unknown file'
+ filename = stack[0].filename if stack else 'unknown file'
if filename == 'unknown file':
return
for callback_list in callback_map.values():
@@ -332,6 +383,14 @@ def on_cfg_denoised(callback):
add_callback(callback_map['callbacks_cfg_denoised'], callback)
+def on_cfg_after_cfg(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
+ The callback is called with one argument:
+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
+ """
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
+
+
def on_before_component(callback):
"""register a function to be called before a component is created.
The callback is called with arguments:
@@ -377,3 +436,18 @@ def on_before_ui(callback):
"""register a function to be called before the UI is created."""
add_callback(callback_map['callbacks_before_ui'], callback)
+
+
+def on_list_optimizers(callback):
+ """register a function to be called when UI is making a list of cross attention optimization options.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
+ to it."""
+
+ add_callback(callback_map['callbacks_list_optimizers'], callback)
+
+
+def on_list_unets(callback):
+ """register a function to be called when UI is making a list of alternative options for unet.
+ The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
+
+ add_callback(callback_map['callbacks_list_unets'], callback)
diff --git a/modules/script_loading.py b/modules/script_loading.py
index a7d2203f..306a1f35 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -1,8 +1,7 @@
import os
-import sys
-import traceback
import importlib.util
-from types import ModuleType
+
+from modules import errors
def load_module(path):
@@ -28,5 +27,4 @@ def preload_extensions(extensions_dir, parser):
module.preload(parser)
except Exception:
- print(f"Error running preload() for {preload_script}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running preload() for {preload_script}", exc_info=True)
diff --git a/modules/scripts.py b/modules/scripts.py
index d945b89f..7d9dd59f 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -1,12 +1,12 @@
import os
import re
import sys
-import traceback
+import inspect
from collections import namedtuple
import gradio as gr
-from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
AlwaysVisible = object()
@@ -17,6 +17,12 @@ class PostprocessImageArgs:
class Script:
+ name = None
+ """script's internal name derived from title"""
+
+ section = None
+ """name of UI section that the script's controls will be placed into"""
+
filename = None
args_from = None
args_to = None
@@ -25,8 +31,8 @@ class Script:
is_txt2img = False
is_img2img = False
- """A gr.Group component that has all script's UI inside it"""
group = None
+ """A gr.Group component that has all script's UI inside it"""
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@@ -38,6 +44,9 @@ class Script:
various "Send to <X>" buttons when clicked
"""
+ api_info = None
+ """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -76,6 +85,15 @@ class Script:
pass
+ def before_process(self, p, *args):
+ """
+ This function is called very early before processing begins for AlwaysVisible scripts.
+ You can modify the processing object (p) here, inject hooks, etc.
+ args contains all values returned by components from ui()
+ """
+
+ pass
+
def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
@@ -99,6 +117,21 @@ class Script:
pass
+ def after_extra_networks_activate(self, p, *args, **kwargs):
+ """
+ Calledafter extra networks activation, before conds calculation
+ allow modification of the network after extra networks activation been applied
+ won't be call if p.disable_extra_networks
+
+ **kwargs will have those items:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
+ - seeds - list of seeds for current batch
+ - subseeds - list of subseeds for current batch
+ - extra_network_data - list of ExtraNetworkParams for current stage
+ """
+ pass
+
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
@@ -169,6 +202,11 @@ class Script:
return f'script_{tabname}{title}_{item_id}'
+ def before_hr(self, p, *args):
+ """
+ This function is called before hires fix start.
+ """
+ pass
current_basedir = paths.script_path
@@ -231,8 +269,8 @@ def load_scripts():
syspath = sys.path
def register_scripts_from_module(module):
- for key, script_class in module.__dict__.items():
- if type(script_class) != type:
+ for script_class in module.__dict__.values():
+ if not inspect.isclass(script_class):
continue
if issubclass(script_class, Script):
@@ -258,21 +296,25 @@ def load_scripts():
register_scripts_from_module(script_module)
except Exception:
- print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True)
finally:
sys.path = syspath
current_basedir = paths.script_path
+ timer.startup_timer.record(scriptfile.filename)
+
+ global scripts_txt2img, scripts_img2img, scripts_postproc
+
+ scripts_txt2img = ScriptRunner()
+ scripts_img2img = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
- res = func(*args, **kwargs)
- return res
+ return func(*args, **kwargs)
except Exception:
- print(f"Error calling: {filename}/{funcname}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error calling: {filename}/{funcname}", exc_info=True)
return default
@@ -285,6 +327,7 @@ class ScriptRunner:
self.titles = []
self.infotext_fields = []
self.paste_field_names = []
+ self.inputs = [None]
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -295,9 +338,9 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
- for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
- script = script_class()
- script.filename = path
+ for script_data in auto_processing_scripts + scripts_data:
+ script = script_data.script_class()
+ script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
@@ -312,49 +355,74 @@ class ScriptRunner:
self.scripts.append(script)
self.selectable_scripts.append(script)
- def setup_ui(self):
- self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+ def create_script_ui(self, script):
+ import modules.api.models as api_models
- inputs = [None]
- inputs_alwayson = [True]
+ script.args_from = len(self.inputs)
+ script.args_to = len(self.inputs)
- def create_script_ui(script, inputs, inputs_alwayson):
- script.args_from = len(inputs)
- script.args_to = len(inputs)
+ controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
- controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
+ if controls is None:
+ return
- if controls is None:
- return
+ script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+ api_args = []
- for control in controls:
- control.custom_script_source = os.path.basename(script.filename)
+ for control in controls:
+ control.custom_script_source = os.path.basename(script.filename)
- if script.infotext_fields is not None:
- self.infotext_fields += script.infotext_fields
+ arg_info = api_models.ScriptArg(label=control.label or "")
- if script.paste_field_names is not None:
- self.paste_field_names += script.paste_field_names
+ for field in ("value", "minimum", "maximum", "step", "choices"):
+ v = getattr(control, field, None)
+ if v is not None:
+ setattr(arg_info, field, v)
- inputs += controls
- inputs_alwayson += [script.alwayson for _ in controls]
- script.args_to = len(inputs)
+ api_args.append(arg_info)
- for script in self.alwayson_scripts:
- with gr.Group() as group:
- create_script_ui(script, inputs, inputs_alwayson)
+ script.api_info = api_models.ScriptInfo(
+ name=script.name,
+ is_img2img=script.is_img2img,
+ is_alwayson=script.alwayson,
+ args=api_args,
+ )
- script.group = group
+ if script.infotext_fields is not None:
+ self.infotext_fields += script.infotext_fields
- dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
- inputs[0] = dropdown
+ if script.paste_field_names is not None:
+ self.paste_field_names += script.paste_field_names
+
+ self.inputs += controls
+ script.args_to = len(self.inputs)
+
+ def setup_ui_for_section(self, section, scriptlist=None):
+ if scriptlist is None:
+ scriptlist = self.alwayson_scripts
+
+ for script in scriptlist:
+ if script.alwayson and script.section != section:
+ continue
- for script in self.selectable_scripts:
- with gr.Group(visible=False) as group:
- create_script_ui(script, inputs, inputs_alwayson)
+ with gr.Group(visible=script.alwayson) as group:
+ self.create_script_ui(script)
script.group = group
+ def prepare_ui(self):
+ self.inputs = [None]
+
+ def setup_ui(self):
+ self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+
+ self.setup_ui_for_section(None)
+
+ dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
+ self.inputs[0] = dropdown
+
+ self.setup_ui_for_section(None, self.selectable_scripts)
+
def select_script(script_index):
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
@@ -378,6 +446,7 @@ class ScriptRunner:
)
self.script_load_ctr = 0
+
def onload_script_visibility(params):
title = params.get('Script', None)
if title:
@@ -388,10 +457,10 @@ class ScriptRunner:
else:
return gr.update(visible=False)
- self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
- self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
+ self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
+ self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
- return inputs
+ return self.inputs
def run(self, p, *args):
script_index = args[0]
@@ -411,14 +480,21 @@ class ScriptRunner:
return processed
+ def before_process(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_process(p, *script_args)
+ except Exception:
+ errors.report(f"Error running before_process: {script.filename}", exc_info=True)
+
def process(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
- print(f"Error running process: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running process: {script.filename}", exc_info=True)
def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -426,8 +502,15 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running before_process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
+
+ def after_extra_networks_activate(self, p, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.after_extra_networks_activate(p, *script_args, **kwargs)
+ except Exception:
+ errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
@@ -435,8 +518,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
- print(f"Error running process_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
@@ -444,8 +526,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
- print(f"Error running postprocess: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
@@ -453,8 +534,7 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
@@ -462,24 +542,21 @@ class ScriptRunner:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
except Exception:
- print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
script.before_component(component, **kwargs)
except Exception:
- print(f"Error running before_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
- print(f"Error running after_component: {script.filename}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error running after_component: {script.filename}", exc_info=True)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
@@ -492,7 +569,7 @@ class ScriptRunner:
module = script_loading.load_module(script.filename)
cache[filename] = module
- for key, script_class in module.__dict__.items():
+ for script_class in module.__dict__.values():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
@@ -500,9 +577,18 @@ class ScriptRunner:
self.scripts[si].args_to = args_to
-scripts_txt2img = ScriptRunner()
-scripts_img2img = ScriptRunner()
-scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+ def before_hr(self, p):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.before_hr(p, *script_args)
+ except Exception:
+ errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
+
+
+scripts_txt2img: ScriptRunner = None
+scripts_img2img: ScriptRunner = None
+scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
scripts_current: ScriptRunner = None
@@ -512,14 +598,7 @@ def reload_script_body_only():
scripts_img2img.reload_sources(cache)
-def reload_scripts():
- global scripts_txt2img, scripts_img2img, scripts_postproc
-
- load_scripts()
-
- scripts_txt2img = ScriptRunner()
- scripts_img2img = ScriptRunner()
- scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
+reload_scripts = load_scripts # compatibility alias
def add_classes_to_gradio_component(comp):
diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py
index 30d6d658..d63078de 100644
--- a/modules/scripts_auto_postprocessing.py
+++ b/modules/scripts_auto_postprocessing.py
@@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args):
- args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
+ args_dict = dict(zip(self.postprocessing_controls, args))
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {}
diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py
index b11568c0..bac1335d 100644
--- a/modules/scripts_postprocessing.py
+++ b/modules/scripts_postprocessing.py
@@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
def initialize_scripts(self, scripts_data):
self.scripts = []
- for script_class, path, basedir, script_module in scripts_data:
- script: ScriptPostprocessing = script_class()
- script.filename = path
+ for script_data in scripts_data:
+ script: ScriptPostprocessing = script_data.script_class()
+ script.filename = script_data.path
if script.name == "Simple Upscale":
continue
@@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
script_args = args[script.args_from:script.args_to]
process_args = {}
- for (name, component), value in zip(script.controls.items(), script_args):
+ for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value
script.process(pp, **process_args)
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index c4a09d15..9fc89dc6 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -61,7 +61,7 @@ class DisableInitialization:
if res is None:
res = original(url, *args, local_files_only=False, **kwargs)
return res
- except Exception as e:
+ except Exception:
return original(url, *args, local_files_only=False, **kwargs)
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index f4bb0266..3b6f95ce 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -28,57 +28,65 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
+optimizers = []
+current_optimizer: sd_hijack_optimizations.SdOptimization = None
+
+
+def list_optimizers():
+ new_optimizers = script_callbacks.list_optimizers_callback()
+
+ new_optimizers = [x for x in new_optimizers if x.is_available()]
+
+ new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)
+
+ optimizers.clear()
+ optimizers.extend(new_optimizers)
+
+
+def apply_optimizations(option=None):
+ global current_optimizer
-def apply_optimizations():
undo_optimizations()
+ if len(optimizers) == 0:
+ # a script can access the model very early, and optimizations would not be filled by then
+ current_optimizer = None
+ return ''
+
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
-
- optimization_method = None
- can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
-
- if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
- print("Applying xformers cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
- optimization_method = 'xformers'
- elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
- print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
- optimization_method = 'sdp-no-mem'
- elif cmd_opts.opt_sdp_attention and can_use_sdp:
- print("Applying scaled dot product cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
- optimization_method = 'sdp'
- elif cmd_opts.opt_sub_quad_attention:
- print("Applying sub-quadratic cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
- optimization_method = 'sub-quadratic'
- elif cmd_opts.opt_split_attention_v1:
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'V1'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'InvokeAI'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
- print("Applying cross attention optimization (Doggettx).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
- ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
- optimization_method = 'Doggettx'
-
- return optimization_method
+ if current_optimizer is not None:
+ current_optimizer.undo()
+ current_optimizer = None
+
+ selection = option or shared.opts.cross_attention_optimization
+ if selection == "Automatic" and len(optimizers) > 0:
+ matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
+ else:
+ matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)
+
+ if selection == "None":
+ matching_optimizer = None
+ elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
+ matching_optimizer = None
+ elif matching_optimizer is None:
+ matching_optimizer = optimizers[0]
+
+ if matching_optimizer is not None:
+ print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
+ matching_optimizer.apply()
+ print("done.")
+ current_optimizer = matching_optimizer
+ return current_optimizer.name
+ else:
+ print("Disabling attention optimization")
+ return ''
def undo_optimizations():
- ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
@@ -92,12 +100,12 @@ def fix_checkpoint():
def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False)
-
+
#Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None:
loss *= weight
-
+
#Return the loss, as mean if specified
return loss.mean() if mean else loss
@@ -105,7 +113,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w
-
+
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'):
@@ -118,9 +126,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try:
#Delete temporary weights if appended
del sd_model._custom_loss_weight
- except AttributeError as e:
+ except AttributeError:
pass
-
+
#If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss
@@ -133,7 +141,7 @@ def apply_weighted_forward(sd_model):
def undo_weighted_forward(sd_model):
try:
del sd_model.weighted_forward
- except AttributeError as e:
+ except AttributeError:
pass
@@ -150,6 +158,13 @@ class StableDiffusionModelHijack:
def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
+ def apply_optimizations(self, option=None):
+ try:
+ self.optimization_method = apply_optimizations(option)
+ except Exception as e:
+ errors.display(e, "applying cross attention optimization")
+ undo_optimizations()
+
def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
@@ -169,7 +184,7 @@ class StableDiffusionModelHijack:
if m.cond_stage_key == "edit":
sd_hijack_unet.hijack_ddpm_edit()
- self.optimization_method = apply_optimizations()
+ self.apply_optimizations()
self.clip = m.cond_stage_model
@@ -182,9 +197,14 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
+ if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
+ ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
+
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
- m.cond_stage_model = m.cond_stage_model.wrapped
+ m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
@@ -203,6 +223,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
+ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
+
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
@@ -216,10 +238,17 @@ class StableDiffusionModelHijack:
self.comments = []
def get_prompt_lengths(self, text):
+ if self.clip is None:
+ return "-", "-"
+
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count)
+ def redo_hijack(self, m):
+ self.undo_hijack(m)
+ self.hijack(m)
+
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings):
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 9fa5c5c5..3b5a7666 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
chunk.multipliers += [weight] * emb_len
position += embedding_length_in_tokens
- if len(chunk.tokens) > 0 or len(chunks) == 0:
+ if chunk.tokens or not chunks:
next_chunk(is_last=True)
return chunks, token_count
@@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack.fixes = [x.fixes for x in batch_chunk]
for fixes in self.hijack.fixes:
- for position, embedding in fixes:
+ for _position, embedding in fixes:
used_embeddings[embedding.name] = embedding
z = self.process_tokens(tokens, multipliers)
diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py
index a3476e95..c5c6270b 100644
--- a/modules/sd_hijack_clip_old.py
+++ b/modules/sd_hijack_clip_old.py
@@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
self.hijack.comments += hijack_comments
- if len(used_custom_terms) > 0:
+ if used_custom_terms:
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index 55a2ce4d..c1977b19 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -1,16 +1,10 @@
-import os
import torch
-from einops import repeat
-from omegaconf import ListConfig
-
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
-from ldm.models.diffusion.ddpm import LatentDiffusion
-from ldm.models.diffusion.plms import PLMSSampler
-from ldm.models.diffusion.ddim import DDIMSampler, noise_like
+from ldm.models.diffusion.ddim import noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
@@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
- c_in = dict()
+ c_in = {}
for k in c:
if isinstance(c[k], list):
c_in[k] = [
diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py
index 3c727d3b..6fe6b6ff 100644
--- a/modules/sd_hijack_ip2p.py
+++ b/modules/sd_hijack_ip2p.py
@@ -1,8 +1,5 @@
-import collections
import os.path
-import sys
-import gc
-import time
+
def should_hijack_ip2p(checkpoint_info):
from modules import sd_models_config
@@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
- return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
+ return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index f10865cd..53e27ade 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -1,6 +1,5 @@
+from __future__ import annotations
import math
-import sys
-import traceback
import psutil
import torch
@@ -9,10 +8,129 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, errors, devices
+from modules import shared, errors, devices, sub_quadratic_attention
from modules.hypernetworks import hypernetwork
-from .sub_quadratic_attention import efficient_dot_product_attention
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.model
+
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
+
+class SdOptimization:
+ name: str = None
+ label: str | None = None
+ cmd_opt: str | None = None
+ priority: int = 0
+
+ def title(self):
+ if self.label is None:
+ return self.name
+
+ return f"{self.name} - {self.label}"
+
+ def is_available(self):
+ return True
+
+ def apply(self):
+ pass
+
+ def undo(self):
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
+
+class SdOptimizationXformers(SdOptimization):
+ name = "xformers"
+ cmd_opt = "xformers"
+ priority = 100
+
+ def is_available(self):
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+
+
+class SdOptimizationSdpNoMem(SdOptimization):
+ name = "sdp-no-mem"
+ label = "scaled dot product without memory efficient attention"
+ cmd_opt = "opt_sdp_no_mem_attention"
+ priority = 80
+
+ def is_available(self):
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+
+
+class SdOptimizationSdp(SdOptimizationSdpNoMem):
+ name = "sdp"
+ label = "scaled dot product"
+ cmd_opt = "opt_sdp_attention"
+ priority = 70
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+
+
+class SdOptimizationSubQuad(SdOptimization):
+ name = "sub-quadratic"
+ cmd_opt = "opt_sub_quad_attention"
+ priority = 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+
+
+class SdOptimizationV1(SdOptimization):
+ name = "V1"
+ label = "original v1"
+ cmd_opt = "opt_split_attention_v1"
+ priority = 10
+
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+
+
+class SdOptimizationInvokeAI(SdOptimization):
+ name = "InvokeAI"
+ cmd_opt = "opt_split_attention_invokeai"
+
+ @property
+ def priority(self):
+ return 1000 if not torch.cuda.is_available() else 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+
+
+class SdOptimizationDoggettx(SdOptimization):
+ name = "Doggettx"
+ cmd_opt = "opt_split_attention"
+ priority = 90
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+
+
+def list_optimizers(res):
+ res.extend([
+ SdOptimizationXformers(),
+ SdOptimizationSdpNoMem(),
+ SdOptimizationSdp(),
+ SdOptimizationSubQuad(),
+ SdOptimizationV1(),
+ SdOptimizationInvokeAI(),
+ SdOptimizationDoggettx(),
+ ])
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -20,8 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
import xformers.ops
shared.xformers_available = True
except Exception:
- print("Cannot import xformers", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Cannot import xformers", exc_info=True)
def get_available_vram():
@@ -49,7 +166,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
v_in = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -62,10 +179,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
-
+
s2 = s1.softmax(dim=-1)
del s1
-
+
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
del q, k, v
@@ -95,43 +212,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale
-
+
del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
-
+
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
-
+
mem_free_total = get_available_vram()
-
+
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
-
+
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
-
+
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
-
+
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
-
+
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
-
+
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
-
+
del q, k, v
r1 = r1.to(dtype)
@@ -228,8 +345,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
r = einsum_op(q, k, v)
r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
@@ -296,11 +413,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
- query_chunk_size = q_tokens
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
- return efficient_dot_product_attention(
+ return sub_quadratic_attention.efficient_dot_product_attention(
q,
k,
v,
@@ -335,7 +451,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -370,7 +486,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
-
+
del q_in, k_in, v_in
dtype = q.dtype
@@ -452,7 +568,7 @@ def cross_attention_attnblock_forward(self, x):
h3 += x
return h3
-
+
def xformers_attnblock_forward(self, x):
try:
h_ = x
@@ -461,7 +577,7 @@ def xformers_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -483,10 +599,10 @@ def sdp_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
+ q, k, v = q.float(), k.float(), v.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
@@ -507,7 +623,7 @@ def sub_quad_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py
index 4ac51c38..28528329 100644
--- a/modules/sd_hijack_xlmr.py
+++ b/modules/sd_hijack_xlmr.py
@@ -1,8 +1,6 @@
-import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
-from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 36f643e1..f65f4e36 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,10 +14,10 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
-from modules.paths import models_path
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
+import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -87,8 +87,7 @@ class CheckpointInfo:
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging, CLIPModel
+ from transformers import logging, CLIPModel # noqa: F401
logging.set_verbosity_error()
except Exception:
@@ -96,10 +95,8 @@ except Exception:
def setup_model():
- if not os.path.exists(model_path):
- os.makedirs(model_path)
+ os.makedirs(model_path, exist_ok=True)
- list_models()
enable_midas_autodownload()
@@ -166,21 +163,22 @@ def model_hash(filename):
def select_checkpoint():
+ """Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint
-
+
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
if len(checkpoints_list) == 0:
- print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ error_message = "No checkpoints found. When searching for checkpoints, looked at:"
if shared.cmd_opts.ckpt is not None:
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {model_path}", file=sys.stderr)
+ error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
+ error_message += f"\n - directory {model_path}"
if shared.cmd_opts.ckpt_dir is not None:
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
- exit(1)
+ error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
+ error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
+ raise FileNotFoundError(error_message)
checkpoint_info = next(iter(checkpoints_list.values()))
if model_checkpoint is not None:
@@ -239,7 +237,7 @@ def read_metadata_from_safetensors(filename):
if isinstance(v, str) and v[0:1] == '{':
try:
res[k] = json.loads(v)
- except Exception as e:
+ except Exception:
pass
return res
@@ -249,7 +247,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+
+ if not shared.opts.disable_mmap_load_safetensors:
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
+ else:
+ pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
+ pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -315,8 +318,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
timer.record("apply half()")
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
- devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
devices.dtype_unet = model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
@@ -374,7 +375,7 @@ def enable_midas_autodownload():
if not os.path.exists(path):
if not os.path.exists(midas_path):
mkdir(midas_path)
-
+
print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded")
@@ -410,15 +411,22 @@ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_w
class SdModelData:
def __init__(self):
self.sd_model = None
+ self.was_loaded_at_least_once = False
self.lock = threading.Lock()
def get_sd_model(self):
+ if self.was_loaded_at_least_once:
+ return self.sd_model
+
if self.sd_model is None:
with self.lock:
+ if self.sd_model is not None or self.was_loaded_at_least_once:
+ return self.sd_model
+
try:
load_model()
except Exception as e:
- errors.display(e, "loading stable diffusion model")
+ errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
@@ -467,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
- except Exception as e:
+ except Exception:
pass
if sd_model is None:
@@ -493,6 +501,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model.eval()
model_data.sd_model = sd_model
+ model_data.was_loaded_at_least_once = True
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -502,6 +511,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks")
+ with devices.autocast(), torch.no_grad():
+ sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+
+ timer.record("calculate empty prompt")
+
print(f"Model loaded in {timer.summary()}.")
return sd_model
@@ -521,6 +535,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
+ sd_unet.apply_unet("None")
+
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
@@ -538,13 +554,12 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
- checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
- except Exception as e:
+ except Exception:
print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise
@@ -565,7 +580,7 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
timer = Timer()
if model_data.sd_model:
@@ -580,3 +595,29 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
return sd_model
+
+
+def apply_token_merging(sd_model, token_merging_ratio):
+ """
+ Applies speed and memory optimizations from tomesd.
+ """
+
+ current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
+
+ if current_token_merging_ratio == token_merging_ratio:
+ return
+
+ if current_token_merging_ratio > 0:
+ tomesd.remove_patch(sd_model)
+
+ if token_merging_ratio > 0:
+ tomesd.apply_patch(
+ sd_model,
+ ratio=token_merging_ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
+
+ sd_model.applied_token_merged_ratio = token_merging_ratio
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 7a79925a..9bfe1237 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -1,4 +1,3 @@
-import re
import os
import torch
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index ff361f22..f22aad8f 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -1,7 +1,7 @@
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
# imports for functions that previously were here and are used by other modules
-from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
+from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
@@ -14,12 +14,18 @@ samplers_for_img2img = []
samplers_map = {}
-def create_sampler(name, model):
+def find_sampler_config(name):
if name is not None:
config = all_samplers_map.get(name, None)
else:
config = all_samplers[0]
+ return config
+
+
+def create_sampler(name, model):
+ config = find_sampler_config(name)
+
assert config is not None, f'bad sampler name: {name}'
sampler = config.constructor(model)
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index bc074238..763829f1 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
-from modules import devices, processing, images, sd_vae_approx
+from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
from modules.shared import opts, state
import modules.shared as shared
@@ -22,7 +22,7 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
-approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
+approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
def single_sample_to_image(sample, approximation=None):
@@ -30,15 +30,19 @@ def single_sample_to_image(sample, approximation=None):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
- x_sample = sd_vae_approx.cheap_approximation(sample)
+ x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
elif approximation == 1:
- x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
+ x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
+ elif approximation == 3:
+ x_sample = sample * 1.5
+ x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
- x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
+ x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
- x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
+
return Image.fromarray(x_sample)
@@ -58,6 +62,25 @@ def store_latent(decoded):
shared.state.assign_current_image(sample_to_image(decoded))
+def is_sampler_using_eta_noise_seed_delta(p):
+ """returns whether sampler from config will use eta noise seed delta for image creation"""
+
+ sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
+
+ eta = p.eta
+
+ if eta is None and p.sampler is not None:
+ eta = p.sampler.eta
+
+ if eta is None and sampler_config is not None:
+ eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
+
+ if eta == 0:
+ return False
+
+ return sampler_config.options.get("uses_ensd", False)
+
+
class InterruptedException(BaseException):
pass
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index bfcc5574..bdae8b40 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -11,7 +11,7 @@ import modules.models.diffusion.uni_pc
samplers_data_compvis = [
- sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
]
@@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
- res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
@@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
- assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
+ assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
@@ -134,7 +134,11 @@ class VanillaStableDiffusionSampler:
self.update_step(x)
def initialize(self, p):
- self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ if self.is_ddim:
+ self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
+ else:
+ self.eta = 0.0
+
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 0fc9f456..71581b76 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,7 +1,6 @@
from collections import deque
import torch
import inspect
-import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
@@ -9,25 +8,28 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
- ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
- ('Heun', 'sample_heun', ['k_heun'], {}),
+ ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
- ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
+ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
- ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
- ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
- ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
+ ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
+ ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
+ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
- ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
- ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
+ ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
+ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
- ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
+ ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
+ ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
]
samplers_data_k_diffusion = [
@@ -42,6 +44,14 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
+k_diffusion_scheduler = {
+ 'Automatic': None,
+ 'karras': k_diffusion.sampling.get_sigmas_karras,
+ 'exponential': k_diffusion.sampling.get_sigmas_exponential,
+ 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
+}
+
class CFGDenoiser(torch.nn.Module):
"""
@@ -59,6 +69,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
self.image_cfg_scale = None
+ self.padded_cond_uncond = False
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -87,17 +98,17 @@ class CFGDenoiser(torch.nn.Module):
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
- assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
- make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else:
image_uncond = image_cond
- make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@@ -123,6 +134,18 @@ class CFGDenoiser(torch.nn.Module):
x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
+ self.padded_cond_uncond = False
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
+ self.padded_cond_uncond = True
+ elif num_repeats > 0:
+ uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
+ self.padded_cond_uncond = True
+
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond])
@@ -161,7 +184,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -181,6 +204,10 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+
self.step += 1
return denoised
@@ -224,7 +251,7 @@ class KDiffusionSampler:
self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.config = None
+ self.config = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
@@ -249,6 +276,13 @@ class KDiffusionSampler:
try:
return func()
+ except RecursionError:
+ print(
+ 'Encountered RecursionError during sampling, returning last latent. '
+ 'rho >5 with a polyexponential scheduler may cause this error. '
+ 'You should try to use a smaller rho value instead.'
+ )
+ return self.last_latent
except sd_samplers_common.InterruptedException:
return self.last_latent
@@ -288,6 +322,31 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
+ elif opts.k_sched_type != "Automatic":
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+ sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
+ sigmas_kwargs = {
+ 'sigma_min': sigma_min,
+ 'sigma_max': sigma_max,
+ }
+
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
+ p.extra_generation_params["Schedule type"] = opts.k_sched_type
+
+ if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
+ sigmas_kwargs['sigma_min'] = opts.sigma_min
+ p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
+ if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
+ sigmas_kwargs['sigma_max'] = opts.sigma_max
+ p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
+
+ default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
+
+ if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
+ sigmas_kwargs['rho'] = opts.rho
+ p.extra_generation_params["Schedule rho"] = opts.rho
+
+ sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
@@ -317,7 +376,7 @@ class KDiffusionSampler:
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
-
+
extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters
@@ -333,22 +392,25 @@ class KDiffusionSampler:
if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigma_sched
- if self.funcname == 'sample_dpmpp_sde':
+ if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
self.model_wrap_cfg.init_latent = x
self.last_latent = x
- extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ extra_args = {
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
@@ -369,18 +431,21 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
- if self.funcname == 'sample_dpmpp_sde':
+ if self.config.options.get('brownian_noise', False):
noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
- 'cond': conditioning,
- 'image_cond': image_conditioning,
- 'uncond': unconditional_conditioning,
+ 'cond': conditioning,
+ 'image_cond': image_conditioning,
+ 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ if self.model_wrap_cfg.padded_cond_uncond:
+ p.extra_generation_params["Pad conds"] = True
+
return samples
diff --git a/modules/sd_unet.py b/modules/sd_unet.py
new file mode 100644
index 00000000..6d708ad2
--- /dev/null
+++ b/modules/sd_unet.py
@@ -0,0 +1,92 @@
+import torch.nn
+import ldm.modules.diffusionmodules.openaimodel
+
+from modules import script_callbacks, shared, devices
+
+unet_options = []
+current_unet_option = None
+current_unet = None
+
+
+def list_unets():
+ new_unets = script_callbacks.list_unets_callback()
+
+ unet_options.clear()
+ unet_options.extend(new_unets)
+
+
+def get_unet_option(option=None):
+ option = option or shared.opts.sd_unet
+
+ if option == "None":
+ return None
+
+ if option == "Automatic":
+ name = shared.sd_model.sd_checkpoint_info.model_name
+
+ options = [x for x in unet_options if x.model_name == name]
+
+ option = options[0].label if options else "None"
+
+ return next(iter([x for x in unet_options if x.label == option]), None)
+
+
+def apply_unet(option=None):
+ global current_unet_option
+ global current_unet
+
+ new_option = get_unet_option(option)
+ if new_option == current_unet_option:
+ return
+
+ if current_unet is not None:
+ print(f"Dectivating unet: {current_unet.option.label}")
+ current_unet.deactivate()
+
+ current_unet_option = new_option
+ if current_unet_option is None:
+ current_unet = None
+
+ if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
+ shared.sd_model.model.diffusion_model.to(devices.device)
+
+ return
+
+ shared.sd_model.model.diffusion_model.to(devices.cpu)
+ devices.torch_gc()
+
+ current_unet = current_unet_option.create_unet()
+ current_unet.option = current_unet_option
+ print(f"Activating unet: {current_unet.option.label}")
+ current_unet.activate()
+
+
+class SdUnetOption:
+ model_name = None
+ """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
+
+ label = None
+ """name of the unet in UI"""
+
+ def create_unet(self):
+ """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
+ raise NotImplementedError()
+
+
+class SdUnet(torch.nn.Module):
+ def forward(self, x, timesteps, context, *args, **kwargs):
+ raise NotImplementedError()
+
+ def activate(self):
+ pass
+
+ def deactivate(self):
+ pass
+
+
+def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
+ if current_unet is not None:
+ return current_unet.forward(x, timesteps, context, *args, **kwargs)
+
+ return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
+
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 521e485a..e4ff2994 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,8 +1,5 @@
-import torch
-import safetensors.torch
import os
import collections
-from collections import namedtuple
from modules import paths, shared, devices, script_callbacks, sd_models
import glob
from copy import deepcopy
@@ -88,10 +85,10 @@ def refresh_vae_list():
def find_vae_near_checkpoint(checkpoint_file):
- checkpoint_path = os.path.splitext(checkpoint_file)[0]
- for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]:
- if os.path.isfile(vae_location):
- return vae_location
+ checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
+ for vae_file in vae_dict.values():
+ if os.path.basename(vae_file).startswith(checkpoint_path):
+ return vae_file
return None
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
new file mode 100644
index 00000000..5e8496e8
--- /dev/null
+++ b/modules/sd_vae_taesd.py
@@ -0,0 +1,88 @@
+"""
+Tiny AutoEncoder for Stable Diffusion
+(DNN for encoding / decoding SD's latent space)
+
+https://github.com/madebyollin/taesd
+"""
+import os
+import torch
+import torch.nn as nn
+
+from modules import devices, paths_internal
+
+sd_vae_taesd = None
+
+
+def conv(n_in, n_out, **kwargs):
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+
+
+class Clamp(nn.Module):
+ @staticmethod
+ def forward(x):
+ return torch.tanh(x / 3) * 3
+
+
+class Block(nn.Module):
+ def __init__(self, n_in, n_out):
+ super().__init__()
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ self.fuse = nn.ReLU()
+
+ def forward(self, x):
+ return self.fuse(self.conv(x) + self.skip(x))
+
+
+def decoder():
+ return nn.Sequential(
+ Clamp(), conv(4, 64), nn.ReLU(),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), conv(64, 3),
+ )
+
+
+class TAESD(nn.Module):
+ latent_magnitude = 3
+ latent_shift = 0.5
+
+ def __init__(self, decoder_path="taesd_decoder.pth"):
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
+ super().__init__()
+ self.decoder = decoder()
+ self.decoder.load_state_dict(
+ torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
+
+ @staticmethod
+ def unscale_latents(x):
+ """[0, 1] -> raw latents"""
+ return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
+
+
+def download_model(model_path):
+ model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
+
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+
+ print(f'Downloading TAESD decoder to: {model_path}')
+ torch.hub.download_url_to_file(model_url, model_path)
+
+
+def model():
+ global sd_vae_taesd
+
+ if sd_vae_taesd is None:
+ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
+ download_model(model_path)
+
+ if os.path.exists(model_path):
+ sd_vae_taesd = TAESD(model_path)
+ sd_vae_taesd.eval()
+ sd_vae_taesd.to(devices.device, devices.dtype)
+ else:
+ raise FileNotFoundError('TAESD model not found')
+
+ return sd_vae_taesd.decoder
diff --git a/modules/shared.py b/modules/shared.py
index b3508883..48478a68 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -1,13 +1,14 @@
-import argparse
import datetime
import json
import os
+import re
import sys
+import threading
import time
-import requests
+import logging
-from PIL import Image
import gradio as gr
+import torch
import tqdm
import modules.interrogate
@@ -15,8 +16,11 @@ import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
-from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion
+from typing import Optional
+
+log = logging.getLogger(__name__)
demo = None
@@ -44,19 +48,6 @@ restricted_opts = {
"outdir_init_images"
}
-ui_reorder_categories = [
- "inpaint",
- "sampler",
- "checkboxes",
- "hires_fix",
- "dimensions",
- "cfg",
- "seed",
- "batch",
- "override_settings",
- "scripts",
-]
-
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
gradio_hf_hub_themes = [
"gradio/glass",
@@ -77,6 +68,9 @@ cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_op
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
+devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
+devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
+
device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
@@ -113,14 +107,56 @@ class State:
id_live_preview = 0
textinfo = None
time_start = None
- need_restart = False
server_start = None
+ _server_command_signal = threading.Event()
+ _server_command: Optional[str] = None
+
+ @property
+ def need_restart(self) -> bool:
+ # Compatibility getter for need_restart.
+ return self.server_command == "restart"
+
+ @need_restart.setter
+ def need_restart(self, value: bool) -> None:
+ # Compatibility setter for need_restart.
+ if value:
+ self.server_command = "restart"
+
+ @property
+ def server_command(self):
+ return self._server_command
+
+ @server_command.setter
+ def server_command(self, value: Optional[str]) -> None:
+ """
+ Set the server command to `value` and signal that it's been set.
+ """
+ self._server_command = value
+ self._server_command_signal.set()
+
+ def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
+ """
+ Wait for server command to get set; return and clear the value and signal.
+ """
+ if self._server_command_signal.wait(timeout):
+ self._server_command_signal.clear()
+ req = self._server_command
+ self._server_command = None
+ return req
+ return None
+
+ def request_restart(self) -> None:
+ self.interrupt()
+ self.server_command = "restart"
+ log.info("Received restart request")
def skip(self):
self.skipped = True
+ log.info("Received skip request")
def interrupt(self):
self.interrupted = True
+ log.info("Received interrupt request")
def nextjob(self):
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
@@ -144,7 +180,7 @@ class State:
return obj
- def begin(self):
+ def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
@@ -158,10 +194,13 @@ class State:
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
-
+ self.job = job
devices.torch_gc()
+ log.info("Starting job %s", job)
def end(self):
+ duration = time.time() - self.time_start
+ log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
@@ -202,8 +241,9 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
+
class OptionInfo:
- def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
+ def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
self.default = default
self.label = label
self.component = component
@@ -212,9 +252,37 @@ class OptionInfo:
self.section = section
self.refresh = refresh
+ self.comment_before = comment_before
+ """HTML text that will be added after label in UI"""
+
+ self.comment_after = comment_after
+ """HTML text that will be added before label in UI"""
+
+ def link(self, label, url):
+ self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
+ return self
+
+ def js(self, label, js_func):
+ self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
+ return self
+
+ def info(self, info):
+ self.comment_after += f"<span class='info'>({info})</span>"
+ return self
+
+ def html(self, html):
+ self.comment_after += html
+ return self
+
+ def needs_restart(self):
+ self.comment_after += " <span class='info'>(requires restart)</span>"
+ return self
+
+
+
def options_section(section_identifier, options_dict):
- for k, v in options_dict.items():
+ for v in options_dict.values():
v.section = section_identifier
return options_dict
@@ -243,7 +311,7 @@ options_templates = {}
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'),
- "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
+ "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"grid_save": OptionInfo(True, "Always save all generated image grids"),
@@ -251,7 +319,12 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
+ "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
+ "font": OptionInfo("", "Font for image grids that have text"),
+ "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
+ "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
@@ -262,10 +335,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
- "export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
+ "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
- "img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
+ "img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
@@ -293,31 +366,31 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
- "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
+ "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
}))
options_templates.update(options_section(('upscaling', "Upscaling"), {
- "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
- "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
+ "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
+ "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
- "SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
- "SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
- "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
}))
options_templates.update(options_section(('system', "System"), {
"show_warnings": OptionInfo(False, "Show warnings in console."),
- "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
+ "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
+ "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
+ "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -339,20 +412,31 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
- "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
+ "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
- "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
+ "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
- "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
- "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
+ "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
+ "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
- "randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
+ "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
+}))
+
+options_templates.update(options_section(('optimizations', "Optimizations"), {
+ "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
+ "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
+ "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
+ "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
+ "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
+ "experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -361,89 +445,109 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
+ "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
- "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
- "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
- "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
- "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
- "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
- "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
+ "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
+ "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
+ "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
+ "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
+ "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
+ "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
- "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
- "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
- "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
- "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
- "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
+ "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
+ "deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
+ "deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
+ "deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
+ "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
}))
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
+ "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
+ "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
- "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
- "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
- "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
+ "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
+ "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
+ "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
+ "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
+ "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
+ "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
+ "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
+ "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
- "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
- "js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
+ "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
- "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
- "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
+ "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
+ "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
- "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
- "hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
- "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
- "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
- "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
- "gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
+ "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
+ "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
+ "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
+ "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
+ "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
+ "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
+ "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
+ "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
}))
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
+ "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
- "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
+ "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
+ "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
+<li>Ignore: keep prompt and styles dropdown as it is.</li>
+<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
+<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
+<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
+</ul>"""),
+
}))
options_templates.update(options_section(('ui', "Live previews"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+ "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
- "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
- "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
+ "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
+ "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
- "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
+ "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
}))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
- "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
- "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+ "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
+ "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
+ "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
- 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
- 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
+ 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
+ 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
+ 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
+ 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
+ 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
+ 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
- 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
+ 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
}))
@@ -460,6 +564,7 @@ options_templates.update(options_section((None, "Hidden options"), {
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
+
options_templates.update()
@@ -553,6 +658,10 @@ class Options:
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
+ # 1.4.0 ui_reorder
+ if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
+ self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
+
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
@@ -571,7 +680,9 @@ class Options:
func()
def dumpjson(self):
- d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
+ d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
+ d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
+ d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
return json.dumps(d)
def add_option(self, key, info):
@@ -582,11 +693,11 @@ class Options:
section_ids = {}
settings_items = self.data_labels.items()
- for k, item in settings_items:
+ for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
- self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
+ self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
@@ -722,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
+def natural_sort_key(s, regex=re.compile('([0-9]+)')):
+ return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
+
+
def listfiles(dirname):
- filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
+ filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
@@ -748,11 +863,17 @@ def walk_files(path, allowed_extensions=None):
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
- for root, dirs, files in os.walk(path):
- for filename in files:
+ items = list(os.walk(path, followlinks=True))
+ items = sorted(items, key=lambda x: natural_sort_key(x[0]))
+
+ for root, _, files in items:
+ for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
+ if not opts.list_hidden_files and ("/." in root or "\\." in root):
+ continue
+
yield os.path.join(root, filename)
diff --git a/modules/shared_items.py b/modules/shared_items.py
index e792a134..89792e88 100644
--- a/modules/shared_items.py
+++ b/modules/shared_items.py
@@ -21,3 +21,49 @@ def refresh_vae_list():
import modules.sd_vae
modules.sd_vae.refresh_vae_list()
+
+
+def cross_attention_optimizations():
+ import modules.sd_hijack
+
+ return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
+
+
+def sd_unet_items():
+ import modules.sd_unet
+
+ return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
+
+
+def refresh_unet_list():
+ import modules.sd_unet
+
+ modules.sd_unet.list_unets()
+
+
+ui_reorder_categories_builtin_items = [
+ "inpaint",
+ "sampler",
+ "checkboxes",
+ "hires_fix",
+ "dimensions",
+ "cfg",
+ "seed",
+ "batch",
+ "override_settings",
+]
+
+
+def ui_reorder_categories():
+ from modules import scripts
+
+ yield from ui_reorder_categories_builtin_items
+
+ sections = {}
+ for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
+ if isinstance(script.section, str):
+ sections[script.section] = 1
+
+ yield from sections
+
+ yield "scripts"
diff --git a/modules/styles.py b/modules/styles.py
index 11642075..ec0e1bc5 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -1,18 +1,10 @@
-# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
-from __future__ import annotations
-
import csv
import os
import os.path
+import re
import typing
-import collections.abc as abc
-import tempfile
import shutil
-if typing.TYPE_CHECKING:
- # Only import this when code is being type-checked, it doesn't have any effect at runtime
- from .processing import StableDiffusionProcessing
-
class PromptStyle(typing.NamedTuple):
name: str
@@ -37,6 +29,44 @@ def apply_styles_to_prompt(prompt, styles):
return prompt
+re_spaces = re.compile(" +")
+
+
+def extract_style_text_from_prompt(style_text, prompt):
+ stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
+ stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
+ if "{prompt}" in stripped_style_text:
+ left, right = stripped_style_text.split("{prompt}", 2)
+ if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
+ prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
+ return True, prompt
+ else:
+ if stripped_prompt.endswith(stripped_style_text):
+ prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
+
+ if prompt.endswith(', '):
+ prompt = prompt[:-2]
+
+ return True, prompt
+
+ return False, prompt
+
+
+def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
+ if not style.prompt and not style.negative_prompt:
+ return False, prompt, negative_prompt
+
+ match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
+ if not match_positive:
+ return False, prompt, negative_prompt
+
+ match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
+ if not match_negative:
+ return False, prompt, negative_prompt
+
+ return True, extracted_positive, extracted_negative
+
+
class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "")
@@ -52,7 +82,7 @@ class StyleDatabase:
return
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
- reader = csv.DictReader(file)
+ reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
# Support loading old CSV format with "name, text"-columns
prompt = row["prompt"] if "prompt" in row else row["text"]
@@ -76,10 +106,34 @@ class StyleDatabase:
if os.path.exists(path):
shutil.copy(path, f"{path}.bak")
- fd = os.open(path, os.O_RDWR|os.O_CREAT)
+ fd = os.open(path, os.O_RDWR | os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
- writer.writerows(style._asdict() for k, style in self.styles.items())
+ writer.writerows(style._asdict() for k, style in self.styles.items())
+
+ def extract_styles_from_prompt(self, prompt, negative_prompt):
+ extracted = []
+
+ applicable_styles = list(self.styles.values())
+
+ while True:
+ found_style = None
+
+ for style in applicable_styles:
+ is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
+ if is_match:
+ found_style = style
+ prompt = new_prompt
+ negative_prompt = new_neg_prompt
+ break
+
+ if not found_style:
+ break
+
+ applicable_styles.remove(found_style)
+ extracted.append(found_style.name)
+
+ return list(reversed(extracted)), prompt, negative_prompt
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 05595323..497568eb 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -179,7 +179,7 @@ def efficient_dot_product_attention(
chunk_idx,
min(query_chunk_size, q_tokens)
)
-
+
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@@ -201,14 +201,15 @@ def efficient_dot_product_attention(
key=key,
value=value,
)
-
- # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
- # and pass slices to be mutated, instead of torch.cat()ing the returned slices
- res = torch.cat([
- compute_query_chunk_attn(
+
+ res = torch.zeros_like(query)
+ for i in range(math.ceil(q_tokens / query_chunk_size)):
+ attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
- ) for i in range(math.ceil(q_tokens / query_chunk_size))
- ], dim=1)
+ )
+
+ res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
+
return res
diff --git a/modules/sysinfo.py b/modules/sysinfo.py
new file mode 100644
index 00000000..5f15ac4f
--- /dev/null
+++ b/modules/sysinfo.py
@@ -0,0 +1,162 @@
+import json
+import os
+import sys
+import traceback
+
+import platform
+import hashlib
+import pkg_resources
+import psutil
+import re
+
+import launch
+from modules import paths_internal, timer
+
+checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
+environment_whitelist = {
+ "GIT",
+ "INDEX_URL",
+ "WEBUI_LAUNCH_LIVE_OUTPUT",
+ "GRADIO_ANALYTICS_ENABLED",
+ "PYTHONPATH",
+ "TORCH_INDEX_URL",
+ "TORCH_COMMAND",
+ "REQS_FILE",
+ "XFORMERS_PACKAGE",
+ "GFPGAN_PACKAGE",
+ "CLIP_PACKAGE",
+ "OPENCLIP_PACKAGE",
+ "STABLE_DIFFUSION_REPO",
+ "K_DIFFUSION_REPO",
+ "CODEFORMER_REPO",
+ "BLIP_REPO",
+ "STABLE_DIFFUSION_COMMIT_HASH",
+ "K_DIFFUSION_COMMIT_HASH",
+ "CODEFORMER_COMMIT_HASH",
+ "BLIP_COMMIT_HASH",
+ "COMMANDLINE_ARGS",
+ "IGNORE_CMD_ARGS_ERRORS",
+}
+
+
+def pretty_bytes(num, suffix="B"):
+ for unit in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]:
+ if abs(num) < 1024 or unit == 'Y':
+ return f"{num:.0f}{unit}{suffix}"
+ num /= 1024
+
+
+def get():
+ res = get_dict()
+
+ text = json.dumps(res, ensure_ascii=False, indent=4)
+
+ h = hashlib.sha256(text.encode("utf8"))
+ text = text.replace(checksum_token, h.hexdigest())
+
+ return text
+
+
+re_checksum = re.compile(r'"Checksum": "([0-9a-fA-F]{64})"')
+
+
+def check(x):
+ m = re.search(re_checksum, x)
+ if not m:
+ return False
+
+ replaced = re.sub(re_checksum, f'"Checksum": "{checksum_token}"', x)
+
+ h = hashlib.sha256(replaced.encode("utf8"))
+ return h.hexdigest() == m.group(1)
+
+
+def get_dict():
+ ram = psutil.virtual_memory()
+
+ res = {
+ "Platform": platform.platform(),
+ "Python": platform.python_version(),
+ "Version": launch.git_tag(),
+ "Commit": launch.commit_hash(),
+ "Script path": paths_internal.script_path,
+ "Data path": paths_internal.data_path,
+ "Extensions dir": paths_internal.extensions_dir,
+ "Checksum": checksum_token,
+ "Commandline": sys.argv,
+ "Torch env info": get_torch_sysinfo(),
+ "Exceptions": get_exceptions(),
+ "CPU": {
+ "model": platform.processor(),
+ "count logical": psutil.cpu_count(logical=True),
+ "count physical": psutil.cpu_count(logical=False),
+ },
+ "RAM": {
+ x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0
+ },
+ "Extensions": get_extensions(enabled=True),
+ "Inactive extensions": get_extensions(enabled=False),
+ "Environment": get_environment(),
+ "Config": get_config(),
+ "Startup": timer.startup_record,
+ "Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]),
+ }
+
+ return res
+
+
+def format_traceback(tb):
+ return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
+
+
+def get_exceptions():
+ try:
+ from modules import errors
+
+ return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
+ except Exception as e:
+ return str(e)
+
+
+def get_environment():
+ return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
+
+
+re_newline = re.compile(r"\r*\n")
+
+
+def get_torch_sysinfo():
+ try:
+ import torch.utils.collect_env
+ info = torch.utils.collect_env.get_env_info()._asdict()
+
+ return {k: re.split(re_newline, str(v)) if "\n" in str(v) else v for k, v in info.items()}
+ except Exception as e:
+ return str(e)
+
+
+def get_extensions(*, enabled):
+
+ try:
+ from modules import extensions
+
+ def to_json(x: extensions.Extension):
+ return {
+ "name": x.name,
+ "path": x.path,
+ "version": x.version,
+ "branch": x.branch,
+ "remote": x.remote,
+ }
+
+ return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]
+ except Exception as e:
+ return str(e)
+
+
+def get_config():
+ try:
+ from modules import shared
+ return shared.opts.data
+ except Exception as e:
+ return str(e)
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py
index ba1bdcd4..1675e39a 100644
--- a/modules/textual_inversion/autocrop.py
+++ b/modules/textual_inversion/autocrop.py
@@ -1,10 +1,8 @@
import cv2
import requests
import os
-from collections import defaultdict
-from math import log, sqrt
import numpy as np
-from PIL import Image, ImageDraw
+from PIL import ImageDraw
GREEN = "#0F0"
BLUE = "#00F"
@@ -12,63 +10,64 @@ RED = "#F00"
def crop_image(im, settings):
- """ Intelligently crop an image to the subject matter """
-
- scale_by = 1
- if is_landscape(im.width, im.height):
- scale_by = settings.crop_height / im.height
- elif is_portrait(im.width, im.height):
- scale_by = settings.crop_width / im.width
- elif is_square(im.width, im.height):
- if is_square(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_landscape(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_width / im.width
- elif is_portrait(settings.crop_width, settings.crop_height):
- scale_by = settings.crop_height / im.height
-
- im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
- im_debug = im.copy()
-
- focus = focal_point(im_debug, settings)
-
- # take the focal point and turn it into crop coordinates that try to center over the focal
- # point but then get adjusted back into the frame
- y_half = int(settings.crop_height / 2)
- x_half = int(settings.crop_width / 2)
-
- x1 = focus.x - x_half
- if x1 < 0:
- x1 = 0
- elif x1 + settings.crop_width > im.width:
- x1 = im.width - settings.crop_width
-
- y1 = focus.y - y_half
- if y1 < 0:
- y1 = 0
- elif y1 + settings.crop_height > im.height:
- y1 = im.height - settings.crop_height
-
- x2 = x1 + settings.crop_width
- y2 = y1 + settings.crop_height
-
- crop = [x1, y1, x2, y2]
-
- results = []
-
- results.append(im.crop(tuple(crop)))
-
- if settings.annotate_image:
- d = ImageDraw.Draw(im_debug)
- rect = list(crop)
- rect[2] -= 1
- rect[3] -= 1
- d.rectangle(rect, outline=GREEN)
- results.append(im_debug)
- if settings.destop_view_image:
- im_debug.show()
-
- return results
+ """ Intelligently crop an image to the subject matter """
+
+ scale_by = 1
+ if is_landscape(im.width, im.height):
+ scale_by = settings.crop_height / im.height
+ elif is_portrait(im.width, im.height):
+ scale_by = settings.crop_width / im.width
+ elif is_square(im.width, im.height):
+ if is_square(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_landscape(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_width / im.width
+ elif is_portrait(settings.crop_width, settings.crop_height):
+ scale_by = settings.crop_height / im.height
+
+
+ im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
+ im_debug = im.copy()
+
+ focus = focal_point(im_debug, settings)
+
+ # take the focal point and turn it into crop coordinates that try to center over the focal
+ # point but then get adjusted back into the frame
+ y_half = int(settings.crop_height / 2)
+ x_half = int(settings.crop_width / 2)
+
+ x1 = focus.x - x_half
+ if x1 < 0:
+ x1 = 0
+ elif x1 + settings.crop_width > im.width:
+ x1 = im.width - settings.crop_width
+
+ y1 = focus.y - y_half
+ if y1 < 0:
+ y1 = 0
+ elif y1 + settings.crop_height > im.height:
+ y1 = im.height - settings.crop_height
+
+ x2 = x1 + settings.crop_width
+ y2 = y1 + settings.crop_height
+
+ crop = [x1, y1, x2, y2]
+
+ results = []
+
+ results.append(im.crop(tuple(crop)))
+
+ if settings.annotate_image:
+ d = ImageDraw.Draw(im_debug)
+ rect = list(crop)
+ rect[2] -= 1
+ rect[3] -= 1
+ d.rectangle(rect, outline=GREEN)
+ results.append(im_debug)
+ if settings.destop_view_image:
+ im_debug.show()
+
+ return results
def focal_point(im, settings):
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
@@ -78,29 +77,29 @@ def focal_point(im, settings):
pois = []
weight_pref_total = 0
- if len(corner_points) > 0:
+ if corner_points:
weight_pref_total += settings.corner_points_weight
- if len(entropy_points) > 0:
+ if entropy_points:
weight_pref_total += settings.entropy_points_weight
- if len(face_points) > 0:
+ if face_points:
weight_pref_total += settings.face_points_weight
corner_centroid = None
- if len(corner_points) > 0:
+ if corner_points:
corner_centroid = centroid(corner_points)
- corner_centroid.weight = settings.corner_points_weight / weight_pref_total
+ corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid)
entropy_centroid = None
- if len(entropy_points) > 0:
+ if entropy_points:
entropy_centroid = centroid(entropy_points)
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
pois.append(entropy_centroid)
face_centroid = None
- if len(face_points) > 0:
+ if face_points:
face_centroid = centroid(face_points)
- face_centroid.weight = settings.face_points_weight / weight_pref_total
+ face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid)
average_point = poi_average(pois, settings)
@@ -134,7 +133,7 @@ def focal_point(im, settings):
d.rectangle(f.bounding(4), outline=color)
d.ellipse(average_point.bounding(max_size), outline=GREEN)
-
+
return average_point
@@ -185,10 +184,10 @@ def image_face_points(im, settings):
try:
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
- except:
+ except Exception:
continue
- if len(faces) > 0:
+ if faces:
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
return []
@@ -262,10 +261,11 @@ def image_entropy(im):
hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum()
+
def centroid(pois):
- x = [poi.x for poi in pois]
- y = [poi.y for poi in pois]
- return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
+ x = [poi.x for poi in pois]
+ y = [poi.y for poi in pois]
+ return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
def poi_average(pois, settings):
@@ -283,59 +283,58 @@ def poi_average(pois, settings):
def is_landscape(w, h):
- return w > h
+ return w > h
def is_portrait(w, h):
- return h > w
+ return h > w
def is_square(w, h):
- return w == h
+ return w == h
def download_and_cache_models(dirname):
- download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
- model_file_name = 'face_detection_yunet.onnx'
+ download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
+ model_file_name = 'face_detection_yunet.onnx'
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ os.makedirs(dirname, exist_ok=True)
- cache_file = os.path.join(dirname, model_file_name)
- if not os.path.exists(cache_file):
- print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
- response = requests.get(download_url)
- with open(cache_file, "wb") as f:
- f.write(response.content)
+ cache_file = os.path.join(dirname, model_file_name)
+ if not os.path.exists(cache_file):
+ print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
+ response = requests.get(download_url)
+ with open(cache_file, "wb") as f:
+ f.write(response.content)
- if os.path.exists(cache_file):
- return cache_file
- return None
+ if os.path.exists(cache_file):
+ return cache_file
+ return None
class PointOfInterest:
- def __init__(self, x, y, weight=1.0, size=10):
- self.x = x
- self.y = y
- self.weight = weight
- self.size = size
+ def __init__(self, x, y, weight=1.0, size=10):
+ self.x = x
+ self.y = y
+ self.weight = weight
+ self.size = size
- def bounding(self, size):
- return [
- self.x - size//2,
- self.y - size//2,
- self.x + size//2,
- self.y + size//2
- ]
+ def bounding(self, size):
+ return [
+ self.x - size // 2,
+ self.y - size // 2,
+ self.x + size // 2,
+ self.y + size // 2
+ ]
class Settings:
- def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
- self.crop_width = crop_width
- self.crop_height = crop_height
- self.corner_points_weight = corner_points_weight
- self.entropy_points_weight = entropy_points_weight
- self.face_points_weight = face_points_weight
- self.annotate_image = annotate_image
- self.destop_view_image = False
- self.dnn_model_path = dnn_model_path
+ def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
+ self.crop_width = crop_width
+ self.crop_height = crop_height
+ self.corner_points_weight = corner_points_weight
+ self.entropy_points_weight = entropy_points_weight
+ self.face_points_weight = face_points_weight
+ self.annotate_image = annotate_image
+ self.destop_view_image = False
+ self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 41610e03..7ee05061 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -32,7 +32,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):
- re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
+ re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None
self.placeholder_token = placeholder_token
@@ -118,7 +118,7 @@ class PersonalizedBase(Dataset):
weight = torch.ones(latent_sample.shape)
else:
weight = None
-
+
if latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
else:
@@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader):
return self
def collate_wrapper_random(batch):
- return BatchLoaderRandom(batch) \ No newline at end of file
+ return BatchLoaderRandom(batch)
diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py
index 5593f88c..81cff7bf 100644
--- a/modules/textual_inversion/image_embedding.py
+++ b/modules/textual_inversion/image_embedding.py
@@ -1,11 +1,11 @@
import base64
import json
+import warnings
+
import numpy as np
import zlib
-from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
-from fonts.ttf import Roboto
+from PIL import Image, ImageDraw
import torch
-from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
@@ -17,7 +17,7 @@ class EmbeddingEncoder(json.JSONEncoder):
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
- json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
+ json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:
@@ -131,17 +131,17 @@ def extract_image_data_embed(image):
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
+ from modules.images import get_font
+ if textfont:
+ warnings.warn(
+ 'passing in a textfont to caption_image_overlay is deprecated and does nothing',
+ DeprecationWarning,
+ stacklevel=2,
+ )
from math import cos
image = srcimage.copy()
fontsize = 32
- if textfont is None:
- try:
- textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
- textfont = opts.font or Roboto
- except Exception:
- textfont = Roboto
-
factor = 1.5
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
for y in range(image.size[1]):
@@ -152,12 +152,12 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
draw = ImageDraw.Draw(image)
- font = ImageFont.truetype(textfont, fontsize)
+ font = get_font(fontsize)
padding = 10
_, _, w, h = draw.textbbox((0, 0), title, font=font)
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
- font = ImageFont.truetype(textfont, fontsize)
+ font = get_font(fontsize)
_, _, w, h = draw.textbbox((0, 0), title, font=font)
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
@@ -168,7 +168,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
- font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
+ font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index f63fc72f..c56bea45 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -12,7 +12,7 @@ class LearnScheduleIterator:
self.it = 0
self.maxit = 0
try:
- for i, pair in enumerate(pairs):
+ for pair in pairs:
if not pair.strip():
continue
tmp = pair.split(':')
@@ -32,8 +32,8 @@ class LearnScheduleIterator:
self.maxit += 1
return
assert self.rates
- except (ValueError, AssertionError):
- raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+ except (ValueError, AssertionError) as e:
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
def __iter__(self):
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
index 734a4b6f..45823eb1 100644
--- a/modules/textual_inversion/logging.py
+++ b/modules/textual_inversion/logging.py
@@ -2,11 +2,51 @@ import datetime
import json
import os
-saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"}
-saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
-saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_shared = {
+ "batch_size",
+ "clip_grad_mode",
+ "clip_grad_value",
+ "create_image_every",
+ "data_root",
+ "gradient_step",
+ "initial_step",
+ "latent_sampling_method",
+ "learn_rate",
+ "log_directory",
+ "model_hash",
+ "model_name",
+ "num_of_dataset_images",
+ "steps",
+ "template_file",
+ "training_height",
+ "training_width",
+}
+saved_params_ti = {
+ "embedding_name",
+ "num_vectors_per_token",
+ "save_embedding_every",
+ "save_image_with_stored_embedding",
+}
+saved_params_hypernet = {
+ "activation_func",
+ "add_layer_norm",
+ "hypernetwork_name",
+ "layer_structure",
+ "save_hypernetwork_every",
+ "use_dropout",
+ "weight_init",
+}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
-saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+saved_params_previews = {
+ "preview_cfg_scale",
+ "preview_height",
+ "preview_negative_prompt",
+ "preview_prompt",
+ "preview_sampler_index",
+ "preview_seed",
+ "preview_steps",
+ "preview_width",
+}
def save_settings_to_file(log_directory, all_params):
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index da0bcb26..dbd856bd 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -1,17 +1,13 @@
import os
from PIL import Image, ImageOps
import math
-import platform
-import sys
import tqdm
-import time
from modules import paths, shared, images, deepbooru
-from modules.shared import opts, cmd_opts
from modules.textual_inversion import autocrop
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try:
if process_caption:
shared.interrogator.load()
@@ -51,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
caption += shared.interrogator.generate_caption(image)
if params.process_caption_deepbooru:
- if len(caption) > 0:
+ if caption:
caption += ", "
caption += deepbooru.model.tag_multi(image)
@@ -71,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
caption = caption.strip()
- if len(caption) > 0:
+ if caption:
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
@@ -129,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
default=None
)
return wh and center_crop(image, *wh)
-
+
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4368eb63..bb6f211c 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,7 +1,4 @@
import os
-import sys
-import traceback
-import inspect
from collections import namedtuple
import torch
@@ -15,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -30,7 +27,7 @@ textual_inversion_templates = {}
def list_textual_inversion_templates():
textual_inversion_templates.clear()
- for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
for fn in fns:
path = os.path.join(root, fn)
@@ -121,16 +118,29 @@ class EmbeddingDatabase:
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
-
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
+ return self.register_embedding_by_name(embedding, model, embedding.name)
+ def register_embedding_by_name(self, embedding, model, name):
+ ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
-
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
+ if name in self.word_embeddings:
+ # remove old one from the lookup list
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
+ else:
+ lookup = self.ids_lookup[first_id]
+ if embedding is not None:
+ lookup += [(ids, embedding)]
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
+ if embedding is None:
+ # unregister embedding with specified name
+ if name in self.word_embeddings:
+ del self.word_embeddings[name]
+ if len(self.ids_lookup[first_id])==0:
+ del self.ids_lookup[first_id]
+ return None
+ self.word_embeddings[name] = embedding
return embedding
def get_expected_shape(self):
@@ -167,8 +177,7 @@ class EmbeddingDatabase:
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
@@ -199,7 +208,7 @@ class EmbeddingDatabase:
if not os.path.isdir(embdir.path):
return
- for root, dirs, fns in os.walk(embdir.path, followlinks=True):
+ for root, _, fns in os.walk(embdir.path, followlinks=True):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
@@ -209,14 +218,13 @@ class EmbeddingDatabase:
self.load_from_file(fullfn, fn)
except Exception:
- print(f"Error loading embedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error loading embedding {fn}", exc_info=True)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
if not force_reload:
need_reload = False
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
if embdir.has_changed():
need_reload = True
break
@@ -229,7 +237,7 @@ class EmbeddingDatabase:
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
- for path, embdir in self.embedding_dirs.items():
+ for embdir in self.embedding_dirs.values():
self.load_from_dir(embdir)
embdir.update()
@@ -243,7 +251,7 @@ class EmbeddingDatabase:
if self.previously_displayed_embeddings != displayed_embeddings:
self.previously_displayed_embeddings = displayed_embeddings
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
- if len(self.skipped_embeddings) > 0:
+ if self.skipped_embeddings:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
@@ -325,16 +333,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
- tensorboard_writer.add_scalar(tag=tag,
+ tensorboard_writer.add_scalar(tag=tag,
scalar_value=value, global_step=step)
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
# Convert a pil image to a torch tensor
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
- img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
+ img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
len(pil_image.getbands()))
img_tensor = img_tensor.permute((2, 0, 1))
-
+
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
@@ -404,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
-
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
@@ -414,7 +422,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
-
+
if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory)
@@ -441,7 +449,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
-
+
if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint")
@@ -470,7 +478,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
try:
sd_hijack_checkpoint.add()
- for i in range((steps-initial_step) * gradient_step):
+ for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
@@ -487,7 +495,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if clip_grad:
clip_grad_sched.step(embedding.step)
-
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight:
@@ -515,7 +523,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
-
+
if clip_grad:
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
@@ -603,7 +611,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
+ except Exception:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint()
@@ -634,8 +642,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
- print(traceback.format_exc(), file=sys.stderr)
- pass
+ errors.report("Error training embedding", exc_info=True)
finally:
pbar.leave = False
pbar.close()
diff --git a/modules/timer.py b/modules/timer.py
index ba92be33..da99e49f 100644
--- a/modules/timer.py
+++ b/modules/timer.py
@@ -1,11 +1,30 @@
import time
+class TimerSubcategory:
+ def __init__(self, timer, category):
+ self.timer = timer
+ self.category = category
+ self.start = None
+ self.original_base_category = timer.base_category
+
+ def __enter__(self):
+ self.start = time.time()
+ self.timer.base_category = self.original_base_category + self.category + "/"
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ elapsed_for_subcategroy = time.time() - self.start
+ self.timer.base_category = self.original_base_category
+ self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
+ self.timer.record(self.category)
+
+
class Timer:
def __init__(self):
self.start = time.time()
self.records = {}
self.total = 0
+ self.base_category = ''
def elapsed(self):
end = time.time()
@@ -13,18 +32,29 @@ class Timer:
self.start = end
return res
- def record(self, category, extra_time=0):
- e = self.elapsed()
+ def add_time_to_record(self, category, amount):
if category not in self.records:
self.records[category] = 0
- self.records[category] += e + extra_time
+ self.records[category] += amount
+
+ def record(self, category, extra_time=0):
+ e = self.elapsed()
+
+ self.add_time_to_record(self.base_category + category, e + extra_time)
+
self.total += e + extra_time
+ def subcategory(self, name):
+ self.elapsed()
+
+ subcat = TimerSubcategory(self, name)
+ return subcat
+
def summary(self):
res = f"{self.total:.1f}s"
- additions = [x for x in self.records.items() if x[1] >= 0.1]
+ additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
if not additions:
return res
@@ -34,5 +64,13 @@ class Timer:
return res
+ def dump(self):
+ return {'total': self.total, 'records': self.records}
+
def reset(self):
self.__init__()
+
+
+startup_timer = Timer()
+
+startup_record = None
diff --git a/modules/txt2img.py b/modules/txt2img.py
index 16841d0f..6aa79f23 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -1,18 +1,16 @@
import modules.scripts
-from modules import sd_samplers
+from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict
-from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
- StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
-import modules.processing as processing
from modules.ui import plaintext_to_html
+import gradio as gr
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
- p = StableDiffusionProcessingTxt2Img(
+ p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
@@ -41,19 +39,24 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_second_pass_steps=hr_second_pass_steps,
hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y,
+ hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
+ hr_prompt=hr_prompt,
+ hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
+ p.user = request.username
+
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
- processed = process_images(p)
+ processed = processing.process_images(p)
p.close()
diff --git a/modules/ui.py b/modules/ui.py
index f07bcc41..39d226ad 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1,29 +1,25 @@
-import html
+import datetime
import json
-import math
import mimetypes
import os
-import platform
-import random
import sys
-import tempfile
-import time
-import traceback
-from functools import partial, reduce
+from functools import reduce
import warnings
import gradio as gr
-import gradio.routes
import gradio.utils
import numpy as np
-from PIL import Image, PngImagePlugin
+from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing, progress
-from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
-from modules.paths import script_path, data_path
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo
+from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
+from modules.paths import script_path
+from modules.ui_common import create_refresh_button
+from modules.ui_gradio_extensions import reload_javascript
-from modules.shared import opts, cmd_opts, restricted_opts
+
+from modules.shared import opts, cmd_opts
import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste
@@ -34,7 +30,6 @@ import modules.shared as shared
import modules.styles
import modules.textual_inversion.ui
from modules import prompt_parser
-from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.textual_inversion import textual_inversion
@@ -42,6 +37,8 @@ import modules.hypernetworks.ui
from modules.generation_parameters_copypaste import image_from_url_text
import modules.extras
+create_setting_component = ui_settings.create_setting_component
+
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
@@ -59,7 +56,7 @@ if cmd_opts.ngrok is not None:
ngrok.connect(
cmd_opts.ngrok,
cmd_opts.port if cmd_opts.port is not None else 7860,
- cmd_opts.ngrok_region
+ cmd_opts.ngrok_options
)
@@ -82,6 +79,8 @@ clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
restore_progress_symbol = '\U0001F300' # 🌀
+detect_image_size_symbol = '\U0001F4D0' # 📐
+up_down_symbol = '\u2195\ufe0f' # ↕️
def plaintext_to_html(text):
@@ -93,16 +92,6 @@ def send_gradio_gallery_to_image(x):
return None
return image_from_url_text(x[0])
-def visit(x, func, path=""):
- if hasattr(x, 'children'):
- if isinstance(x, gr.Tabs) and x.elem_id is not None:
- # Tabs element can't have a label, have to use elem_id instead
- func(f"{path}/Tabs@{x.elem_id}", x)
- for c in x.children:
- visit(c, func, path)
- elif x.label is not None:
- func(f"{path}/{x.label}", x)
-
def add_style(name: str, prompt: str, negative_prompt: str):
if name is None:
@@ -166,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
img = Image.open(image)
filename = os.path.basename(image)
left, _ = os.path.splitext(filename)
- print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
+ print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
return [gr.update(), None]
@@ -206,8 +195,8 @@ def create_seed_inputs(target_interface):
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
- random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
- random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
+ random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[])
+ random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[])
def change_visibility(show):
return {comp: gr_show(show) for comp in seed_extras}
@@ -246,10 +235,9 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
- except json.decoder.JSONDecodeError as e:
- if gen_info_string != '':
- print("Error parsing JSON generation info:", file=sys.stderr)
- print(gen_info_string, file=sys.stderr)
+ except json.decoder.JSONDecodeError:
+ if gen_info_string:
+ errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
@@ -288,12 +276,12 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
button_interrogate = None
button_deepbooru = None
@@ -384,25 +372,6 @@ def apply_setting(key, value):
return getattr(opts, key)
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
- def refresh():
- refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
-
- for k, v in args.items():
- setattr(refresh_component, k, v)
-
- return gr.update(**(args or {}))
-
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
- refresh_button.click(
- fn=refresh,
- inputs=[],
- outputs=[refresh_component]
- )
- return refresh_button
-
-
def create_output_panel(tabname, outdir):
return ui_common.create_output_panel(tabname, outdir)
@@ -421,27 +390,17 @@ def create_sampler_and_steps_selection(choices, tabname):
def ordered_ui_categories():
- user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
+ user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
- for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
+ for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
-def get_value_for_setting(key):
- value = getattr(opts, key)
-
- info = opts.data_labels[key]
- args = info.component_args() if callable(info.component_args) else info.component_args or {}
- args = {k: v for k, v in args.items() if k not in {'precision'}}
-
- return gr.update(value=value, **args)
-
-
def create_override_settings_dropdown(tabname, row):
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
dropdown.change(
- fn=lambda x: gr.Dropdown.update(visible=len(x) > 0),
+ fn=lambda x: gr.Dropdown.update(visible=bool(x)),
inputs=[dropdown],
outputs=[dropdown],
)
@@ -472,6 +431,8 @@ def create_ui():
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
+ modules.scripts.scripts_txt2img.prepare_ui()
+
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
@@ -515,6 +476,17 @@ def create_ui():
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
+ with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
+ hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
+
+ with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
+ with gr.Column(scale=80):
+ with gr.Row():
+ hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
+ with gr.Column(scale=80):
+ with gr.Row():
+ hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
+
elif category == "batch":
if not opts.dimensions_and_batch_together:
with FormRow(elem_id="txt2img_column_batch"):
@@ -529,15 +501,21 @@ def create_ui():
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
+ else:
+ modules.scripts.scripts_txt2img.setup_ui_for_section(category)
+
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
- for input in hr_resolution_preview_inputs:
- input.change(
+
+ for component in hr_resolution_preview_inputs:
+ event = component.release if isinstance(component, gr.Slider) else component.change
+
+ event(
fn=calc_resolution_hires,
inputs=hr_resolution_preview_inputs,
outputs=[hr_final_resolution],
show_progress=False,
)
- input.change(
+ event(
None,
_js="onCalcResolutionHires",
inputs=hr_resolution_preview_inputs,
@@ -576,7 +554,11 @@ def create_ui():
hr_second_pass_steps,
hr_resize_x,
hr_resize_y,
+ hr_sampler_index,
+ hr_prompt,
+ hr_negative_prompt,
override_settings,
+
] + custom_inputs,
outputs=[
@@ -591,7 +573,7 @@ def create_ui():
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
restore_progress_button.click(
fn=progress.restore_progress,
@@ -614,7 +596,8 @@ def create_ui():
outputs=[
txt2img_prompt,
txt_prompt_img
- ]
+ ],
+ show_progress=False,
)
enable_hr.change(
@@ -639,6 +622,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
+ (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
@@ -647,6 +631,11 @@ def create_ui():
(hr_second_pass_steps, "Hires steps"),
(hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"),
+ (hr_sampler_index, "Hires sampler"),
+ (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
+ (hr_prompt, "Hires prompt"),
+ (hr_negative_prompt, "Hires negative prompt"),
+ (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
*modules.scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
@@ -704,19 +693,19 @@ def create_ui():
img2img_selected_tab = gr.State(0)
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
- init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
+ init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
- sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
- init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
+ init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
- inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480)
+ inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
inpaint_color_sketch_orig = gr.State(None)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
@@ -736,17 +725,20 @@ def create_ui():
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(
- f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
- f"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
+ "<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
+ "<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
f"{hidden}</p>"
)
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
+ with gr.Accordion("PNG info", open=False):
+ img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
+ img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
+ img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
- img2img_image_inputs = [init_img, sketch, init_img_with_mask, inpaint_color_sketch]
for i, tab in enumerate(img2img_tabs):
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
@@ -773,6 +765,8 @@ def create_ui():
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
+ modules.scripts.scripts_img2img.prepare_ui()
+
for category in ordered_ui_categories():
if category == "sampler":
steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
@@ -783,15 +777,16 @@ def create_ui():
selected_scale_tab = gr.State(value=0)
with gr.Tabs():
- with gr.Tab(label="Resize to") as tab_scale_to:
+ with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
+ detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
- with gr.Tab(label="Resize by") as tab_scale_by:
+ with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
with FormRow():
@@ -881,6 +876,8 @@ def create_ui():
inputs=[],
outputs=[inpaint_controls, mask_alpha],
)
+ else:
+ modules.scripts.scripts_img2img.setup_ui_for_section(category)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
@@ -895,7 +892,8 @@ def create_ui():
outputs=[
img2img_prompt,
img2img_prompt_img
- ]
+ ],
+ show_progress=False,
)
img2img_args = dict(
@@ -940,6 +938,9 @@ def create_ui():
img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir,
override_settings,
+ img2img_batch_use_png_info,
+ img2img_batch_png_info_props,
+ img2img_batch_png_info_dir,
] + custom_inputs,
outputs=[
img2img_gallery,
@@ -967,7 +968,16 @@ def create_ui():
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
- res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
+
+ res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
+
+ detect_image_size_btn.click(
+ fn=lambda w, h, _: (w or gr.update(), h or gr.update()),
+ _js="currentImg2imgSourceResolution",
+ inputs=[dummy_component, dummy_component, dummy_component],
+ outputs=[width, height],
+ show_progress=False,
+ )
restore_progress_button.click(
fn=progress.restore_progress,
@@ -1035,6 +1045,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
+ (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields
@@ -1189,7 +1200,7 @@ def create_ui():
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
+
with gr.Column(visible=False) as process_multicrop_col:
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
with gr.Row():
@@ -1201,7 +1212,7 @@ def create_ui():
with gr.Row():
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
-
+
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
@@ -1230,7 +1241,7 @@ def create_ui():
)
def get_textual_inversion_template_names():
- return sorted([x for x in textual_inversion.textual_inversion_templates])
+ return sorted(textual_inversion.textual_inversion_templates)
with gr.Tab(label="Train", id="train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
@@ -1238,13 +1249,13 @@ def create_ui():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
- train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
- create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
+ train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
+ create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
with FormRow():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
-
+
with FormRow():
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
@@ -1290,8 +1301,8 @@ def create_ui():
with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
- ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
- ti_progress = gr.HTML(elem_id="ti_progress", value="")
+ gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
+ gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
create_embedding.click(
@@ -1444,194 +1455,10 @@ def create_ui():
outputs=[],
)
- def create_setting_component(key, is_quicksettings=False):
- def fun():
- return opts.data[key] if key in opts.data else opts.data_labels[key].default
-
- info = opts.data_labels[key]
- t = type(info.default)
-
- args = info.component_args() if callable(info.component_args) else info.component_args
-
- if info.component is not None:
- comp = info.component
- elif t == str:
- comp = gr.Textbox
- elif t == int:
- comp = gr.Number
- elif t == bool:
- comp = gr.Checkbox
- else:
- raise Exception(f'bad options item type: {t} for key {key}')
-
- elem_id = f"setting_{key}"
-
- if info.refresh is not None:
- if is_quicksettings:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
- else:
- with FormRow():
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
- create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
- else:
- res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
-
- return res
-
- components = []
- component_dict = {}
- shared.settings_components = component_dict
-
- script_callbacks.ui_settings_callback()
- opts.reorder()
-
- def run_settings(*args):
- changed = []
-
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
-
- for key, value, comp in zip(opts.data_labels.keys(), args, components):
- if comp == dummy_component:
- continue
-
- if opts.set(key, value):
- changed.append(key)
-
- try:
- opts.save(shared.config_filename)
- except RuntimeError:
- return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
- return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
-
- def run_settings_single(value, key):
- if not opts.same_type(value, opts.data_labels[key].default):
- return gr.update(visible=True), opts.dumpjson()
-
- if not opts.set(key, value):
- return gr.update(value=getattr(opts, key)), opts.dumpjson()
-
- opts.save(shared.config_filename)
-
- return get_value_for_setting(key), opts.dumpjson()
-
- with gr.Blocks(analytics_enabled=False) as settings_interface:
- with gr.Row():
- with gr.Column(scale=6):
- settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
- with gr.Column():
- restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
-
- result = gr.HTML(elem_id="settings_result")
-
- quicksettings_names = opts.quicksettings_list
- quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
-
- quicksettings_list = []
-
- previous_section = None
- current_tab = None
- current_row = None
- with gr.Tabs(elem_id="settings"):
- for i, (k, item) in enumerate(opts.data_labels.items()):
- section_must_be_skipped = item.section[0] is None
-
- if previous_section != item.section and not section_must_be_skipped:
- elem_id, text = item.section
-
- if current_tab is not None:
- current_row.__exit__()
- current_tab.__exit__()
-
- gr.Group()
- current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
- current_tab.__enter__()
- current_row = gr.Column(variant='compact')
- current_row.__enter__()
-
- previous_section = item.section
-
- if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
- quicksettings_list.append((i, k, item))
- components.append(dummy_component)
- elif section_must_be_skipped:
- components.append(dummy_component)
- else:
- component = create_setting_component(k)
- component_dict[k] = component
- components.append(component)
-
- if current_tab is not None:
- current_row.__exit__()
- current_tab.__exit__()
-
- with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
- request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
- download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
- reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
- with gr.Row():
- unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
- reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
-
- with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
- gr.HTML(shared.html("licenses.html"), elem_id="licenses")
-
- gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
-
-
- def unload_sd_weights():
- modules.sd_models.unload_model_weights()
-
- def reload_sd_weights():
- modules.sd_models.reload_model_weights()
-
- unload_sd_model.click(
- fn=unload_sd_weights,
- inputs=[],
- outputs=[]
- )
-
- reload_sd_model.click(
- fn=reload_sd_weights,
- inputs=[],
- outputs=[]
- )
-
- request_notifications.click(
- fn=lambda: None,
- inputs=[],
- outputs=[],
- _js='function(){}'
- )
-
- download_localization.click(
- fn=lambda: None,
- inputs=[],
- outputs=[],
- _js='download_localization'
- )
+ loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
- def reload_scripts():
- modules.scripts.reload_script_body_only()
- reload_javascript() # need to refresh the html page
-
- reload_script_bodies.click(
- fn=reload_scripts,
- inputs=[],
- outputs=[]
- )
-
- def request_restart():
- shared.state.interrupt()
- shared.state.need_restart = True
-
- restart_gradio.click(
- fn=request_restart,
- _js='restart_reload',
- inputs=[],
- outputs=[],
- )
+ settings = ui_settings.UiSettings()
+ settings.create_ui(loadsave, dummy_component)
interfaces = [
(txt2img_interface, "txt2img", "txt2img"),
@@ -1639,11 +1466,11 @@ def create_ui():
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
- (train_interface, "Train", "ti"),
+ (train_interface, "Train", "train"),
]
interfaces += script_callbacks.ui_tabs_callback()
- interfaces += [(settings_interface, "Settings", "settings")]
+ interfaces += [(settings.interface, "Settings", "settings")]
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
@@ -1653,76 +1480,48 @@ def create_ui():
shared.tab_names.append(label)
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
- with gr.Row(elem_id="quicksettings", variant="compact"):
- for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
- component = create_setting_component(k, is_quicksettings=True)
- component_dict[k] = component
+ settings.add_quicksettings()
parameters_copypaste.connect_paste_params_buttons()
with gr.Tabs(elem_id="tabs") as tabs:
- for interface, label, ifid in interfaces:
+ tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}
+ sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))
+
+ for interface, label, ifid in sorted_interfaces:
if label in shared.opts.hidden_tabs:
continue
with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
interface.render()
+ for interface, _label, ifid in interfaces:
+ if ifid in ["extensions", "settings"]:
+ continue
+
+ loadsave.add_block(interface, ifid)
+
+ loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
+
+ loadsave.setup_ui()
+
if os.path.exists(os.path.join(script_path, "notification.mp3")):
- audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
+ gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
footer = shared.html("footer.html")
- footer = footer.format(versions=versions_html())
+ footer = footer.format(versions=versions_html(), api_docs="/docs" if shared.cmd_opts.api else "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API")
gr.HTML(footer, elem_id="footer")
- text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
- settings_submit.click(
- fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
- inputs=components,
- outputs=[text_settings, result],
- )
-
- for i, k, item in quicksettings_list:
- component = component_dict[k]
- info = opts.data_labels[k]
-
- change_handler = component.release if hasattr(component, 'release') else component.change
- change_handler(
- fn=lambda value, k=k: run_settings_single(value, key=k),
- inputs=[component],
- outputs=[component, text_settings],
- show_progress=info.refresh is not None,
- )
+ settings.add_functionality(demo)
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
- text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
+ settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
- button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
- button_set_checkpoint.click(
- fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
- _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
- inputs=[component_dict['sd_model_checkpoint'], dummy_component],
- outputs=[component_dict['sd_model_checkpoint'], text_settings],
- )
-
- component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
-
- def get_settings_values():
- return [get_value_for_setting(key) for key in component_keys]
-
- demo.load(
- fn=get_settings_values,
- inputs=[],
- outputs=[component_dict[k] for k in component_keys],
- queue=False,
- )
-
def modelmerger(*args):
try:
results = modules.extras.run_modelmerger(*args)
except Exception as e:
- print("Error loading/saving model file:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report("Error loading/saving model file", exc_info=True)
modules.sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
@@ -1750,102 +1549,13 @@ def create_ui():
primary_model_name,
secondary_model_name,
tertiary_model_name,
- component_dict['sd_model_checkpoint'],
+ settings.component_dict['sd_model_checkpoint'],
modelmerger_result,
]
)
- ui_config_file = cmd_opts.ui_config_file
- ui_settings = {}
- settings_count = len(ui_settings)
- error_loading = False
-
- try:
- if os.path.exists(ui_config_file):
- with open(ui_config_file, "r", encoding="utf8") as file:
- ui_settings = json.load(file)
- except Exception:
- error_loading = True
- print("Error loading settings:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- def loadsave(path, x):
- def apply_field(obj, field, condition=None, init_field=None):
- key = f"{path}/{field}"
-
- if getattr(obj, 'custom_script_source', None) is not None:
- key = f"customscript/{obj.custom_script_source}/{key}"
-
- if getattr(obj, 'do_not_save_to_config', False):
- return
-
- saved_value = ui_settings.get(key, None)
- if saved_value is None:
- ui_settings[key] = getattr(obj, field)
- elif condition and not condition(saved_value):
- pass
-
- # this warning is generally not useful;
- # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
- else:
- setattr(obj, field, saved_value)
- if init_field is not None:
- init_field(saved_value)
-
- if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
- apply_field(x, 'visible')
-
- if type(x) == gr.Slider:
- apply_field(x, 'value')
- apply_field(x, 'minimum')
- apply_field(x, 'maximum')
- apply_field(x, 'step')
-
- if type(x) == gr.Radio:
- apply_field(x, 'value', lambda val: val in x.choices)
-
- if type(x) == gr.Checkbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Textbox:
- apply_field(x, 'value')
-
- if type(x) == gr.Number:
- apply_field(x, 'value')
-
- if type(x) == gr.Dropdown:
- def check_dropdown(val):
- if getattr(x, 'multiselect', False):
- return all([value in x.choices for value in val])
- else:
- return val in x.choices
-
- apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
-
- def check_tab_id(tab_id):
- tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
- if type(tab_id) == str:
- tab_ids = [t.id for t in tab_items]
- return tab_id in tab_ids
- elif type(tab_id) == int:
- return tab_id >= 0 and tab_id < len(tab_items)
- else:
- return False
-
- if type(x) == gr.Tabs:
- apply_field(x, 'selected', check_tab_id)
-
- visit(txt2img_interface, loadsave, "txt2img")
- visit(img2img_interface, loadsave, "img2img")
- visit(extras_interface, loadsave, "extras")
- visit(modelmerger_interface, loadsave, "modelmerger")
- visit(train_interface, loadsave, "train")
-
- loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)
-
- if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
- with open(ui_config_file, "w", encoding="utf8") as file:
- json.dump(ui_settings, file, indent=4)
+ loadsave.dump_defaults()
+ demo.ui_loadsave = loadsave
# Required as a workaround for change() event not triggering when loading values from ui-config.json
interp_description.value = update_interp_description(interp_method.value)
@@ -1853,70 +1563,6 @@ def create_ui():
return demo
-def webpath(fn):
- if fn.startswith(script_path):
- web_path = os.path.relpath(fn, script_path).replace('\\', '/')
- else:
- web_path = os.path.abspath(fn)
-
- return f'file={web_path}?{os.path.getmtime(fn)}'
-
-
-def javascript_html():
- # Ensure localization is in `window` before scripts
- head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
-
- script_js = os.path.join(script_path, "script.js")
- head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
-
- for script in modules.scripts.list_scripts("javascript", ".js"):
- head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
-
- for script in modules.scripts.list_scripts("javascript", ".mjs"):
- head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
-
- if cmd_opts.theme:
- head += f'<script type="text/javascript">set_theme(\"{cmd_opts.theme}\");</script>\n'
-
- return head
-
-
-def css_html():
- head = ""
-
- def stylesheet(fn):
- return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
-
- for cssfile in modules.scripts.list_files_with_name("style.css"):
- if not os.path.isfile(cssfile):
- continue
-
- head += stylesheet(cssfile)
-
- if os.path.exists(os.path.join(data_path, "user.css")):
- head += stylesheet(os.path.join(data_path, "user.css"))
-
- return head
-
-
-def reload_javascript():
- js = javascript_html()
- css = css_html()
-
- def template_response(*args, **kwargs):
- res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
- res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
- res.init_headers()
- return res
-
- gradio.routes.templates.TemplateResponse = template_response
-
-
-if not hasattr(shared, 'GradioTemplateResponseOriginal'):
- shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
-
-
def versions_html():
import torch
import launch
@@ -1933,15 +1579,15 @@ def versions_html():
return f"""
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
- • 
+&#x2000;•&#x2000;
python: <span title="{sys.version}">{python_version}</span>
- • 
+&#x2000;•&#x2000;
torch: {getattr(torch, '__long_version__',torch.__version__)}
- • 
+&#x2000;•&#x2000;
xformers: {xformers_version}
- • 
+&#x2000;•&#x2000;
gradio: {gr.__version__}
- • 
+&#x2000;•&#x2000;
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
"""
@@ -1960,3 +1606,17 @@ def setup_ui_api(app):
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
+
+ app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])
+
+ def download_sysinfo(attachment=False):
+ from fastapi.responses import PlainTextResponse
+
+ text = sysinfo.get()
+ filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
+
+ return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
+
+ app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
+ app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
+
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 27ab3ebb..57c2d0ad 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -10,8 +10,11 @@ import subprocess as sp
from modules import call_queue, shared
from modules.generation_parameters_copypaste import image_from_url_text
import modules.images
+from modules.ui_components import ToolButton
+
folder_symbol = '\U0001f4c2' # 📂
+refresh_symbol = '\U0001f504' # 🔄
def update_generation_info(generation_info, html_info, img_index):
@@ -50,9 +53,10 @@ def save_files(js_data, images, do_make_zip, index):
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
extension: str = shared.opts.samples_format
start_index = 0
+ only_one = False
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
-
+ only_one = True
images = [images[index]]
start_index = index
@@ -70,6 +74,7 @@ def save_files(js_data, images, do_make_zip, index):
is_grid = image_index < p.index_of_first_image
i = 0 if is_grid else (image_index - p.index_of_first_image)
+ p.batch_index = image_index-1
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
filename = os.path.relpath(fullfn, path)
@@ -83,7 +88,10 @@ def save_files(js_data, images, do_make_zip, index):
# Make Zip
if do_make_zip:
- zip_filepath = os.path.join(path, "images.zip")
+ zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
+ namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
+ zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
+ zip_filepath = os.path.join(path, f"{zip_filename}.zip")
from zipfile import ZipFile
with ZipFile(zip_filepath, "w") as zip_file:
@@ -211,3 +219,23 @@ Requested path was: {f}
))
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
+
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index d9faf85a..dff522ef 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -1,9 +1,8 @@
import json
import os.path
-import sys
+import threading
import time
from datetime import datetime
-import traceback
import git
@@ -12,7 +11,7 @@ import html
import shutil
import errno
-from modules import extensions, shared, paths, config_states
+from modules import extensions, shared, paths, config_states, errors, restart
from modules.paths_internal import config_states_dir
from modules.call_queue import wrap_gradio_gpu_call
@@ -45,15 +44,16 @@ def apply_and_restart(disable_list, update_list, disable_all):
try:
ext.fetch_and_reset_hard()
except Exception:
- print(f"Error getting updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error getting updates for {ext.name}", exc_info=True)
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename)
- shared.state.interrupt()
- shared.state.need_restart = True
+ if restart.is_restartable():
+ restart.restart_program()
+ else:
+ restart.stop_program()
def save_config_state(name):
@@ -91,8 +91,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
if restore_type == "webui" or restore_type == "both":
config_states.restore_webui_config(config_state)
- shared.state.interrupt()
- shared.state.need_restart = True
+ shared.state.request_restart()
return ""
@@ -115,8 +114,7 @@ def check_updates(id_task, disable_list):
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
- print(f"Error checking updates for {ext.name}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error checking updates for {ext.name}", exc_info=True)
shared.state.nextjob()
@@ -127,7 +125,9 @@ def make_commit_link(commit_hash, remote, text=None):
if text is None:
text = commit_hash[:8]
if remote.startswith("https://github.com/"):
- href = os.path.join(remote, "commit", commit_hash)
+ if remote.endswith(".git"):
+ remote = remote[:-4]
+ href = remote + "/commit/" + commit_hash
return f'<a href="{href}" target="_blank">{text}</a>'
else:
return text
@@ -138,9 +138,14 @@ def extension_table():
<table id="extensions">
<thead>
<tr>
- <th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
+ <th>
+ <input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
+ <abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
+ </th>
<th>URL</th>
- <th><abbr title="Extension version">Version</abbr></th>
+ <th>Branch</th>
+ <th>Version</th>
+ <th>Date</th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
@@ -148,6 +153,7 @@ def extension_table():
"""
for ext in extensions.extensions:
+ ext: extensions.Extension
ext.read_info_from_repo()
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
@@ -167,9 +173,11 @@ def extension_table():
code += f"""
<tr>
- <td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
+ <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
<td>{remote}</td>
+ <td>{ext.branch}</td>
<td>{version_link}</td>
+ <td>{time.asctime(time.gmtime(ext.commit_date))}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
@@ -320,6 +328,11 @@ def normalize_git_url(url):
def install_extension_from_url(dirname, url, branch_name=None):
check_access()
+ if isinstance(dirname, str):
+ dirname = dirname.strip()
+ if isinstance(url, str):
+ url = url.strip()
+
assert url, 'No URL specified'
if dirname is None or dirname == "":
@@ -332,7 +345,8 @@ def install_extension_from_url(dirname, url, branch_name=None):
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
normalized_url = normalize_git_url(url)
- assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
+ if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url):
+ raise Exception(f'Extension with this URL is already installed: {url}')
tmpdir = os.path.join(paths.data_path, "tmp", dirname)
@@ -340,12 +354,12 @@ def install_extension_from_url(dirname, url, branch_name=None):
shutil.rmtree(tmpdir, True)
if not branch_name:
# if no branch is specified, use the default branch
- with git.Repo.clone_from(url, tmpdir) as repo:
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
else:
- with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
+ with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
@@ -410,9 +424,19 @@ sort_ordering = [
(False, lambda x: x.get('name', 'z')),
(True, lambda x: x.get('name', 'z')),
(False, lambda x: 'z'),
+ (True, lambda x: x.get('commit_time', '')),
+ (True, lambda x: x.get('created_at', '')),
+ (True, lambda x: x.get('stars', 0)),
]
+def get_date(info: dict, key):
+ try:
+ return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
+ except (ValueError, TypeError):
+ return ''
+
+
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@@ -437,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
name = ext.get("name", "noname")
+ stars = int(ext.get("stars", 0))
added = ext.get('added', 'unknown')
+ update_time = get_date(ext, 'commit_time')
+ create_time = get_date(ext, 'created_at')
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
@@ -448,7 +475,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
existing = installed_extension_urls.get(normalize_git_url(url), None)
extension_tags = extension_tags + ["installed"] if existing else extension_tags
- if len([x for x in extension_tags if x in tags_to_hide]) > 0:
+ if any(x for x in extension_tags if x in tags_to_hide):
hidden += 1
continue
@@ -464,10 +491,11 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
code += f"""
<tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
- <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
+ <td>{html.escape(description)}<p class="info">
+ <span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
<td>{install_code}</td>
</tr>
-
+
"""
for tag in [x for x in extension_tags if x not in tags]:
@@ -484,17 +512,31 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
return code, list(tags)
+def preload_extensions_git_metadata():
+ t0 = time.time()
+ for extension in extensions.extensions:
+ extension.read_info_from_repo()
+ print(
+ f"preload_extensions_git_metadata for "
+ f"{len(extensions.extensions)} extensions took "
+ f"{time.time() - t0:.2f}s"
+ )
+
+
def create_ui():
import modules.ui
config_states.list_config_states()
+ threading.Thread(target=preload_extensions_git_metadata).start()
+
with gr.Blocks(analytics_enabled=False) as ui:
- with gr.Tabs(elem_id="tabs_extensions") as tabs:
+ with gr.Tabs(elem_id="tabs_extensions"):
with gr.TabItem("Installed", id="installed"):
with gr.Row(elem_id="extensions_installed_top"):
- apply = gr.Button(value="Apply and restart UI", variant="primary")
+ apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit")
+ apply = gr.Button(value=apply_label, variant="primary")
check = gr.Button(value="Check for updates")
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
@@ -508,7 +550,8 @@ def create_ui():
</span>
"""
info = gr.HTML(html)
- extensions_table = gr.HTML(lambda: extension_table())
+ extensions_table = gr.HTML('Loading...')
+ ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
apply.click(
fn=apply_and_restart,
@@ -533,18 +576,18 @@ def create_ui():
with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
- sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
+ sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
- with gr.Row():
+ with gr.Row():
search_extensions_text = gr.Text(label="Search").style(container=False)
-
+
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
- fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
+ fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags, sort_column],
- outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
+ outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
)
install_extension_button.click(
@@ -579,9 +622,9 @@ def create_ui():
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
- fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
+ fn=modules.ui.wrap_gradio_call(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),
inputs=[install_dirname, install_url, install_branch],
- outputs=[extensions_table, install_result],
+ outputs=[install_url, extensions_table, install_result],
)
with gr.TabItem("Backup/Restore"):
@@ -595,7 +638,8 @@ def create_ui():
config_save_button = gr.Button(value="Save Current Config")
config_states_info = gr.HTML("")
- config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
+ config_states_table = gr.HTML("Loading...")
+ ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
@@ -608,4 +652,5 @@ def create_ui():
outputs=[config_states_table],
)
+
return ui
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 8c3dea56..693cafb6 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -1,11 +1,10 @@
-import glob
import os.path
import urllib.parse
from pathlib import Path
-from PIL import PngImagePlugin
from modules import shared
-from modules.images import read_info_from_image
+from modules.images import read_info_from_image, save_image_with_geninfo
+from modules.ui import up_down_symbol
import gradio as gr
import json
import html
@@ -27,12 +26,12 @@ def register_page(page):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
- if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
+ if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
- if ext not in (".png", ".jpg", ".webp"):
- raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
+ if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
+ raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@@ -91,8 +90,8 @@ class ExtraNetworksPage:
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
- for root, dirs, files in os.walk(parentdir):
- for dirname in dirs:
+ for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
+ for dirname in sorted(dirs, key=shared.natural_sort_key):
x = os.path.join(root, dirname)
if not os.path.isdir(x):
@@ -106,6 +105,9 @@ class ExtraNetworksPage:
if not is_empty and not subdir.endswith("/"):
subdir = subdir + "/"
+ if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
+ continue
+
subdirs[subdir] = 1
if subdirs:
@@ -148,6 +150,10 @@ class ExtraNetworksPage:
return []
def create_html_for_item(self, item, tabname):
+ """
+ Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
+ """
+
preview = item.get("preview", None)
onclick = item.get("onclick", None)
@@ -156,7 +162,7 @@ class ExtraNetworksPage:
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
- background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
+ background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
metadata_button = ""
metadata = item.get("metadata")
if metadata:
@@ -170,12 +176,21 @@ class ExtraNetworksPage:
if filename.startswith(absdir):
local_path = filename[len(absdir):]
- # if this is true, the item must not be show in the default view, and must instead only be
+ # if this is true, the item must not be shown in the default view, and must instead only be
# shown when searching for it
- serach_only = "/." in local_path or "\\." in local_path
+ if shared.opts.extra_networks_hidden_models == "Always":
+ search_only = False
+ else:
+ search_only = "/." in local_path or "\\." in local_path
+
+ if search_only and shared.opts.extra_networks_hidden_models == "Never":
+ return ""
+
+ sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
args = {
- "style": f"'display: none; {height}{width}{background_image}'",
+ "background_image": background_image,
+ "style": f"'display: none; {height}{width}'",
"prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
@@ -185,17 +200,30 @@ class ExtraNetworksPage:
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
- "serach_only": " search_only" if serach_only else "",
+ "search_only": " search_only" if search_only else "",
+ "sort_keys": sort_keys,
}
return self.card_page.format(**args)
+ def get_sort_keys(self, path):
+ """
+ List of default keys used for sorting in the UI.
+ """
+ pth = Path(path)
+ stat = pth.stat()
+ return {
+ "date_created": int(stat.st_ctime or 0),
+ "date_modified": int(stat.st_mtime or 0),
+ "name": pth.name.lower(),
+ }
+
def find_preview(self, path):
"""
Find a preview PNG for a given path (without extension) and call link_preview on it.
"""
- preview_extensions = ["png", "jpg", "webp"]
+ preview_extensions = ["png", "jpg", "jpeg", "webp"]
if shared.opts.samples_format not in preview_extensions:
preview_extensions.append(shared.opts.samples_format)
@@ -220,10 +248,19 @@ class ExtraNetworksPage:
return None
-def intialize():
+def initialize():
extra_pages.clear()
+def register_default_pages():
+ from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
+ from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
+ from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
+ register_page(ExtraNetworksPageTextualInversion())
+ register_page(ExtraNetworksPageHypernetworks())
+ register_page(ExtraNetworksPageCheckpoints())
+
+
class ExtraNetworksUi:
def __init__(self):
self.pages = None
@@ -263,18 +300,20 @@ def create_ui(container, button, tabname):
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
- with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+ with gr.Tabs(elem_id=tabname+"_extra_tabs"):
for page in ui.stored_extra_pages:
page_id = page.title.lower().replace(" ", "_")
with gr.Tab(page.title, id=page_id):
elem_id = f"{tabname}_{page_id}_cards_html"
- page_elem = gr.HTML('', elem_id=elem_id)
+ page_elem = gr.HTML('Loading...', elem_id=elem_id)
ui.pages.append(page_elem)
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
+ gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True)
+ gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder")
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
@@ -283,13 +322,24 @@ def create_ui(container, button, tabname):
def toggle_visibility(is_visible):
is_visible = not is_visible
- if is_visible and not ui.pages_contents:
+ return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
+
+ def fill_tabs(is_empty):
+ """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time."""
+
+ if not ui.pages_contents:
refresh()
- return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
+ if is_empty:
+ return True, *ui.pages_contents
+
+ return True, *[gr.update() for _ in ui.pages_contents]
state_visible = gr.State(value=False)
- button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
+ button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False)
+
+ state_empty = gr.State(value=True)
+ button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False)
def refresh():
for pg in ui.stored_extra_pages:
@@ -327,18 +377,13 @@ def setup_ui(ui, gallery):
is_allowed = False
for extra_page in ui.stored_extra_pages:
- if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
is_allowed = True
break
assert is_allowed, f'writing to {filename} is not allowed'
- if geninfo:
- pnginfo_data = PngImagePlugin.PngInfo()
- pnginfo_data.add_text('parameters', geninfo)
- image.save(filename, pnginfo=pnginfo_data)
- else:
- image.save(filename)
+ save_image_with_geninfo(image, geninfo, filename)
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index a17aa9c9..8b9ab71b 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -14,7 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def list_items(self):
checkpoint: sd_models.CheckpointInfo
- for name, checkpoint in sd_models.checkpoints_list.items():
+ for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()):
path, ext = os.path.splitext(checkpoint.filename)
yield {
"name": checkpoint.name_for_extra,
@@ -24,6 +24,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"local_preview": f"{path}.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
+
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index 6187e000..7c19b532 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -12,7 +12,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
shared.reload_hypernetworks()
def list_items(self):
- for name, path in shared.hypernetworks.items():
+ for index, (name, path) in enumerate(shared.hypernetworks.items()):
path, ext = os.path.splitext(path)
yield {
@@ -23,6 +23,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(path),
"prompt": json.dumps(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
+
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index 6944d559..58a61c55 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -13,7 +13,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
def list_items(self):
- for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+ for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()):
path, ext = os.path.splitext(embedding.filename)
yield {
"name": embedding.name,
@@ -23,6 +23,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(embedding.filename),
"prompt": json.dumps(embedding.name),
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
+ "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
+
}
def allowed_directories_for_previews(self):
diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py
new file mode 100644
index 00000000..b824b113
--- /dev/null
+++ b/modules/ui_gradio_extensions.py
@@ -0,0 +1,69 @@
+import os
+import gradio as gr
+
+from modules import localization, shared, scripts
+from modules.paths import script_path, data_path
+
+
+def webpath(fn):
+ if fn.startswith(script_path):
+ web_path = os.path.relpath(fn, script_path).replace('\\', '/')
+ else:
+ web_path = os.path.abspath(fn)
+
+ return f'file={web_path}?{os.path.getmtime(fn)}'
+
+
+def javascript_html():
+ # Ensure localization is in `window` before scripts
+ head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
+
+ script_js = os.path.join(script_path, "script.js")
+ head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
+
+ for script in scripts.list_scripts("javascript", ".js"):
+ head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
+
+ for script in scripts.list_scripts("javascript", ".mjs"):
+ head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
+
+ if shared.cmd_opts.theme:
+ head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
+
+ return head
+
+
+def css_html():
+ head = ""
+
+ def stylesheet(fn):
+ return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
+
+ for cssfile in scripts.list_files_with_name("style.css"):
+ if not os.path.isfile(cssfile):
+ continue
+
+ head += stylesheet(cssfile)
+
+ if os.path.exists(os.path.join(data_path, "user.css")):
+ head += stylesheet(os.path.join(data_path, "user.css"))
+
+ return head
+
+
+def reload_javascript():
+ js = javascript_html()
+ css = css_html()
+
+ def template_response(*args, **kwargs):
+ res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
+ res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
+ res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
+ res.init_headers()
+ return res
+
+ gr.routes.templates.TemplateResponse = template_response
+
+
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+ shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse
diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py
new file mode 100644
index 00000000..0052a5cc
--- /dev/null
+++ b/modules/ui_loadsave.py
@@ -0,0 +1,210 @@
+import json
+import os
+
+import gradio as gr
+
+from modules import errors
+from modules.ui_components import ToolButton
+
+
+class UiLoadsave:
+ """allows saving and restorig default values for gradio components"""
+
+ def __init__(self, filename):
+ self.filename = filename
+ self.ui_settings = {}
+ self.component_mapping = {}
+ self.error_loading = False
+ self.finalized_ui = False
+
+ self.ui_defaults_view = None
+ self.ui_defaults_apply = None
+ self.ui_defaults_review = None
+
+ try:
+ if os.path.exists(self.filename):
+ self.ui_settings = self.read_from_file()
+ except Exception as e:
+ self.error_loading = True
+ errors.display(e, "loading settings")
+
+ def add_component(self, path, x):
+ """adds component to the registry of tracked components"""
+
+ assert not self.finalized_ui
+
+ def apply_field(obj, field, condition=None, init_field=None):
+ key = f"{path}/{field}"
+
+ if getattr(obj, 'custom_script_source', None) is not None:
+ key = f"customscript/{obj.custom_script_source}/{key}"
+
+ if getattr(obj, 'do_not_save_to_config', False):
+ return
+
+ saved_value = self.ui_settings.get(key, None)
+ if saved_value is None:
+ self.ui_settings[key] = getattr(obj, field)
+ elif condition and not condition(saved_value):
+ pass
+ else:
+ setattr(obj, field, saved_value)
+ if init_field is not None:
+ init_field(saved_value)
+
+ if field == 'value' and key not in self.component_mapping:
+ self.component_mapping[key] = x
+
+ if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
+ apply_field(x, 'visible')
+
+ if type(x) == gr.Slider:
+ apply_field(x, 'value')
+ apply_field(x, 'minimum')
+ apply_field(x, 'maximum')
+ apply_field(x, 'step')
+
+ if type(x) == gr.Radio:
+ apply_field(x, 'value', lambda val: val in x.choices)
+
+ if type(x) == gr.Checkbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Textbox:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Number:
+ apply_field(x, 'value')
+
+ if type(x) == gr.Dropdown:
+ def check_dropdown(val):
+ if getattr(x, 'multiselect', False):
+ return all(value in x.choices for value in val)
+ else:
+ return val in x.choices
+
+ apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
+
+ def check_tab_id(tab_id):
+ tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
+ if type(tab_id) == str:
+ tab_ids = [t.id for t in tab_items]
+ return tab_id in tab_ids
+ elif type(tab_id) == int:
+ return 0 <= tab_id < len(tab_items)
+ else:
+ return False
+
+ if type(x) == gr.Tabs:
+ apply_field(x, 'selected', check_tab_id)
+
+ def add_block(self, x, path=""):
+ """adds all components inside a gradio block x to the registry of tracked components"""
+
+ if hasattr(x, 'children'):
+ if isinstance(x, gr.Tabs) and x.elem_id is not None:
+ # Tabs element can't have a label, have to use elem_id instead
+ self.add_component(f"{path}/Tabs@{x.elem_id}", x)
+ for c in x.children:
+ self.add_block(c, path)
+ elif x.label is not None:
+ self.add_component(f"{path}/{x.label}", x)
+ elif isinstance(x, gr.Button) and x.value is not None:
+ self.add_component(f"{path}/{x.value}", x)
+
+ def read_from_file(self):
+ with open(self.filename, "r", encoding="utf8") as file:
+ return json.load(file)
+
+ def write_to_file(self, current_ui_settings):
+ with open(self.filename, "w", encoding="utf8") as file:
+ json.dump(current_ui_settings, file, indent=4)
+
+ def dump_defaults(self):
+ """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
+
+ if self.error_loading and os.path.exists(self.filename):
+ return
+
+ self.write_to_file(self.ui_settings)
+
+ def iter_changes(self, current_ui_settings, values):
+ """
+ given a dictionary with defaults from a file and current values from gradio elements, returns
+ an iterator over tuples of values that are not the same between the file and the current;
+ tuple contents are: path, old value, new value
+ """
+
+ for (path, component), new_value in zip(self.component_mapping.items(), values):
+ old_value = current_ui_settings.get(path)
+
+ choices = getattr(component, 'choices', None)
+ if isinstance(new_value, int) and choices:
+ if new_value >= len(choices):
+ continue
+
+ new_value = choices[new_value]
+
+ if new_value == old_value:
+ continue
+
+ if old_value is None and new_value == '' or new_value == []:
+ continue
+
+ yield path, old_value, new_value
+
+ def ui_view(self, *values):
+ text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
+
+ for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
+ if old_value is None:
+ old_value = "<span class='ui-defaults-none'>None</span>"
+
+ text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
+
+ if len(text) == 1:
+ text.append("<tr><td colspan=3>No changes</td></tr>")
+
+ text.append("</tbody>")
+ return "".join(text)
+
+ def ui_apply(self, *values):
+ num_changed = 0
+
+ current_ui_settings = self.read_from_file()
+
+ for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
+ num_changed += 1
+ current_ui_settings[path] = new_value
+
+ if num_changed == 0:
+ return "No changes."
+
+ self.write_to_file(current_ui_settings)
+
+ return f"Wrote {num_changed} changes."
+
+ def create_ui(self):
+ """creates ui elements for editing defaults UI, without adding any logic to them"""
+
+ gr.HTML(
+ f"This page allows you to change default values in UI elements on other tabs.<br />"
+ f"Make your changes, press 'View changes' to review the changed default values,<br />"
+ f"then press 'Apply' to write them to {self.filename}.<br />"
+ f"New defaults will apply after you restart the UI.<br />"
+ )
+
+ with gr.Row():
+ self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
+ self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
+
+ self.ui_defaults_review = gr.HTML("")
+
+ def setup_ui(self):
+ """adds logic to elements created with create_ui; all add_block class must be made before this"""
+
+ assert not self.finalized_ui
+ self.finalized_ui = True
+
+ self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
+ self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index f25639e5..c7dc1154 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -1,5 +1,5 @@
import gradio as gr
-from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
+from modules import scripts, shared, ui_common, postprocessing, call_queue
import modules.generation_parameters_copypaste as parameters_copypaste
diff --git a/modules/ui_settings.py b/modules/ui_settings.py
new file mode 100644
index 00000000..0c560b30
--- /dev/null
+++ b/modules/ui_settings.py
@@ -0,0 +1,289 @@
+import gradio as gr
+
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
+from modules.call_queue import wrap_gradio_call
+from modules.shared import opts
+from modules.ui_components import FormRow
+from modules.ui_gradio_extensions import reload_javascript
+
+
+def get_value_for_setting(key):
+ value = getattr(opts, key)
+
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+ return gr.update(value=value, **args)
+
+
+def create_setting_component(key, is_quicksettings=False):
+ def fun():
+ return opts.data[key] if key in opts.data else opts.data_labels[key].default
+
+ info = opts.data_labels[key]
+ t = type(info.default)
+
+ args = info.component_args() if callable(info.component_args) else info.component_args
+
+ if info.component is not None:
+ comp = info.component
+ elif t == str:
+ comp = gr.Textbox
+ elif t == int:
+ comp = gr.Number
+ elif t == bool:
+ comp = gr.Checkbox
+ else:
+ raise Exception(f'bad options item type: {t} for key {key}')
+
+ elem_id = f"setting_{key}"
+
+ if info.refresh is not None:
+ if is_quicksettings:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
+ else:
+ with FormRow():
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+ ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
+ else:
+ res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+
+ return res
+
+
+class UiSettings:
+ submit = None
+ result = None
+ interface = None
+ components = None
+ component_dict = None
+ dummy_component = None
+ quicksettings_list = None
+ quicksettings_names = None
+ text_settings = None
+
+ def run_settings(self, *args):
+ changed = []
+
+ for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
+ assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
+
+ for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
+ if comp == self.dummy_component:
+ continue
+
+ if opts.set(key, value):
+ changed.append(key)
+
+ try:
+ opts.save(shared.config_filename)
+ except RuntimeError:
+ return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+ return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.'
+
+ def run_settings_single(self, value, key):
+ if not opts.same_type(value, opts.data_labels[key].default):
+ return gr.update(visible=True), opts.dumpjson()
+
+ if not opts.set(key, value):
+ return gr.update(value=getattr(opts, key)), opts.dumpjson()
+
+ opts.save(shared.config_filename)
+
+ return get_value_for_setting(key), opts.dumpjson()
+
+ def create_ui(self, loadsave, dummy_component):
+ self.components = []
+ self.component_dict = {}
+ self.dummy_component = dummy_component
+
+ shared.settings_components = self.component_dict
+
+ script_callbacks.ui_settings_callback()
+ opts.reorder()
+
+ with gr.Blocks(analytics_enabled=False) as settings_interface:
+ with gr.Row():
+ with gr.Column(scale=6):
+ self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+ with gr.Column():
+ restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
+
+ self.result = gr.HTML(elem_id="settings_result")
+
+ self.quicksettings_names = opts.quicksettings_list
+ self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
+
+ self.quicksettings_list = []
+
+ previous_section = None
+ current_tab = None
+ current_row = None
+ with gr.Tabs(elem_id="settings"):
+ for i, (k, item) in enumerate(opts.data_labels.items()):
+ section_must_be_skipped = item.section[0] is None
+
+ if previous_section != item.section and not section_must_be_skipped:
+ elem_id, text = item.section
+
+ if current_tab is not None:
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ gr.Group()
+ current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
+ current_tab.__enter__()
+ current_row = gr.Column(variant='compact')
+ current_row.__enter__()
+
+ previous_section = item.section
+
+ if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
+ self.quicksettings_list.append((i, k, item))
+ self.components.append(dummy_component)
+ elif section_must_be_skipped:
+ self.components.append(dummy_component)
+ else:
+ component = create_setting_component(k)
+ self.component_dict[k] = component
+ self.components.append(component)
+
+ if current_tab is not None:
+ current_row.__exit__()
+ current_tab.__exit__()
+
+ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
+ loadsave.create_ui()
+
+ with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
+ gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ sysinfo_check_file = gr.File(label="Check system info for validity", type='binary')
+ with gr.Column(scale=1):
+ sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity")
+ with gr.Column(scale=100):
+ pass
+
+ with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
+ request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+ download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+ reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+ with gr.Row():
+ unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
+ reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
+
+ with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
+
+ gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
+ self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+
+ unload_sd_model.click(
+ fn=sd_models.unload_model_weights,
+ inputs=[],
+ outputs=[]
+ )
+
+ reload_sd_model.click(
+ fn=sd_models.reload_model_weights,
+ inputs=[],
+ outputs=[]
+ )
+
+ request_notifications.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='function(){}'
+ )
+
+ download_localization.click(
+ fn=lambda: None,
+ inputs=[],
+ outputs=[],
+ _js='download_localization'
+ )
+
+ def reload_scripts():
+ scripts.reload_script_body_only()
+ reload_javascript() # need to refresh the html page
+
+ reload_script_bodies.click(
+ fn=reload_scripts,
+ inputs=[],
+ outputs=[]
+ )
+
+ restart_gradio.click(
+ fn=shared.state.request_restart,
+ _js='restart_reload',
+ inputs=[],
+ outputs=[],
+ )
+
+ def check_file(x):
+ if x is None:
+ return ''
+
+ if sysinfo.check(x.decode('utf8', errors='ignore')):
+ return 'Valid'
+
+ return 'Invalid'
+
+ sysinfo_check_file.change(
+ fn=check_file,
+ inputs=[sysinfo_check_file],
+ outputs=[sysinfo_check_output],
+ )
+
+ self.interface = settings_interface
+
+ def add_quicksettings(self):
+ with gr.Row(elem_id="quicksettings", variant="compact"):
+ for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
+ component = create_setting_component(k, is_quicksettings=True)
+ self.component_dict[k] = component
+
+ def add_functionality(self, demo):
+ self.submit.click(
+ fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
+ inputs=self.components,
+ outputs=[self.text_settings, self.result],
+ )
+
+ for _i, k, _item in self.quicksettings_list:
+ component = self.component_dict[k]
+ info = opts.data_labels[k]
+
+ change_handler = component.release if hasattr(component, 'release') else component.change
+ change_handler(
+ fn=lambda value, k=k: self.run_settings_single(value, key=k),
+ inputs=[component],
+ outputs=[component, self.text_settings],
+ show_progress=info.refresh is not None,
+ )
+
+ button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
+ button_set_checkpoint.click(
+ fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
+ _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+ inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
+ outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
+ )
+
+ component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
+
+ def get_settings_values():
+ return [get_value_for_setting(key) for key in component_keys]
+
+ demo.load(
+ fn=get_settings_values,
+ inputs=[],
+ outputs=[self.component_dict[k] for k in component_keys],
+ queue=False,
+ )
diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py
index 46fa9cb0..fb75137e 100644
--- a/modules/ui_tempdir.py
+++ b/modules/ui_tempdir.py
@@ -3,7 +3,7 @@ import tempfile
from collections import namedtuple
from pathlib import Path
-import gradio as gr
+import gradio.components
from PIL import PngImagePlugin
@@ -23,7 +23,7 @@ def register_tmp_file(gradio, filename):
def check_tmp_file(gradio, filename):
if hasattr(gradio, 'temp_file_sets'):
- return any([filename in fileset for fileset in gradio.temp_file_sets])
+ return any(filename in fileset for fileset in gradio.temp_file_sets)
if hasattr(gradio, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
@@ -31,13 +31,16 @@ def check_tmp_file(gradio, filename):
return False
-def save_pil_to_file(pil_image, dir=None):
+def save_pil_to_file(self, pil_image, dir=None, format="png"):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
+ filename = already_saved_as
- file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
- return file_obj
+ if not shared.opts.save_images_add_number:
+ filename += f'?{os.path.getmtime(already_saved_as)}'
+
+ return filename
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
@@ -51,11 +54,11 @@ def save_pil_to_file(pil_image, dir=None):
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
- return file_obj
+ return file_obj.name
# override save to file function so that it also writes PNG info
-gr.processing_utils.save_pil_to_file = save_pil_to_file
+gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
def on_tmpdir_changed():
@@ -72,7 +75,7 @@ def cleanup_tmpdr():
if temp_dir == "" or not os.path.isdir(temp_dir):
return
- for root, dirs, files in os.walk(temp_dir, topdown=False):
+ for root, _, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
diff --git a/modules/upscaler.py b/modules/upscaler.py
index e2eaa730..e682bbaa 100644
--- a/modules/upscaler.py
+++ b/modules/upscaler.py
@@ -2,8 +2,6 @@ import os
from abc import abstractmethod
import PIL
-import numpy as np
-import torch
from PIL import Image
import modules.shared
@@ -36,6 +34,7 @@ class Upscaler:
self.half = not modules.shared.cmd_opts.no_half
self.pre_pad = 0
self.mod_scale = None
+ self.model_download_path = None
if self.model_path is None and self.name:
self.model_path = os.path.join(shared.models_path, self.name)
@@ -43,9 +42,9 @@ class Upscaler:
os.makedirs(self.model_path, exist_ok=True)
try:
- import cv2
+ import cv2 # noqa: F401
self.can_tile = True
- except:
+ except Exception:
pass
@abstractmethod
@@ -54,10 +53,10 @@ class Upscaler:
def upscale(self, img: PIL.Image, scale, selected_model: str = None):
self.scale = scale
- dest_w = int(img.width * scale)
- dest_h = int(img.height * scale)
+ dest_w = int((img.width * scale) // 8 * 8)
+ dest_h = int((img.height * scale) // 8 * 8)
- for i in range(3):
+ for _ in range(3):
shape = (img.width, img.height)
img = self.do_upscale(img, selected_model)
@@ -78,7 +77,7 @@ class Upscaler:
pass
def find_models(self, ext_filter=None) -> list:
- return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
+ return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)
def update_status(self, prompt):
print(f"\nextras: {prompt}", file=shared.progress_print_out)
diff --git a/modules/xlmr.py b/modules/xlmr.py
index beab3fdf..a407a3ca 100644
--- a/modules/xlmr.py
+++ b/modules/xlmr.py
@@ -1,4 +1,4 @@
-from transformers import BertPreTrainedModel,BertModel,BertConfig
+from transformers import BertPreTrainedModel, BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
- # modify initialization for autoloading
+ # modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
@@ -74,7 +74,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
- return features['projection_state']
+ return features['projection_state']
def forward(
self,
@@ -134,4 +134,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
- config_class= RobertaSeriesConfig \ No newline at end of file
+ config_class= RobertaSeriesConfig
diff --git a/package.json b/package.json
new file mode 100644
index 00000000..c0ba4067
--- /dev/null
+++ b/package.json
@@ -0,0 +1,11 @@
+{
+ "name": "stable-diffusion-webui",
+ "version": "0.0.0",
+ "devDependencies": {
+ "eslint": "^8.40.0"
+ },
+ "scripts": {
+ "lint": "eslint .",
+ "fix": "eslint --fix ."
+ }
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..80541a8f
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,35 @@
+[tool.ruff]
+
+target-version = "py39"
+
+extend-select = [
+ "B",
+ "C",
+ "I",
+ "W",
+]
+
+exclude = [
+ "extensions",
+ "extensions-disabled",
+]
+
+ignore = [
+ "E501", # Line too long
+ "E731", # Do not assign a `lambda` expression, use a `def`
+
+ "I001", # Import block is un-sorted or un-formatted
+ "C901", # Function is too complex
+ "C408", # Rewrite as a literal
+ "W605", # invalid escape sequence, messes with some docstrings
+]
+
+[tool.ruff.per-file-ignores]
+"webui.py" = ["E402"] # Module level import not at top of file
+
+[tool.ruff.flake8-bugbear]
+# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
+extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
+
+[tool.pytest.ini_options]
+base_url = "http://127.0.0.1:7860"
diff --git a/requirements-test.txt b/requirements-test.txt
new file mode 100644
index 00000000..37838ca2
--- /dev/null
+++ b/requirements-test.txt
@@ -0,0 +1,3 @@
+pytest-base-url~=2.0
+pytest-cov~=4.0
+pytest~=7.3
diff --git a/requirements.txt b/requirements.txt
index 35987a13..3142085e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,33 +1,32 @@
-astunparse
-blendmodes
+GitPython
+Pillow
accelerate
+
basicsr
-fonts
-font-roboto
+blendmodes
+clean-fid
+einops
gfpgan
-gradio==3.29.0
+gradio==3.32.0
+inflection
+jsonmerge
+kornia
+lark
numpy
omegaconf
-opencv-contrib-python
-requests
+
piexif
-Pillow
-pytorch_lightning==1.7.7
+psutil
+pytorch_lightning
realesrgan
+requests
+resize-right
+
+safetensors
scikit-image>=0.19
-timm==0.4.12
-transformers==4.25.1
+timm
+tomesd
torch
-einops
-jsonmerge
-clean-fid
-resize-right
torchdiffeq
-kornia
-lark
-inflection
-GitPython
torchsde
-safetensors
-psutil
-rich
+transformers==4.25.1
diff --git a/requirements_versions.txt b/requirements_versions.txt
index 7bce02e5..f71b9d6c 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,30 +1,30 @@
-blendmodes==2022
-transformers==4.25.1
+GitPython==3.1.30
+Pillow==9.5.0
accelerate==0.18.0
basicsr==1.4.2
+blendmodes==2022
+clean-fid==0.1.35
+einops==0.4.1
+fastapi==0.94.0
gfpgan==1.3.8
-gradio==3.29.0
+gradio==3.32.0
+httpcore<=0.15
+inflection==0.5.1
+jsonmerge==1.8.0
+kornia==0.6.7
+lark==1.1.2
numpy==1.23.5
-Pillow==9.4.0
-realesrgan==0.3.0
-torch
omegaconf==2.2.3
-pytorch_lightning==1.9.4
-scikit-image==0.19.2
-fonts
-font-roboto
-timm==0.6.7
piexif==1.1.3
-einops==0.4.1
-jsonmerge==1.8.0
-clean-fid==0.1.29
+psutil~=5.9.5
+pytorch_lightning==1.9.4
+realesrgan==0.3.0
resize-right==0.0.2
+safetensors==0.3.1
+scikit-image==0.20.0
+timm==0.6.7
+tomesd==0.1.2
+torch
torchdiffeq==0.2.3
-kornia==0.6.7
-lark==1.1.2
-inflection==0.5.1
-GitPython==3.1.30
torchsde==0.2.5
-safetensors==0.3.1
-httpcore<=0.15
-fastapi==0.94.0
+transformers==4.25.1
diff --git a/script.js b/script.js
index 03afe844..34cca765 100644
--- a/script.js
+++ b/script.js
@@ -1,66 +1,123 @@
function gradioApp() {
- const elems = document.getElementsByTagName('gradio-app')
- const elem = elems.length == 0 ? document : elems[0]
+ const elems = document.getElementsByTagName('gradio-app');
+ const elem = elems.length == 0 ? document : elems[0];
- if (elem !== document) elem.getElementById = function(id){ return document.getElementById(id) }
- return elem.shadowRoot ? elem.shadowRoot : elem
+ if (elem !== document) {
+ elem.getElementById = function(id) {
+ return document.getElementById(id);
+ };
+ }
+ return elem.shadowRoot ? elem.shadowRoot : elem;
}
+/**
+ * Get the currently selected top-level UI tab button (e.g. the button that says "Extras").
+ */
function get_uiCurrentTab() {
- return gradioApp().querySelector('#tabs button.selected')
+ return gradioApp().querySelector('#tabs > .tab-nav > button.selected');
}
+/**
+ * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI).
+ */
function get_uiCurrentTabContent() {
- return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])')
+ return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])');
}
-uiUpdateCallbacks = []
-uiLoadedCallbacks = []
-uiTabChangeCallbacks = []
-optionsChangedCallbacks = []
-let uiCurrentTab = null
+var uiUpdateCallbacks = [];
+var uiAfterUpdateCallbacks = [];
+var uiLoadedCallbacks = [];
+var uiTabChangeCallbacks = [];
+var optionsChangedCallbacks = [];
+var uiAfterUpdateTimeout = null;
+var uiCurrentTab = null;
-function onUiUpdate(callback){
- uiUpdateCallbacks.push(callback)
+/**
+ * Register callback to be called at each UI update.
+ * The callback receives an array of MutationRecords as an argument.
+ */
+function onUiUpdate(callback) {
+ uiUpdateCallbacks.push(callback);
}
-function onUiLoaded(callback){
- uiLoadedCallbacks.push(callback)
+
+/**
+ * Register callback to be called soon after UI updates.
+ * The callback receives no arguments.
+ *
+ * This is preferred over `onUiUpdate` if you don't need
+ * access to the MutationRecords, as your function will
+ * not be called quite as often.
+ */
+function onAfterUiUpdate(callback) {
+ uiAfterUpdateCallbacks.push(callback);
}
-function onUiTabChange(callback){
- uiTabChangeCallbacks.push(callback)
+
+/**
+ * Register callback to be called when the UI is loaded.
+ * The callback receives no arguments.
+ */
+function onUiLoaded(callback) {
+ uiLoadedCallbacks.push(callback);
}
-function onOptionsChanged(callback){
- optionsChangedCallbacks.push(callback)
+
+/**
+ * Register callback to be called when the UI tab is changed.
+ * The callback receives no arguments.
+ */
+function onUiTabChange(callback) {
+ uiTabChangeCallbacks.push(callback);
}
-function runCallback(x, m){
- try {
- x(m)
- } catch (e) {
- (console.error || console.log).call(console, e.message, e);
+/**
+ * Register callback to be called when the options are changed.
+ * The callback receives no arguments.
+ * @param callback
+ */
+function onOptionsChanged(callback) {
+ optionsChangedCallbacks.push(callback);
+}
+
+function executeCallbacks(queue, arg) {
+ for (const callback of queue) {
+ try {
+ callback(arg);
+ } catch (e) {
+ console.error("error running callback", callback, ":", e);
+ }
}
}
-function executeCallbacks(queue, m) {
- queue.forEach(function(x){runCallback(x, m)})
+
+/**
+ * Schedule the execution of the callbacks registered with onAfterUiUpdate.
+ * The callbacks are executed after a short while, unless another call to this function
+ * is made before that time. IOW, the callbacks are executed only once, even
+ * when there are multiple mutations observed.
+ */
+function scheduleAfterUiUpdateCallbacks() {
+ clearTimeout(uiAfterUpdateTimeout);
+ uiAfterUpdateTimeout = setTimeout(function() {
+ executeCallbacks(uiAfterUpdateCallbacks);
+ }, 200);
}
var executedOnLoaded = false;
document.addEventListener("DOMContentLoaded", function() {
- var mutationObserver = new MutationObserver(function(m){
- if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){
+ var mutationObserver = new MutationObserver(function(m) {
+ if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) {
executedOnLoaded = true;
executeCallbacks(uiLoadedCallbacks);
}
executeCallbacks(uiUpdateCallbacks, m);
+ scheduleAfterUiUpdateCallbacks();
const newTab = get_uiCurrentTab();
- if ( newTab && ( newTab !== uiCurrentTab ) ) {
+ if (newTab && (newTab !== uiCurrentTab)) {
uiCurrentTab = newTab;
executeCallbacks(uiTabChangeCallbacks);
}
});
- mutationObserver.observe( gradioApp(), { childList:true, subtree:true })
+ mutationObserver.observe(gradioApp(), {childList: true, subtree: true});
});
/**
@@ -69,36 +126,38 @@ document.addEventListener("DOMContentLoaded", function() {
document.addEventListener('keydown', function(e) {
var handled = false;
if (e.key !== undefined) {
- if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
+ if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} else if (e.keyCode !== undefined) {
- if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
+ if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
}
if (handled) {
- button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
+ var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
if (button) {
button.click();
}
e.preventDefault();
}
-})
+});
/**
* checks that a UI element is not in another hidden element or tab content
*/
function uiElementIsVisible(el) {
- let isVisible = !el.closest('.\\!hidden');
- if ( ! isVisible ) {
- return false;
+ if (el === document) {
+ return true;
}
- while( isVisible = el.closest('.tabitem')?.style.display !== 'none' ) {
- if ( ! isVisible ) {
- return false;
- } else if ( el.parentElement ) {
- el = el.parentElement
- } else {
- break;
- }
- }
- return isVisible;
+ const computedStyle = getComputedStyle(el);
+ const isVisible = computedStyle.display !== 'none';
+
+ if (!isVisible) return false;
+ return uiElementIsVisible(el.parentNode);
+}
+
+function uiElementInSight(el) {
+ const clRect = el.getBoundingClientRect();
+ const windowHeight = window.innerHeight;
+ const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight;
+
+ return isOnScreen;
}
diff --git a/scripts/custom_code.py b/scripts/custom_code.py
index f36a3675..cc6f0d49 100644
--- a/scripts/custom_code.py
+++ b/scripts/custom_code.py
@@ -4,7 +4,7 @@ import ast
import copy
from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
+from modules.shared import cmd_opts
def convertExpr2Expression(expr):
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py
index bb00fb3f..1e833fa8 100644
--- a/scripts/img2imgalt.py
+++ b/scripts/img2imgalt.py
@@ -149,9 +149,9 @@ class Script(scripts.Script):
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
return [
- info,
+ info,
override_sampler,
- override_prompt, original_prompt, original_negative_prompt,
+ override_prompt, original_prompt, original_negative_prompt,
override_steps, st,
override_strength,
cfg, randomness, sigma_adjustment,
@@ -191,17 +191,17 @@ class Script(scripts.Script):
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
-
+
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
-
+
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps)
-
+
noise_dt = combined_noise - (p.init_latent / sigmas[0])
-
+
p.seed = p.seed + 1
-
+
return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
p.sample = sample_extra
diff --git a/scripts/loopback.py b/scripts/loopback.py
index ad6609be..2d5feaf9 100644
--- a/scripts/loopback.py
+++ b/scripts/loopback.py
@@ -14,7 +14,7 @@ class Script(scripts.Script):
def show(self, is_img2img):
return is_img2img
- def ui(self, is_img2img):
+ def ui(self, is_img2img):
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
@@ -104,7 +104,7 @@ class Script(scripts.Script):
p.seed = processed.seed + 1
p.denoising_strength = calculate_denoising_strength(i + 1)
-
+
if state.skipped:
break
@@ -121,7 +121,7 @@ class Script(scripts.Script):
all_images.append(last_image)
p.inpainting_fill = original_inpainting_fill
-
+
if state.interrupted:
break
@@ -132,7 +132,7 @@ class Script(scripts.Script):
if opts.return_grid:
grids.append(grid)
-
+
all_images = grids + all_images
processed = Processed(p, all_images, initial_seed, initial_info)
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py
index 670bb8ac..c98ab480 100644
--- a/scripts/outpainting_mk_2.py
+++ b/scripts/outpainting_mk_2.py
@@ -7,9 +7,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing, devices
+from modules import images
from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
@@ -72,7 +72,7 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0
height = _np_src_image.shape[1]
num_channels = _np_src_image.shape[2]
- np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
+ _np_src_image[:] * (1. - np_mask_rgb)
np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)
img_mask = np_mask_grey > 1e-6
ref_mask = np_mask_grey < 1e-3
@@ -145,7 +145,6 @@ class Script(scripts.Script):
process_width = p.width
process_height = p.height
- p.mask_blur = mask_blur*4
p.inpaint_full_res = False
p.inpainting_fill = 1
p.do_not_save_samples = True
@@ -156,6 +155,19 @@ class Script(scripts.Script):
up = pixels if "up" in direction else 0
down = pixels if "down" in direction else 0
+ if left > 0 or right > 0:
+ mask_blur_x = mask_blur
+ else:
+ mask_blur_x = 0
+
+ if up > 0 or down > 0:
+ mask_blur_y = mask_blur
+ else:
+ mask_blur_y = 0
+
+ p.mask_blur_x = mask_blur_x*4
+ p.mask_blur_y = mask_blur_y*4
+
init_img = p.init_images[0]
target_w = math.ceil((init_img.width + left + right) / 64) * 64
target_h = math.ceil((init_img.height + up + down) / 64) * 64
@@ -191,10 +203,10 @@ class Script(scripts.Script):
mask = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(mask)
draw.rectangle((
- expand_pixels + mask_blur if is_left else 0,
- expand_pixels + mask_blur if is_top else 0,
- mask.width - expand_pixels - mask_blur if is_right else res_w,
- mask.height - expand_pixels - mask_blur if is_bottom else res_h,
+ expand_pixels + mask_blur_x if is_left else 0,
+ expand_pixels + mask_blur_y if is_top else 0,
+ mask.width - expand_pixels - mask_blur_x if is_right else res_w,
+ mask.height - expand_pixels - mask_blur_y if is_bottom else res_h,
), fill="black")
np_image = (np.asarray(img) / 255.0).astype(np.float64)
@@ -224,10 +236,10 @@ class Script(scripts.Script):
latent_mask = Image.new("RGB", (p.width, p.height), "white")
draw = ImageDraw.Draw(latent_mask)
draw.rectangle((
- expand_pixels + mask_blur * 2 if is_left else 0,
- expand_pixels + mask_blur * 2 if is_top else 0,
- mask.width - expand_pixels - mask_blur * 2 if is_right else res_w,
- mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h,
+ expand_pixels + mask_blur_x * 2 if is_left else 0,
+ expand_pixels + mask_blur_y * 2 if is_top else 0,
+ mask.width - expand_pixels - mask_blur_x * 2 if is_right else res_w,
+ mask.height - expand_pixels - mask_blur_y * 2 if is_bottom else res_h,
), fill="black")
p.latent_mask = latent_mask
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
index ddcbd2d3..ea0632b6 100644
--- a/scripts/poor_mans_outpainting.py
+++ b/scripts/poor_mans_outpainting.py
@@ -4,9 +4,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
-from modules import images, processing, devices
+from modules import images, devices
from modules.processing import Processed, process_images
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
class Script(scripts.Script):
@@ -19,7 +19,7 @@ class Script(scripts.Script):
def ui(self, is_img2img):
if not is_img2img:
return None
-
+
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py
index ef1186ac..edb70ac0 100644
--- a/scripts/postprocessing_upscale.py
+++ b/scripts/postprocessing_upscale.py
@@ -98,13 +98,13 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'
upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
- pp.info[f"Postprocess upscaler"] = upscaler1.name
+ pp.info["Postprocess upscaler"] = upscaler1.name
if upscaler2 and upscaler_2_visibility > 0:
second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop)
upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)
- pp.info[f"Postprocess upscaler 2"] = upscaler2.name
+ pp.info["Postprocess upscaler 2"] = upscaler2.name
pp.image = upscaled_image
@@ -134,4 +134,4 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
assert upscaler1, f'could not find upscaler named {upscaler_name}'
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
- pp.info[f"Postprocess upscaler"] = upscaler1.name
+ pp.info["Postprocess upscaler"] = upscaler1.name
diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py
index e9b11517..88324fe6 100644
--- a/scripts/prompt_matrix.py
+++ b/scripts/prompt_matrix.py
@@ -1,14 +1,11 @@
import math
-from collections import namedtuple
-from copy import copy
-import random
import modules.scripts as scripts
import gradio as gr
from modules import images
-from modules.processing import process_images, Processed
-from modules.shared import opts, cmd_opts, state
+from modules.processing import process_images
+from modules.shared import opts, state
import modules.sd_samplers
@@ -99,7 +96,7 @@ class Script(scripts.Script):
p.prompt_for_display = positive_prompt
processed = process_images(p)
- grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
+ grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
processed.images.insert(0, grid)
processed.index_of_first_image = 1
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py
index f168389c..50320d55 100644
--- a/scripts/prompts_from_file.py
+++ b/scripts/prompts_from_file.py
@@ -1,18 +1,13 @@
import copy
-import math
-import os
import random
-import sys
-import traceback
import shlex
import modules.scripts as scripts
import gradio as gr
-from modules import sd_samplers
+from modules import sd_samplers, errors
from modules.processing import Processed, process_images
-from PIL import Image
-from modules.shared import opts, cmd_opts, state
+from modules.shared import state
def process_string_tag(tag):
@@ -126,8 +121,7 @@ class Script(scripts.Script):
return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
- lines = [x.strip() for x in prompt_txt.splitlines()]
- lines = [x for x in lines if len(x) > 0]
+ lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
p.do_not_save_grid = True
@@ -139,8 +133,7 @@ class Script(scripts.Script):
try:
args = cmdargs(line)
except Exception:
- print(f"Error parsing line {line} as commandline:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ errors.report(f"Error parsing line {line} as commandline", exc_info=True)
args = {"prompt": line}
else:
args = {"prompt": line}
@@ -158,7 +151,7 @@ class Script(scripts.Script):
images = []
all_prompts = []
infotexts = []
- for n, args in enumerate(jobs):
+ for args in jobs:
state.job = f"{state.job_no + 1} out of {state.job_count}"
copy_p = copy.copy(p)
@@ -167,7 +160,7 @@ class Script(scripts.Script):
proc = process_images(copy_p)
images += proc.images
-
+
if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter)
all_prompts += proc.all_prompts
diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py
index 332d76d9..e614c23b 100644
--- a/scripts/sd_upscale.py
+++ b/scripts/sd_upscale.py
@@ -4,9 +4,9 @@ import modules.scripts as scripts
import gradio as gr
from PIL import Image
-from modules import processing, shared, sd_samplers, images, devices
+from modules import processing, shared, images, devices
from modules.processing import Processed
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
class Script(scripts.Script):
@@ -16,7 +16,7 @@ class Script(scripts.Script):
def show(self, is_img2img):
return is_img2img
- def ui(self, is_img2img):
+ def ui(self, is_img2img):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
@@ -56,7 +56,7 @@ class Script(scripts.Script):
work = []
- for y, h, row in grid.tiles:
+ for _y, _h, row in grid.tiles:
for tiledata in row:
work.append(tiledata[2])
@@ -85,7 +85,7 @@ class Script(scripts.Script):
work_results += processed.images
image_index = 0
- for y, h, row in grid.tiles:
+ for _y, _h, row in grid.tiles:
for tiledata in row:
tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
image_index += 1
diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index a725d74a..7821cc65 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -10,15 +10,13 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
-from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
+from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
-from modules.shared import opts, cmd_opts, state
+from modules.shared import opts, state
import modules.shared as shared
import modules.sd_samplers
import modules.sd_models
import modules.sd_vae
-import glob
-import os
import re
from modules.ui_components import ToolButton
@@ -86,7 +84,7 @@ def apply_checkpoint(p, x, xs):
info = modules.sd_models.get_closet_checkpoint_match(x)
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
- p.override_settings['sd_model_checkpoint'] = info.hash
+ p.override_settings['sd_model_checkpoint'] = info.name
def confirm_checkpoints(p, xs):
@@ -146,6 +144,11 @@ def apply_face_restore(p, opt, x):
p.restore_faces = is_active
+def apply_override(field):
+ def fun(p, x, xs):
+ p.override_settings[field] = x
+ return fun
+
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@@ -217,6 +220,10 @@ axis_options = [
AxisOption("Sigma min", float, apply_field("s_tmin")),
AxisOption("Sigma max", float, apply_field("s_tmax")),
AxisOption("Sigma noise", float, apply_field("s_noise")),
+ AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)),
+ AxisOption("Schedule min sigma", float, apply_override("sigma_min")),
+ AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
+ AxisOption("Schedule rho", float, apply_override("rho")),
AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip),
AxisOption("Denoising", float, apply_field("denoising_strength")),
@@ -226,6 +233,8 @@ axis_options = [
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
+ AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
+ AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
]
@@ -316,7 +325,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
return Processed(p, [])
z_count = len(zs)
- sub_grids = [None] * z_count
+
for i in range(z_count):
start_index = (i * len(xs) * len(ys)) + i
end_index = start_index + len(xs) * len(ys)
@@ -706,7 +715,7 @@ class Script(scripts.Script):
if not include_sub_grids:
# Done with sub-grids, drop all related information:
- for sg in range(z_count):
+ for _ in range(z_count):
del processed.images[1]
del processed.all_prompts[1]
del processed.all_seeds[1]
diff --git a/style.css b/style.css
index 31b2ed5a..5073f0f0 100644
--- a/style.css
+++ b/style.css
@@ -309,6 +309,11 @@ button.custom-button{
font-weight: bold;
}
+#txtimg_hr_finalres div.pending, #img2img_scale_resolution_preview div.pending {
+ opacity: 1;
+ transition: opacity 0s;
+}
+
.inactive{
opacity: 0.5;
}
@@ -320,20 +325,14 @@ button.custom-button{
div.dimensions-tools{
min-width: 0 !important;
max-width: fit-content;
- flex-direction: row;
- align-content: center;
+ flex-direction: column;
+ place-content: center;
}
div#extras_scale_to_tab div.form{
flex-direction: row;
}
-#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{
- height: 480px !important;
- max-height: 480px !important;
- min-height: 480px !important;
-}
-
#img2img_sketch, #img2maskimg, #inpaint_sketch {
overflow: overlay !important;
resize: auto;
@@ -404,19 +403,66 @@ div#extras_scale_to_tab div.form{
margin: 0 1.2em;
}
-table.settings-value-table{
+table.popup-table{
background: white;
border-collapse: collapse;
margin: 1em;
border: 4px solid white;
}
-table.settings-value-table td{
+table.popup-table td{
padding: 0.4em;
border: 1px solid #ccc;
max-width: 36em;
}
+table.popup-table .muted{
+ color: #aaa;
+}
+
+table.popup-table .link{
+ text-decoration: underline;
+ cursor: pointer;
+ font-weight: bold;
+}
+
+.ui-defaults-none{
+ color: #aaa !important;
+}
+
+#settings span{
+ color: var(--body-text-color);
+}
+
+#settings .gradio-textbox, #settings .gradio-slider, #settings .gradio-number, #settings .gradio-dropdown, #settings .gradio-checkboxgroup, #settings .gradio-radio{
+ margin-top: 0.75em;
+}
+
+#settings span .settings-comment {
+ display: inline
+}
+
+.settings-comment a{
+ text-decoration: underline;
+}
+
+.settings-comment .info{
+ opacity: 0.75;
+}
+
+#sysinfo_download a.sysinfo_big_link{
+ font-size: 24pt;
+}
+
+#sysinfo_download a{
+ text-decoration: underline;
+}
+
+#sysinfo_validity{
+ font-size: 18pt;
+}
+
+
/* live preview */
.progressDiv{
position: relative;
@@ -658,11 +704,24 @@ table.settings-value-table td{
margin: 0;
}
-#available_extensions .date_added{
- opacity: 0.85;
+#available_extensions .info{
+ margin: 0.5em 0;
+ display: flex;
+ margin-top: auto;
+ opacity: 0.80;
font-size: 90%;
}
+#available_extensions .date_added{
+ margin-right: auto;
+ display: inline-block;
+}
+
+#available_extensions .star_count{
+ margin-left: auto;
+ display: inline-block;
+}
+
/* replace original footer with ours */
footer {
@@ -701,12 +760,22 @@ footer {
.extra-network-subdirs button{
margin: 0 0.15em;
}
-.extra-networks .tab-nav .search{
+.extra-networks .tab-nav .search,
+.extra-networks .tab-nav .sort,
+.extra-networks .tab-nav .sortorder{
display: inline-block;
- max-width: 16em;
margin: 0.3em;
align-self: center;
+}
+
+.extra-networks .tab-nav .search {
width: 16em;
+ max-width: 16em;
+}
+
+.extra-networks .tab-nav .sort {
+ width: 12em;
+ max-width: 12em;
}
#txt2img_extra_view, #img2img_extra_view {
@@ -733,13 +802,22 @@ footer {
.extra-network-cards .card .metadata-button, .extra-network-thumbs .card .metadata-button{
display: none;
position: absolute;
- right: 0;
color: white;
+ right: 0;
+}
+.extra-network-cards .card .metadata-button {
text-shadow: 2px 2px 3px black;
padding: 0.25em;
font-size: 22pt;
width: 1.5em;
}
+.extra-network-thumbs .card .metadata-button {
+ text-shadow: 1px 1px 2px black;
+ padding: 0;
+ font-size: 16pt;
+ width: 1em;
+ top: -0.25em;
+}
.extra-network-cards .card:hover .metadata-button, .extra-network-thumbs .card:hover .metadata-button{
display: inline-block;
}
@@ -764,6 +842,13 @@ footer {
position: relative;
}
+.extra-network-thumbs .card .preview{
+ position: absolute;
+ object-fit: cover;
+ width: 100%;
+ height:100%;
+}
+
.extra-network-thumbs .card:hover .additional a {
display: inline-block;
}
@@ -878,3 +963,10 @@ footer {
.extra-network-cards .card ul a:hover{
color: red;
}
+
+.extra-network-cards .card .preview{
+ position: absolute;
+ object-fit: cover;
+ width: 100%;
+ height:100%;
+}
diff --git a/test/basic_features/__init__.py b/test/basic_features/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/test/basic_features/__init__.py
+++ /dev/null
diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py
deleted file mode 100644
index 8ed98747..00000000
--- a/test/basic_features/extras_test.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import os
-import unittest
-import requests
-from gradio.processing_utils import encode_pil_to_base64
-from PIL import Image
-from modules.paths import script_path
-
-class TestExtrasWorking(unittest.TestCase):
- def setUp(self):
- self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image"
- self.extras_single = {
- "resize_mode": 0,
- "show_extras_results": True,
- "gfpgan_visibility": 0,
- "codeformer_visibility": 0,
- "codeformer_weight": 0,
- "upscaling_resize": 2,
- "upscaling_resize_w": 128,
- "upscaling_resize_h": 128,
- "upscaling_crop": True,
- "upscaler_1": "None",
- "upscaler_2": "None",
- "extras_upscaler_2_visibility": 0,
- "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
- }
-
- def test_simple_upscaling_performed(self):
- self.extras_single["upscaler_1"] = "Lanczos"
- self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200)
-
-
-class TestPngInfoWorking(unittest.TestCase):
- def setUp(self):
- self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image"
- self.png_info = {
- "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
- }
-
- def test_png_info_performed(self):
- self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200)
-
-
-class TestInterrogateWorking(unittest.TestCase):
- def setUp(self):
- self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image"
- self.interrogate = {
- "image": encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png"))),
- "model": "clip"
- }
-
- def test_interrogate_performed(self):
- self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py
deleted file mode 100644
index 5240ec36..00000000
--- a/test/basic_features/img2img_test.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import os
-import unittest
-import requests
-from gradio.processing_utils import encode_pil_to_base64
-from PIL import Image
-from modules.paths import script_path
-
-
-class TestImg2ImgWorking(unittest.TestCase):
- def setUp(self):
- self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
- self.simple_img2img = {
- "init_images": [encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))],
- "resize_mode": 0,
- "denoising_strength": 0.75,
- "mask": None,
- "mask_blur": 4,
- "inpainting_fill": 0,
- "inpaint_full_res": False,
- "inpaint_full_res_padding": 0,
- "inpainting_mask_invert": False,
- "prompt": "example prompt",
- "styles": [],
- "seed": -1,
- "subseed": -1,
- "subseed_strength": 0,
- "seed_resize_from_h": -1,
- "seed_resize_from_w": -1,
- "batch_size": 1,
- "n_iter": 1,
- "steps": 3,
- "cfg_scale": 7,
- "width": 64,
- "height": 64,
- "restore_faces": False,
- "tiling": False,
- "negative_prompt": "",
- "eta": 0,
- "s_churn": 0,
- "s_tmax": 0,
- "s_tmin": 0,
- "s_noise": 1,
- "override_settings": {},
- "sampler_index": "Euler a",
- "include_init_images": False
- }
-
- def test_img2img_simple_performed(self):
- self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
-
- def test_inpainting_masked_performed(self):
- self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
- self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
-
- def test_inpainting_with_inverted_masked_performed(self):
- self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(os.path.join(script_path, r"test/test_files/img2img_basic.png")))
- self.simple_img2img["inpainting_mask_invert"] = True
- self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
-
- def test_img2img_sd_upscale_performed(self):
- self.simple_img2img["script_name"] = "sd upscale"
- self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0]
-
- self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py
deleted file mode 100644
index cb525fbb..00000000
--- a/test/basic_features/txt2img_test.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import unittest
-import requests
-
-
-class TestTxt2ImgWorking(unittest.TestCase):
- def setUp(self):
- self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
- self.simple_txt2img = {
- "enable_hr": False,
- "denoising_strength": 0,
- "firstphase_width": 0,
- "firstphase_height": 0,
- "prompt": "example prompt",
- "styles": [],
- "seed": -1,
- "subseed": -1,
- "subseed_strength": 0,
- "seed_resize_from_h": -1,
- "seed_resize_from_w": -1,
- "batch_size": 1,
- "n_iter": 1,
- "steps": 3,
- "cfg_scale": 7,
- "width": 64,
- "height": 64,
- "restore_faces": False,
- "tiling": False,
- "negative_prompt": "",
- "eta": 0,
- "s_churn": 0,
- "s_tmax": 0,
- "s_tmin": 0,
- "s_noise": 1,
- "sampler_index": "Euler a"
- }
-
- def test_txt2img_simple_performed(self):
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_negative_prompt_performed(self):
- self.simple_txt2img["negative_prompt"] = "example negative prompt"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_complex_prompt_performed(self):
- self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_not_square_image_performed(self):
- self.simple_txt2img["height"] = 128
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_hrfix_performed(self):
- self.simple_txt2img["enable_hr"] = True
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_tiling_performed(self):
- self.simple_txt2img["tiling"] = True
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_restore_faces_performed(self):
- self.simple_txt2img["restore_faces"] = True
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_with_vanilla_sampler_performed(self):
- self.simple_txt2img["sampler_index"] = "PLMS"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
- self.simple_txt2img["sampler_index"] = "DDIM"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
- self.simple_txt2img["sampler_index"] = "UniPC"
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_multiple_batches_performed(self):
- self.simple_txt2img["n_iter"] = 2
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
- def test_txt2img_batch_performed(self):
- self.simple_txt2img["batch_size"] = 2
- self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py
deleted file mode 100644
index 0bfc28a0..00000000
--- a/test/basic_features/utils_test.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import unittest
-import requests
-
-class UtilsTests(unittest.TestCase):
- def setUp(self):
- self.url_options = "http://localhost:7860/sdapi/v1/options"
- self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
- self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
- self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
- self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
- self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
- self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
- self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
- self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
- self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
-
- def test_options_get(self):
- self.assertEqual(requests.get(self.url_options).status_code, 200)
-
- def test_options_write(self):
- response = requests.get(self.url_options)
- self.assertEqual(response.status_code, 200)
-
- pre_value = response.json()["send_seed"]
-
- self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
-
- response = requests.get(self.url_options)
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.json()["send_seed"], not pre_value)
-
- requests.post(self.url_options, json={"send_seed": pre_value})
-
- def test_cmd_flags(self):
- self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
-
- def test_samplers(self):
- self.assertEqual(requests.get(self.url_samplers).status_code, 200)
-
- def test_upscalers(self):
- self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
-
- def test_sd_models(self):
- self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
-
- def test_hypernetworks(self):
- self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
-
- def test_face_restorers(self):
- self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
-
- def test_realesrgan_models(self):
- self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
-
- def test_prompt_styles(self):
- self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
-
- def test_embeddings(self):
- self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/conftest.py b/test/conftest.py
new file mode 100644
index 00000000..0723f62a
--- /dev/null
+++ b/test/conftest.py
@@ -0,0 +1,17 @@
+import os
+
+import pytest
+from PIL import Image
+from gradio.processing_utils import encode_pil_to_base64
+
+test_files_path = os.path.dirname(__file__) + "/test_files"
+
+
+@pytest.fixture(scope="session") # session so we don't read this over and over
+def img2img_basic_image_base64() -> str:
+ return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "img2img_basic.png")))
+
+
+@pytest.fixture(scope="session") # session so we don't read this over and over
+def mask_basic_image_base64() -> str:
+ return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "mask_basic.png")))
diff --git a/test/server_poll.py b/test/server_poll.py
deleted file mode 100644
index c732630f..00000000
--- a/test/server_poll.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import unittest
-import requests
-import time
-import os
-from modules.paths import script_path
-
-
-def run_tests(proc, test_dir):
- timeout_threshold = 240
- start_time = time.time()
- while time.time()-start_time < timeout_threshold:
- try:
- requests.head("http://localhost:7860/")
- break
- except requests.exceptions.ConnectionError:
- if proc.poll() is not None:
- break
- if proc.poll() is None:
- if test_dir is None:
- test_dir = os.path.join(script_path, "test")
- suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir=test_dir)
- result = unittest.TextTestRunner(verbosity=2).run(suite)
- return len(result.failures) + len(result.errors)
- else:
- print("Launch unsuccessful")
- return 1
diff --git a/test/test_extras.py b/test/test_extras.py
new file mode 100644
index 00000000..799d9fad
--- /dev/null
+++ b/test/test_extras.py
@@ -0,0 +1,35 @@
+import requests
+
+
+def test_simple_upscaling_performed(base_url, img2img_basic_image_base64):
+ payload = {
+ "resize_mode": 0,
+ "show_extras_results": True,
+ "gfpgan_visibility": 0,
+ "codeformer_visibility": 0,
+ "codeformer_weight": 0,
+ "upscaling_resize": 2,
+ "upscaling_resize_w": 128,
+ "upscaling_resize_h": 128,
+ "upscaling_crop": True,
+ "upscaler_1": "Lanczos",
+ "upscaler_2": "None",
+ "extras_upscaler_2_visibility": 0,
+ "image": img2img_basic_image_base64,
+ }
+ assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
+
+
+def test_png_info_performed(base_url, img2img_basic_image_base64):
+ payload = {
+ "image": img2img_basic_image_base64,
+ }
+ assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
+
+
+def test_interrogate_performed(base_url, img2img_basic_image_base64):
+ payload = {
+ "image": img2img_basic_image_base64,
+ "model": "clip",
+ }
+ assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200
diff --git a/test/test_img2img.py b/test/test_img2img.py
new file mode 100644
index 00000000..117d2d1e
--- /dev/null
+++ b/test/test_img2img.py
@@ -0,0 +1,68 @@
+
+import pytest
+import requests
+
+
+@pytest.fixture()
+def url_img2img(base_url):
+ return f"{base_url}/sdapi/v1/img2img"
+
+
+@pytest.fixture()
+def simple_img2img_request(img2img_basic_image_base64):
+ return {
+ "batch_size": 1,
+ "cfg_scale": 7,
+ "denoising_strength": 0.75,
+ "eta": 0,
+ "height": 64,
+ "include_init_images": False,
+ "init_images": [img2img_basic_image_base64],
+ "inpaint_full_res": False,
+ "inpaint_full_res_padding": 0,
+ "inpainting_fill": 0,
+ "inpainting_mask_invert": False,
+ "mask": None,
+ "mask_blur": 4,
+ "n_iter": 1,
+ "negative_prompt": "",
+ "override_settings": {},
+ "prompt": "example prompt",
+ "resize_mode": 0,
+ "restore_faces": False,
+ "s_churn": 0,
+ "s_noise": 1,
+ "s_tmax": 0,
+ "s_tmin": 0,
+ "sampler_index": "Euler a",
+ "seed": -1,
+ "seed_resize_from_h": -1,
+ "seed_resize_from_w": -1,
+ "steps": 3,
+ "styles": [],
+ "subseed": -1,
+ "subseed_strength": 0,
+ "tiling": False,
+ "width": 64,
+ }
+
+
+def test_img2img_simple_performed(url_img2img, simple_img2img_request):
+ assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
+
+
+def test_inpainting_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64):
+ simple_img2img_request["mask"] = mask_basic_image_base64
+ assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
+
+
+def test_inpainting_with_inverted_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64):
+ simple_img2img_request["mask"] = mask_basic_image_base64
+ simple_img2img_request["inpainting_mask_invert"] = True
+ assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
+
+
+def test_img2img_sd_upscale_performed(url_img2img, simple_img2img_request):
+ simple_img2img_request["script_name"] = "sd upscale"
+ simple_img2img_request["script_args"] = ["", 8, "Lanczos", 2.0]
+ assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200
diff --git a/test/test_txt2img.py b/test/test_txt2img.py
new file mode 100644
index 00000000..6eb94f0a
--- /dev/null
+++ b/test/test_txt2img.py
@@ -0,0 +1,90 @@
+
+import pytest
+import requests
+
+
+@pytest.fixture()
+def url_txt2img(base_url):
+ return f"{base_url}/sdapi/v1/txt2img"
+
+
+@pytest.fixture()
+def simple_txt2img_request():
+ return {
+ "batch_size": 1,
+ "cfg_scale": 7,
+ "denoising_strength": 0,
+ "enable_hr": False,
+ "eta": 0,
+ "firstphase_height": 0,
+ "firstphase_width": 0,
+ "height": 64,
+ "n_iter": 1,
+ "negative_prompt": "",
+ "prompt": "example prompt",
+ "restore_faces": False,
+ "s_churn": 0,
+ "s_noise": 1,
+ "s_tmax": 0,
+ "s_tmin": 0,
+ "sampler_index": "Euler a",
+ "seed": -1,
+ "seed_resize_from_h": -1,
+ "seed_resize_from_w": -1,
+ "steps": 3,
+ "styles": [],
+ "subseed": -1,
+ "subseed_strength": 0,
+ "tiling": False,
+ "width": 64,
+ }
+
+
+def test_txt2img_simple_performed(url_txt2img, simple_txt2img_request):
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+def test_txt2img_with_negative_prompt_performed(url_txt2img, simple_txt2img_request):
+ simple_txt2img_request["negative_prompt"] = "example negative prompt"
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+def test_txt2img_with_complex_prompt_performed(url_txt2img, simple_txt2img_request):
+ simple_txt2img_request["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]"
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+def test_txt2img_not_square_image_performed(url_txt2img, simple_txt2img_request):
+ simple_txt2img_request["height"] = 128
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+def test_txt2img_with_hrfix_performed(url_txt2img, simple_txt2img_request):
+ simple_txt2img_request["enable_hr"] = True
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+def test_txt2img_with_tiling_performed(url_txt2img, simple_txt2img_request):
+ simple_txt2img_request["tiling"] = True
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+def test_txt2img_with_restore_faces_performed(url_txt2img, simple_txt2img_request):
+ simple_txt2img_request["restore_faces"] = True
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+@pytest.mark.parametrize("sampler", ["PLMS", "DDIM", "UniPC"])
+def test_txt2img_with_vanilla_sampler_performed(url_txt2img, simple_txt2img_request, sampler):
+ simple_txt2img_request["sampler_index"] = sampler
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+def test_txt2img_multiple_batches_performed(url_txt2img, simple_txt2img_request):
+ simple_txt2img_request["n_iter"] = 2
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
+
+
+def test_txt2img_batch_performed(url_txt2img, simple_txt2img_request):
+ simple_txt2img_request["batch_size"] = 2
+ assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200
diff --git a/test/test_utils.py b/test/test_utils.py
new file mode 100644
index 00000000..edba0b18
--- /dev/null
+++ b/test/test_utils.py
@@ -0,0 +1,33 @@
+import pytest
+import requests
+
+
+def test_options_write(base_url):
+ url_options = f"{base_url}/sdapi/v1/options"
+ response = requests.get(url_options)
+ assert response.status_code == 200
+
+ pre_value = response.json()["send_seed"]
+
+ assert requests.post(url_options, json={'send_seed': (not pre_value)}).status_code == 200
+
+ response = requests.get(url_options)
+ assert response.status_code == 200
+ assert response.json()['send_seed'] == (not pre_value)
+
+ requests.post(url_options, json={"send_seed": pre_value})
+
+
+@pytest.mark.parametrize("url", [
+ "sdapi/v1/cmd-flags",
+ "sdapi/v1/samplers",
+ "sdapi/v1/upscalers",
+ "sdapi/v1/sd-models",
+ "sdapi/v1/hypernetworks",
+ "sdapi/v1/face-restorers",
+ "sdapi/v1/realesrgan-models",
+ "sdapi/v1/prompt-styles",
+ "sdapi/v1/embeddings",
+])
+def test_get_api_url(base_url, url):
+ assert requests.get(f"{base_url}/{url}").status_code == 200
diff --git a/webui-macos-env.sh b/webui-macos-env.sh
index 10ab81c9..6354e73b 100644
--- a/webui-macos-env.sh
+++ b/webui-macos-env.sh
@@ -11,7 +11,7 @@ fi
export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
-export TORCH_COMMAND="pip install torch torchvision"
+export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1
diff --git a/webui-user.sh b/webui-user.sh
index 49a426ff..70306c60 100644
--- a/webui-user.sh
+++ b/webui-user.sh
@@ -36,7 +36,6 @@
# Fixed git commits
#export STABLE_DIFFUSION_COMMIT_HASH=""
-#export TAMING_TRANSFORMERS_COMMIT_HASH=""
#export CODEFORMER_COMMIT_HASH=""
#export BLIP_COMMIT_HASH=""
diff --git a/webui.bat b/webui.bat
index 209d972b..42e7d517 100644
--- a/webui.bat
+++ b/webui.bat
@@ -3,7 +3,7 @@
if not defined PYTHON (set PYTHON=python)
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
-
+set SD_WEBUI_RESTART=tmp/restart
set ERROR_REPORTING=FALSE
mkdir tmp 2>NUL
@@ -51,12 +51,14 @@ if EXIST %ACCELERATE% goto :accelerate_launch
:launch
%PYTHON% launch.py %*
+if EXIST tmp/restart goto :skip_venv
pause
exit /b
:accelerate_launch
echo Accelerating
%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
+if EXIST tmp/restart goto :skip_venv
pause
exit /b
diff --git a/webui.py b/webui.py
index e8f0a63d..34c2fd18 100644
--- a/webui.py
+++ b/webui.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import sys
import time
@@ -7,6 +9,7 @@ import re
import warnings
import json
from threading import Thread
+from typing import Iterable
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@@ -14,36 +17,47 @@ from fastapi.middleware.gzip import GZipMiddleware
from packaging import version
import logging
+
+# We can't use cmd_opts for this because it will not have been initialized at this point.
+log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
+if log_level:
+ log_level = getattr(logging, log_level.upper(), None) or logging.INFO
+ logging.basicConfig(
+ level=log_level,
+ format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ )
+
+logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
-from modules import paths, timer, import_hook, errors
+from modules import paths, timer, import_hook, errors, devices # noqa: F401
-startup_timer = timer.Timer()
+startup_timer = timer.startup_timer
import torch
-import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
+import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
startup_timer.record("import torch")
-import gradio
+import gradio # noqa: F401
startup_timer.record("import gradio")
-import ldm.modules.encoders.modules
+import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
-from modules import extra_networks, ui_extra_networks_checkpoints
-from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
-from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+from modules import extra_networks
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
@@ -52,8 +66,10 @@ import modules.img2img
import modules.lowvram
import modules.scripts
import modules.sd_hijack
+import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
+import modules.sd_unet
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
@@ -130,7 +146,7 @@ there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip())
- expected_xformers_version = "0.0.17"
+ expected_xformers_version = "0.0.20"
if shared.xformers_available:
import xformers
@@ -144,16 +160,11 @@ Use --skip-version-check commandline argument to disable this check.
""".strip())
-def initialize():
- fix_asyncio_event_loop_policy()
-
- check_versions()
-
- extensions.list_extensions()
- localization.list_localizations(cmd_opts.localizations_dir)
- startup_timer.record("list extensions")
-
+def restore_config_state_file():
config_state_file = shared.opts.restore_config_state_file
+ if config_state_file == "":
+ return
+
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
@@ -166,14 +177,82 @@ def initialize():
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
- if cmd_opts.ui_debug_mode:
- shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
- modules.scripts.load_scripts()
+
+def validate_tls_options():
+ if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
return
+ try:
+ if not os.path.exists(cmd_opts.tls_keyfile):
+ print("Invalid path to TLS keyfile given")
+ if not os.path.exists(cmd_opts.tls_certfile):
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
+ except TypeError:
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
+ print("TLS setup invalid, running webui without TLS")
+ else:
+ print("Running with TLS")
+ startup_timer.record("TLS")
+
+
+def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
+ """
+ Convert the gradio_auth and gradio_auth_path commandline arguments into
+ an iterable of (username, password) tuples.
+ """
+ def process_credential_line(s) -> tuple[str, ...] | None:
+ s = s.strip()
+ if not s:
+ return None
+ return tuple(s.split(':', 1))
+
+ if cmd_opts.gradio_auth:
+ for cred in cmd_opts.gradio_auth.split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+ if cmd_opts.gradio_auth_path:
+ with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
+ for line in file.readlines():
+ for cred in line.strip().split(','):
+ cred = process_credential_line(cred)
+ if cred:
+ yield cred
+
+
+def configure_sigint_handler():
+ # make the program just exit at ctrl+c without waiting for anything
+ def sigint_handler(sig, frame):
+ print(f'Interrupted with signal {sig} in {frame}')
+ os._exit(0)
+
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
+
+
+def configure_opts_onchange():
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+ shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
+ shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ startup_timer.record("opts onchange")
+
+
+def initialize():
+ fix_asyncio_event_loop_policy()
+ validate_tls_options()
+ configure_sigint_handler()
+ check_versions()
modelloader.cleanup_models()
+ configure_opts_onchange()
+
modules.sd_models.setup_model()
- startup_timer.record("list SD models")
+ startup_timer.record("setup SD model")
codeformer.setup_model(cmd_opts.codeformer_models_path)
startup_timer.record("setup codeformer")
@@ -181,72 +260,98 @@ def initialize():
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
- modules.scripts.load_scripts()
- startup_timer.record("load scripts")
+ initialize_rest(reload_script_modules=False)
+
+
+def initialize_rest(*, reload_script_modules=False):
+ """
+ Called both from initialize() and when reloading the webui.
+ """
+ sd_samplers.set_samplers()
+ extensions.list_extensions()
+ startup_timer.record("list extensions")
+
+ restore_config_state_file()
+
+ if cmd_opts.ui_debug_mode:
+ shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
+ modules.scripts.load_scripts()
+ return
+
+ modules.sd_models.list_models()
+ startup_timer.record("list SD models")
+
+ localization.list_localizations(cmd_opts.localizations_dir)
+
+ with startup_timer.subcategory("load scripts"):
+ modules.scripts.load_scripts()
+
+ if reload_script_modules:
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
+ startup_timer.record("reload script modules")
modelloader.load_upscalers()
- startup_timer.record("load upscalers") #Is this necessary? I don't know.
+ startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
-
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
- # load model in parallel to other startup stuff
- Thread(target=lambda: shared.sd_model).start()
+ modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
+ modules.sd_hijack.list_optimizers()
+ startup_timer.record("scripts list_optimizers")
- shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
- shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
- shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
- startup_timer.record("opts onchange")
+ modules.sd_unet.list_unets()
+ startup_timer.record("scripts list_unets")
- shared.reload_hypernetworks()
- startup_timer.record("reload hypernets")
+ def load_model():
+ """
+ Accesses shared.sd_model property to load model.
+ After it's available, if it has been loaded before this access by some extension,
+ its optimization may be None because the list of optimizaers has neet been filled
+ by that time, so we apply optimization again.
+ """
- ui_extra_networks.intialize()
- ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
- ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
- ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
+ shared.sd_model # noqa: B018
- extra_networks.initialize()
- extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
- startup_timer.record("extra networks")
+ if modules.sd_hijack.current_optimizer is None:
+ modules.sd_hijack.apply_optimizations()
- if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
+ Thread(target=load_model).start()
- try:
- if not os.path.exists(cmd_opts.tls_keyfile):
- print("Invalid path to TLS keyfile given")
- if not os.path.exists(cmd_opts.tls_certfile):
- print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
- except TypeError:
- cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
- print("TLS setup invalid, running webui without TLS")
- else:
- print("Running with TLS")
- startup_timer.record("TLS")
+ Thread(target=devices.first_time_calculation).start()
- # make the program just exit at ctrl+c without waiting for anything
- def sigint_handler(sig, frame):
- print(f'Interrupted with signal {sig} in {frame}')
- os._exit(0)
+ shared.reload_hypernetworks()
+ startup_timer.record("reload hypernetworks")
+
+ ui_extra_networks.initialize()
+ ui_extra_networks.register_default_pages()
- signal.signal(signal.SIGINT, sigint_handler)
+ extra_networks.initialize()
+ extra_networks.register_default_extra_networks()
+ startup_timer.record("initialize extra networks")
def setup_middleware(app):
- app.middleware_stack = None # reset current middleware to allow modifying user provided list
+ app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
- if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- elif cmd_opts.cors_allow_origins:
- app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- elif cmd_opts.cors_allow_origins_regex:
- app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
- app.build_middleware_stack() # rebuild middleware stack on-the-fly
+ configure_cors_middleware(app)
+ app.build_middleware_stack() # rebuild middleware stack on-the-fly
+
+
+def configure_cors_middleware(app):
+ cors_options = {
+ "allow_methods": ["*"],
+ "allow_headers": ["*"],
+ "allow_credentials": True,
+ }
+ if cmd_opts.cors_allow_origins:
+ cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
+ if cmd_opts.cors_allow_origins_regex:
+ cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
+ app.add_middleware(CORSMiddleware, **cors_options)
def create_api(app):
@@ -255,19 +360,6 @@ def create_api(app):
return api
-def wait_on_server(demo=None):
- while 1:
- time.sleep(0.5)
- if shared.state.need_restart:
- shared.state.need_restart = False
- time.sleep(0.5)
- demo.close()
- time.sleep(0.5)
-
- modules.script_callbacks.app_reload_callback()
- break
-
-
def api_only():
initialize()
@@ -278,7 +370,12 @@ def api_only():
modules.script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
- api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
+ api.launch(
+ server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
+ port=cmd_opts.port if cmd_opts.port else 7861,
+ root_path = f"/{cmd_opts.subpath}"
+ )
+
def webui():
launch_api = cmd_opts.api
@@ -298,23 +395,7 @@ def webui():
if not cmd_opts.no_gradio_queue:
shared.demo.queue(64)
- gradio_auth_creds = []
- if cmd_opts.gradio_auth:
- gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
- if cmd_opts.gradio_auth_path:
- with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
- for line in file.readlines():
- gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
-
- # this restores the missing /docs endpoint
- if launch_api and not hasattr(FastAPI, 'original_setup'):
- def fastapi_setup(self):
- self.docs_url = "/docs"
- self.redoc_url = "/redoc"
- self.original_setup()
-
- FastAPI.original_setup = FastAPI.setup
- FastAPI.setup = fastapi_setup
+ gradio_auth_creds = list(get_gradio_auth_creds()) or None
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
@@ -324,10 +405,17 @@ def webui():
ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug,
- auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
- inbrowser=cmd_opts.autolaunch,
- prevent_thread_lock=True
+ auth=gradio_auth_creds,
+ inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING ') != '1',
+ prevent_thread_lock=True,
+ allowed_paths=cmd_opts.gradio_allowed_path,
+ app_kwargs={
+ "docs_url": "/docs",
+ "redoc_url": "/redoc",
+ },
+ root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
)
+
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
@@ -349,69 +437,41 @@ def webui():
ui_extra_networks.add_pages_to_demo(app)
- modules.script_callbacks.app_started_callback(shared.demo, app)
- startup_timer.record("scripts app_started_callback")
+ startup_timer.record("add APIs")
+ with startup_timer.subcategory("app_started_callback"):
+ modules.script_callbacks.app_started_callback(shared.demo, app)
+
+ timer.startup_record = startup_timer.dump()
print(f"Startup time: {startup_timer.summary()}.")
- if cmd_opts.subpath:
- redirector = FastAPI()
- redirector.get("/")
- mounted_app = gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
+ try:
+ while True:
+ server_command = shared.state.wait_for_server_command(timeout=5)
+ if server_command:
+ if server_command in ("stop", "restart"):
+ break
+ else:
+ print(f"Unknown server command: {server_command}")
+ except KeyboardInterrupt:
+ print('Caught KeyboardInterrupt, stopping...')
+ server_command = "stop"
+
+ if server_command == "stop":
+ print("Stopping server...")
+ # If we catch a keyboard interrupt, we want to stop the server and exit.
+ shared.demo.close()
+ break
- wait_on_server(shared.demo)
print('Restarting UI...')
-
+ shared.demo.close()
+ time.sleep(0.5)
startup_timer.reset()
-
- sd_samplers.set_samplers()
-
+ modules.script_callbacks.app_reload_callback()
+ startup_timer.record("app reload callback")
modules.script_callbacks.script_unloaded_callback()
- extensions.list_extensions()
- startup_timer.record("list extensions")
-
- config_state_file = shared.opts.restore_config_state_file
- shared.opts.restore_config_state_file = ""
- shared.opts.save(shared.config_filename)
-
- if os.path.isfile(config_state_file):
- print(f"*** About to restore extension state from file: {config_state_file}")
- with open(config_state_file, "r", encoding="utf-8") as f:
- config_state = json.load(f)
- config_states.restore_extension_config(config_state)
- startup_timer.record("restore extension config")
- elif config_state_file:
- print(f"!!! Config state backup not found: {config_state_file}")
-
- localization.list_localizations(cmd_opts.localizations_dir)
-
- modules.scripts.reload_scripts()
- startup_timer.record("load scripts")
-
- modules.script_callbacks.model_loaded_callback(shared.sd_model)
- startup_timer.record("model loaded callback")
-
- modelloader.load_upscalers()
- startup_timer.record("load upscalers")
-
- for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
- importlib.reload(module)
- startup_timer.record("reload script modules")
-
- modules.sd_models.list_models()
- startup_timer.record("list SD models")
-
- shared.reload_hypernetworks()
- startup_timer.record("reload hypernetworks")
-
- ui_extra_networks.intialize()
- ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
- ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
- ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
-
- extra_networks.initialize()
- extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
- startup_timer.record("initialize extra networks")
+ startup_timer.record("scripts unloaded callback")
+ initialize_rest(reload_script_modules=True)
if __name__ == "__main__":
diff --git a/webui.sh b/webui.sh
index 19cf2f78..246381fc 100755
--- a/webui.sh
+++ b/webui.sh
@@ -4,26 +4,28 @@
# change the variables in webui-user.sh instead #
#################################################
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+
# If run from macOS, load defaults from webui-macos-env.sh
if [[ "$OSTYPE" == "darwin"* ]]; then
- if [[ -f webui-macos-env.sh ]]
+ if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]]
then
- source ./webui-macos-env.sh
+ source "$SCRIPT_DIR"/webui-macos-env.sh
fi
fi
# Read variables from webui-user.sh
# shellcheck source=/dev/null
-if [[ -f webui-user.sh ]]
+if [[ -f "$SCRIPT_DIR"/webui-user.sh ]]
then
- source ./webui-user.sh
+ source "$SCRIPT_DIR"/webui-user.sh
fi
# Set defaults
# Install directory without trailing slash
if [[ -z "${install_dir}" ]]
then
- install_dir="$(pwd)"
+ install_dir="$SCRIPT_DIR"
fi
# Name of the subdirectory (defaults to stable-diffusion-webui)
@@ -94,6 +96,14 @@ else
printf "\n%s\n" "${delimiter}"
fi
+if [[ $(getconf LONG_BIT) = 32 ]]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: Unsupported Running on a 32bit OS\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
if [[ -d .git ]]
then
printf "\n%s\n" "${delimiter}"
@@ -104,9 +114,24 @@ then
fi
# Check prerequisites
-gpu_info=$(lspci 2>/dev/null | grep VGA)
+gpu_info=$(lspci 2>/dev/null | grep -E "VGA|Display")
case "$gpu_info" in
- *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
+ *"Navi 1"*)
+ export HSA_OVERRIDE_GFX_VERSION=10.3.0
+ if [[ -z "${TORCH_COMMAND}" ]]
+ then
+ pyv="$(${python_cmd} -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')"
+ if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]]
+ then
+ # Navi users will still use torch 1.13 because 2.0 does not seem to work.
+ export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2"
+ else
+ printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m"
+ exit 1
+ fi
+ fi
+ ;;
+ *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}"
@@ -116,11 +141,13 @@ case "$gpu_info" in
*)
;;
esac
-if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+if ! echo "$gpu_info" | grep -q "NVIDIA";
then
- # AMD users will still use torch 1.13 because 2.0 does not seem to work.
- export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2"
-fi
+ if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
+ then
+ export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
+ fi
+fi
for preq in "${GIT}" "${python_cmd}"
do
@@ -183,7 +210,7 @@ fi
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
- TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
+ TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}"
@@ -193,17 +220,24 @@ prepare_tcmalloc() {
fi
}
-if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
-then
- printf "\n%s\n" "${delimiter}"
- printf "Accelerating launch.py..."
- printf "\n%s\n" "${delimiter}"
- prepare_tcmalloc
- exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
-else
- printf "\n%s\n" "${delimiter}"
- printf "Launching launch.py..."
- printf "\n%s\n" "${delimiter}"
- prepare_tcmalloc
- exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
-fi
+KEEP_GOING=1
+export SD_WEBUI_RESTART=tmp/restart
+while [[ "$KEEP_GOING" -eq "1" ]]; do
+ if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]; then
+ printf "\n%s\n" "${delimiter}"
+ printf "Accelerating launch.py..."
+ printf "\n%s\n" "${delimiter}"
+ prepare_tcmalloc
+ accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
+ else
+ printf "\n%s\n" "${delimiter}"
+ printf "Launching launch.py..."
+ printf "\n%s\n" "${delimiter}"
+ prepare_tcmalloc
+ "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+ fi
+
+ if [[ ! -f tmp/restart ]]; then
+ KEEP_GOING=0
+ fi
+done