diff options
-rw-r--r-- | extensions-builtin/LDSR/ldsr_model_arch.py | 2 | ||||
-rw-r--r-- | modules/api/api.py | 94 | ||||
-rw-r--r-- | modules/api/models.py | 9 | ||||
-rw-r--r-- | modules/codeformer/vqgan_arch.py | 4 | ||||
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 30 | ||||
-rw-r--r-- | modules/hypernetworks/ui.py | 31 | ||||
-rw-r--r-- | modules/images.py | 2 | ||||
-rw-r--r-- | modules/interrogate.py | 2 | ||||
-rw-r--r-- | modules/safe.py | 47 | ||||
-rw-r--r-- | modules/sd_models.py | 8 | ||||
-rw-r--r-- | modules/sd_vae.py | 2 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 2 | ||||
-rw-r--r-- | modules/ui.py | 2 | ||||
-rw-r--r-- | scripts/prompts_from_file.py | 2 |
14 files changed, 189 insertions, 48 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index f5bd8ae4..0ad49f4e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -26,7 +26,7 @@ class LDSR: global cached_ldsr_model if shared.opts.ldsr_cached and cached_ldsr_model is not None: - print(f"Loading model from cache") + print("Loading model from cache") model: torch.nn.Module = cached_ldsr_model else: print(f"Loading model from {self.modelPath}") diff --git a/modules/api/api.py b/modules/api/api.py index b43dd16b..1ceba75d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru +from modules import sd_samplers, deepbooru, sd_hijack from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras, run_pnginfo +from modules.textual_inversion.textual_inversion import create_embedding, train_embedding +from modules.textual_inversion.preprocess import preprocess +from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image from modules.sd_models import checkpoints_list from modules.realesrgan_model import get_realesrgan_models +from modules import devices from typing import List def upscaler_to_index(name: str): @@ -97,6 +101,11 @@ class Api: self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) + self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) + self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse) + self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) + self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -326,6 +335,89 @@ class Api: def refresh_checkpoints(self): shared.refresh_checkpoints() + def create_embedding(self, args: dict): + try: + shared.state.begin() + filename = create_embedding(**args) # create empty embedding + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used + shared.state.end() + return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename)) + except AssertionError as e: + shared.state.end() + return TrainResponse(info = "create embedding error: {error}".format(error = e)) + + def create_hypernetwork(self, args: dict): + try: + shared.state.begin() + filename = create_hypernetwork(**args) # create empty embedding + shared.state.end() + return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename)) + except AssertionError as e: + shared.state.end() + return TrainResponse(info = "create hypernetwork error: {error}".format(error = e)) + + def preprocess(self, args: dict): + try: + shared.state.begin() + preprocess(**args) # quick operation unless blip/booru interrogation is enabled + shared.state.end() + return PreprocessResponse(info = 'preprocess complete') + except KeyError as e: + shared.state.end() + return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e)) + except AssertionError as e: + shared.state.end() + return PreprocessResponse(info = "preprocess error: {error}".format(error = e)) + except FileNotFoundError as e: + shared.state.end() + return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e)) + + def train_embedding(self, args: dict): + try: + shared.state.begin() + apply_optimizations = shared.opts.training_xattention_optimizations + error = None + filename = '' + if not apply_optimizations: + sd_hijack.undo_optimizations() + try: + embedding, filename = train_embedding(**args) # can take a long time to complete + except Exception as e: + error = e + finally: + if not apply_optimizations: + sd_hijack.apply_optimizations() + shared.state.end() + return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + except AssertionError as msg: + shared.state.end() + return TrainResponse(info = "train embedding error: {msg}".format(msg = msg)) + + def train_hypernetwork(self, args: dict): + try: + shared.state.begin() + initial_hypernetwork = shared.loaded_hypernetwork + apply_optimizations = shared.opts.training_xattention_optimizations + error = None + filename = '' + if not apply_optimizations: + sd_hijack.undo_optimizations() + try: + hypernetwork, filename = train_hypernetwork(*args) + except Exception as e: + error = e + finally: + shared.loaded_hypernetwork = initial_hypernetwork + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + if not apply_optimizations: + sd_hijack.apply_optimizations() + shared.state.end() + return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + except AssertionError as msg: + shared.state.end() + return TrainResponse(info = "train embedding error: {error}".format(error = error)) + def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port) diff --git a/modules/api/models.py b/modules/api/models.py index a22bc6b3..c446ce7a 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel): class InterrogateResponse(BaseModel): caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") +class TrainResponse(BaseModel): + info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.") + +class CreateResponse(BaseModel): + info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.") + +class PreprocessResponse(BaseModel): + info: str = Field(title="Preprocess info", description="Response string from preprocessing task.") + fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index c06c590c..e7293683 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module): self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) logger.info(f'vqgan is loaded from: {model_path} [params]') else: - raise ValueError(f'Wrong params!') + raise ValueError('Wrong params!') def forward(self, x): @@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module): elif 'params' in chkpt: self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) else: - raise ValueError(f'Wrong params!') + raise ValueError('Wrong params!') def forward(self, x): return self.main(x)
\ No newline at end of file diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c406ffb3..109e8078 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -277,7 +277,7 @@ def load_hypernetwork(filename): print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
- print(f"Unloading hypernetwork")
+ print("Unloading hypernetwork")
shared.loaded_hypernetwork = None
@@ -378,6 +378,32 @@ def report_statistics(loss_info:dict): print(e)
+def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+
+ fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
+
+ if type(layer_structure) == str:
+ layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
+
+ hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
+ name=name,
+ enable_sizes=[int(x) for x in enable_sizes],
+ layer_structure=layer_structure,
+ activation_func=activation_func,
+ weight_init=weight_init,
+ add_layer_norm=add_layer_norm,
+ use_dropout=use_dropout,
+ )
+ hypernet.save(fn)
+
+ shared.reload_hypernetworks()
+
+ return fn
+
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
@@ -417,7 +443,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, initial_step = hypernetwork.step or 0
if initial_step >= steps:
- shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index c2d4b51c..e7f9e593 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -3,39 +3,16 @@ import os import re
import gradio as gr
-import modules.textual_inversion.preprocess
-import modules.textual_inversion.textual_inversion
+import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared
-from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"]
-keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
- # Remove illegal characters from name.
- name = "".join( x for x in name if (x.isalnum() or x in "._- "))
+ filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
- fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
- if not overwrite_old:
- assert not os.path.exists(fn), f"file {fn} already exists"
-
- if type(layer_structure) == str:
- layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
-
- hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
- name=name,
- enable_sizes=[int(x) for x in enable_sizes],
- layer_structure=layer_structure,
- activation_func=activation_func,
- weight_init=weight_init,
- add_layer_norm=add_layer_norm,
- use_dropout=use_dropout,
- )
- hypernet.save(fn)
-
- shared.reload_hypernetworks()
-
- return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
+ return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
def train_hypernetwork(*args):
diff --git a/modules/images.py b/modules/images.py index 809ad9f7..31d4528d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -599,7 +599,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
- print(f"Error parsing NovelAI image generation parameters:", file=sys.stderr)
+ print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return geninfo, items
diff --git a/modules/interrogate.py b/modules/interrogate.py index 0068b81c..46935210 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -172,7 +172,7 @@ class InterrogateModels: res += ", " + match
except Exception:
- print(f"Error interrogating", file=sys.stderr)
+ print("Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
res += "<error>"
diff --git a/modules/safe.py b/modules/safe.py index 479c8b86..82d44be3 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -103,7 +103,7 @@ def check_pt(filename, extra_handler): def load(filename, *args, **kwargs):
- return load_with_extra(filename, *args, **kwargs)
+ return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
@@ -137,19 +137,56 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
+ print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
+ print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
return None
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
- print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
+ print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
+ print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
+class Extra:
+ """
+ A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
+ (because it's not your code making the torch.load call). The intended use is like this:
+
+```
+import torch
+from modules import safe
+
+def handler(module, name):
+ if module == 'torch' and name in ['float64', 'float16']:
+ return getattr(torch, name)
+
+ return None
+
+with safe.Extra(handler):
+ x = torch.load('model.pt')
+```
+ """
+
+ def __init__(self, handler):
+ self.handler = handler
+
+ def __enter__(self):
+ global global_extra_handler
+
+ assert global_extra_handler is None, 'already inside an Extra() block'
+ global_extra_handler = self.handler
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ global global_extra_handler
+
+ global_extra_handler = None
+
+
unsafe_torch_load = torch.load
torch.load = load
+global_extra_handler = None
+
diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ca06211..ecdd91c5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -117,13 +117,13 @@ def select_checkpoint(): return checkpoint_info
if len(checkpoints_list) == 0:
- print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
+ print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
if shared.cmd_opts.ckpt is not None:
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
+ print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
checkpoint_info = next(iter(checkpoints_list.values()))
@@ -324,7 +324,7 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model)
- print(f"Model loaded.")
+ print("Model loaded.")
return sd_model
@@ -359,5 +359,5 @@ def reload_model_weights(sd_model=None, info=None): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
- print(f"Weights loaded.")
+ print("Weights loaded.")
return sd_model
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 25638a83..3856418e 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -208,5 +208,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print(f"VAE Weights loaded.") + print("VAE Weights loaded.") return sd_model diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index daf3997b..f6112578 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -263,7 +263,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ initial_step = embedding.step or 0
if initial_step >= steps:
- shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
diff --git a/modules/ui.py b/modules/ui.py index 65af8966..57ee0465 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -270,7 +270,7 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name): def interrogate(image):
- prompt = shared.interrogator.interrogate(image)
+ prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr_show(True) if prompt is None else prompt
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 6e118ddb..e8386ed2 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -140,7 +140,7 @@ class Script(scripts.Script): try:
args = cmdargs(line)
except Exception:
- print(f"Error parsing line [line] as commandline:", file=sys.stderr)
+ print(f"Error parsing line {line} as commandline:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
args = {"prompt": line}
else:
|