diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 31 | ||||
-rw-r--r-- | modules/config_states.py | 16 | ||||
-rw-r--r-- | modules/progress.py | 3 | ||||
-rw-r--r-- | modules/sd_models.py | 13 | ||||
-rw-r--r-- | modules/sd_vae.py | 4 | ||||
-rw-r--r-- | modules/shared_options.py | 6 | ||||
-rw-r--r-- | modules/ui_extensions.py | 226 |
7 files changed, 178 insertions, 121 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 6e8d21a3..42fbbe3d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,6 +4,8 @@ import os import time import datetime import uvicorn +import ipaddress +import requests import gradio as gr from threading import Lock from io import BytesIO @@ -55,10 +57,35 @@ def setUpscalers(req: dict): return reqDict +def verify_url(url): + """Returns True if the url refers to a global resource.""" + + import socket + from urllib.parse import urlparse + try: + parsed_url = urlparse(url) + domain_name = parsed_url.netloc + host = socket.gethostbyname_ex(domain_name) + for ip in host[2]: + ip_addr = ipaddress.ip_address(ip) + if not ip_addr.is_global: + return False + except Exception: + return False + + return True + + def decode_base64_to_image(encoding): if encoding.startswith("http://") or encoding.startswith("https://"): - import requests - response = requests.get(encoding, timeout=30, headers={'user-agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/62.0.3202.94 Safari/537.36'}) + if not opts.api_enable_requests: + raise HTTPException(status_code=500, detail="Requests not allowed") + + if opts.api_forbid_local_requests and not verify_url(encoding): + raise HTTPException(status_code=500, detail="Request to local resource not allowed") + + headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} + response = requests.get(encoding, timeout=30, headers=headers) try: image = Image.open(BytesIO(response.content)) return image diff --git a/modules/config_states.py b/modules/config_states.py index 6f1ab53f..b766aef1 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -8,14 +8,12 @@ import time import tqdm from datetime import datetime -from collections import OrderedDict import git from modules import shared, extensions, errors from modules.paths_internal import script_path, config_states_dir - -all_config_states = OrderedDict() +all_config_states = {} def list_config_states(): @@ -28,10 +26,14 @@ def list_config_states(): for filename in os.listdir(config_states_dir): if filename.endswith(".json"): path = os.path.join(config_states_dir, filename) - with open(path, "r", encoding="utf-8") as f: - j = json.load(f) - j["filepath"] = path - config_states.append(j) + try: + with open(path, "r", encoding="utf-8") as f: + j = json.load(f) + assert "created_at" in j, '"created_at" does not exist' + j["filepath"] = path + config_states.append(j) + except Exception as e: + print(f'[ERROR]: Config states {path}, {e}') config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) diff --git a/modules/progress.py b/modules/progress.py index f405f07f..e32b59dd 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -48,6 +48,7 @@ def add_task_to_queue(id_job): class ProgressRequest(BaseModel):
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+ live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
class ProgressResponse(BaseModel):
@@ -91,7 +92,7 @@ def progressapi(req: ProgressRequest): id_live_preview = req.id_live_preview
shared.state.set_current_image()
- if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
+ if opts.live_previews_enable and req.live_preview and shared.state.id_live_preview != req.id_live_preview:
image = shared.state.current_image
if image is not None:
buffered = io.BytesIO()
diff --git a/modules/sd_models.py b/modules/sd_models.py index 685585b1..27d15e66 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -485,8 +485,12 @@ class SdModelData: return self.sd_model
- def set_sd_model(self, v):
+ def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
+ if already_loaded:
+ sd_vae.base_vae = getattr(v, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
try:
self.loaded_sd_models.remove(v)
@@ -660,13 +664,14 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): send_model_to_device(already_loaded)
timer.record("send model to device")
- model_data.set_sd_model(already_loaded)
+ model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
@@ -678,6 +683,10 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
+
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index dbade067..ee118656 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -192,7 +192,7 @@ def load_vae_dict(filename, map_location): def load_vae(model, vae_file=None, vae_source="from unknown source"): - global vae_dict, loaded_vae_file + global vae_dict, base_vae, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -230,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): restore_base_vae(model) loaded_vae_file = vae_file + model.base_vae = base_vae + model.loaded_vae_file = loaded_vae_file # don't call this from outside diff --git a/modules/shared_options.py b/modules/shared_options.py index 8630d474..5f30e8e9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -111,6 +111,12 @@ options_templates.update(options_section(('system', "System"), { "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
}))
+options_templates.update(options_section(('API', "API"), {
+ "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API"),
+ "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources"),
+ "api_useragent": OptionInfo("", "User agent for requests"),
+}))
+
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 15a8b0bf..e0138267 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -65,7 +65,7 @@ def save_config_state(name): filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f:
- json.dump(current_config_state, f)
+ json.dump(current_config_state, f, indent=4)
config_states.list_config_states()
new_value = next(iter(config_states.all_config_states.keys()), "Current")
new_choices = ["Current"] + list(config_states.all_config_states.keys())
@@ -200,119 +200,129 @@ def update_config_states_table(state_name): created_date = time.asctime(time.gmtime(config_state["created_at"]))
filepath = config_state.get("filepath", "<unknown>")
- code = f"""<!-- {time.time()} -->"""
-
- webui_remote = config_state["webui"]["remote"] or ""
- webui_branch = config_state["webui"]["branch"]
- webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
- webui_commit_date = config_state["webui"]["commit_date"]
- if webui_commit_date:
- webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
- else:
- webui_commit_date = "<unknown>"
-
- remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
- commit_link = make_commit_link(webui_commit_hash, webui_remote)
- date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
-
- current_webui = config_states.get_webui_config()
-
- style_remote = ""
- style_branch = ""
- style_commit = ""
- if current_webui["remote"] != webui_remote:
- style_remote = STYLE_PRIMARY
- if current_webui["branch"] != webui_branch:
- style_branch = STYLE_PRIMARY
- if current_webui["commit_hash"] != webui_commit_hash:
- style_commit = STYLE_PRIMARY
-
- code += f"""<h2>Config Backup: {config_name}</h2>
- <div><b>Filepath:</b> {filepath}</div>
- <div><b>Created at:</b> {created_date}</div>"""
-
- code += f"""<h2>WebUI State</h2>
- <table id="config_state_webui">
- <thead>
- <tr>
- <th>URL</th>
- <th>Branch</th>
- <th>Commit</th>
- <th>Date</th>
- </tr>
- </thead>
- <tbody>
- <tr>
- <td><label{style_remote}>{remote}</label></td>
- <td><label{style_branch}>{webui_branch}</label></td>
- <td><label{style_commit}>{commit_link}</label></td>
- <td><label{style_commit}>{date_link}</label></td>
- </tr>
- </tbody>
- </table>
- """
-
- code += """<h2>Extension State</h2>
- <table id="config_state_extensions">
- <thead>
- <tr>
- <th>Extension</th>
- <th>URL</th>
- <th>Branch</th>
- <th>Commit</th>
- <th>Date</th>
- </tr>
- </thead>
- <tbody>
- """
-
- ext_map = {ext.name: ext for ext in extensions.extensions}
-
- for ext_name, ext_conf in config_state["extensions"].items():
- ext_remote = ext_conf["remote"] or ""
- ext_branch = ext_conf["branch"] or "<unknown>"
- ext_enabled = ext_conf["enabled"]
- ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
- ext_commit_date = ext_conf["commit_date"]
- if ext_commit_date:
- ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
+ try:
+ webui_remote = config_state["webui"]["remote"] or ""
+ webui_branch = config_state["webui"]["branch"]
+ webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
+ webui_commit_date = config_state["webui"]["commit_date"]
+ if webui_commit_date:
+ webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
else:
- ext_commit_date = "<unknown>"
+ webui_commit_date = "<unknown>"
- remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
- commit_link = make_commit_link(ext_commit_hash, ext_remote)
- date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
+ remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
+ commit_link = make_commit_link(webui_commit_hash, webui_remote)
+ date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
+
+ current_webui = config_states.get_webui_config()
- style_enabled = ""
style_remote = ""
style_branch = ""
style_commit = ""
- if ext_name in ext_map:
- current_ext = ext_map[ext_name]
- current_ext.read_info_from_repo()
- if current_ext.enabled != ext_enabled:
- style_enabled = STYLE_PRIMARY
- if current_ext.remote != ext_remote:
- style_remote = STYLE_PRIMARY
- if current_ext.branch != ext_branch:
- style_branch = STYLE_PRIMARY
- if current_ext.commit_hash != ext_commit_hash:
- style_commit = STYLE_PRIMARY
-
- code += f"""
- <tr>
- <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
- <td><label{style_remote}>{remote}</label></td>
- <td><label{style_branch}>{ext_branch}</label></td>
- <td><label{style_commit}>{commit_link}</label></td>
- <td><label{style_commit}>{date_link}</label></td>
- </tr>
- """
-
- code += """
- </tbody>
- </table>
- """
+ if current_webui["remote"] != webui_remote:
+ style_remote = STYLE_PRIMARY
+ if current_webui["branch"] != webui_branch:
+ style_branch = STYLE_PRIMARY
+ if current_webui["commit_hash"] != webui_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code = f"""<!-- {time.time()} -->
+<h2>Config Backup: {config_name}</h2>
+<div><b>Filepath:</b> {filepath}</div>
+<div><b>Created at:</b> {created_date}</div>
+<h2>WebUI State</h2>
+<table id="config_state_webui">
+ <thead>
+ <tr>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td>
+ <label{style_remote}>{remote}</label>
+ </td>
+ <td>
+ <label{style_branch}>{webui_branch}</label>
+ </td>
+ <td>
+ <label{style_commit}>{commit_link}</label>
+ </td>
+ <td>
+ <label{style_commit}>{date_link}</label>
+ </td>
+ </tr>
+ </tbody>
+</table>
+<h2>Extension State</h2>
+<table id="config_state_extensions">
+ <thead>
+ <tr>
+ <th>Extension</th>
+ <th>URL</th>
+ <th>Branch</th>
+ <th>Commit</th>
+ <th>Date</th>
+ </tr>
+ </thead>
+ <tbody>
+"""
+
+ ext_map = {ext.name: ext for ext in extensions.extensions}
+
+ for ext_name, ext_conf in config_state["extensions"].items():
+ ext_remote = ext_conf["remote"] or ""
+ ext_branch = ext_conf["branch"] or "<unknown>"
+ ext_enabled = ext_conf["enabled"]
+ ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
+ ext_commit_date = ext_conf["commit_date"]
+ if ext_commit_date:
+ ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
+ else:
+ ext_commit_date = "<unknown>"
+
+ remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
+ commit_link = make_commit_link(ext_commit_hash, ext_remote)
+ date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
+
+ style_enabled = ""
+ style_remote = ""
+ style_branch = ""
+ style_commit = ""
+ if ext_name in ext_map:
+ current_ext = ext_map[ext_name]
+ current_ext.read_info_from_repo()
+ if current_ext.enabled != ext_enabled:
+ style_enabled = STYLE_PRIMARY
+ if current_ext.remote != ext_remote:
+ style_remote = STYLE_PRIMARY
+ if current_ext.branch != ext_branch:
+ style_branch = STYLE_PRIMARY
+ if current_ext.commit_hash != ext_commit_hash:
+ style_commit = STYLE_PRIMARY
+
+ code += f""" <tr>
+ <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
+ <td><label{style_remote}>{remote}</label></td>
+ <td><label{style_branch}>{ext_branch}</label></td>
+ <td><label{style_commit}>{commit_link}</label></td>
+ <td><label{style_commit}>{date_link}</label></td>
+ </tr>
+"""
+
+ code += """ </tbody>
+</table>"""
+
+ except Exception as e:
+ print(f"[ERROR]: Config states {filepath}, {e}")
+ code = f"""<!-- {time.time()} -->
+<h2>Config Backup: {config_name}</h2>
+<div><b>Filepath:</b> {filepath}</div>
+<div><b>Created at:</b> {created_date}</div>
+<h2>This file is corrupted</h2>"""
return code
|