diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/api/api.py | 15 | ||||
-rw-r--r-- | modules/artists.py | 25 | ||||
-rw-r--r-- | modules/extra_networks.py | 147 | ||||
-rw-r--r-- | modules/extra_networks_hypernet.py | 21 | ||||
-rw-r--r-- | modules/extras.py | 14 | ||||
-rw-r--r-- | modules/generation_parameters_copypaste.py | 12 | ||||
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 107 | ||||
-rw-r--r-- | modules/hypernetworks/ui.py | 5 | ||||
-rw-r--r-- | modules/interrogate.py | 55 | ||||
-rw-r--r-- | modules/processing.py | 26 | ||||
-rw-r--r-- | modules/script_callbacks.py | 15 | ||||
-rw-r--r-- | modules/sd_hijack.py | 7 | ||||
-rw-r--r-- | modules/sd_hijack_optimizations.py | 10 | ||||
-rw-r--r-- | modules/sd_models.py | 15 | ||||
-rw-r--r-- | modules/shared.py | 28 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 2 | ||||
-rw-r--r-- | modules/ui.py | 166 | ||||
-rw-r--r-- | modules/ui_components.py | 10 | ||||
-rw-r--r-- | modules/ui_extra_networks.py | 155 | ||||
-rw-r--r-- | modules/ui_extra_networks_hypernets.py | 34 | ||||
-rw-r--r-- | modules/ui_extra_networks_textual_inversion.py | 32 |
21 files changed, 697 insertions, 204 deletions
diff --git a/modules/api/api.py b/modules/api/api.py index 9814bbc2..f2e9e884 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -126,8 +126,6 @@ class Api: self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) @@ -390,12 +388,6 @@ class Api: return styleList - def get_artists_categories(self): - return shared.artist_db.cats - - def get_artists(self): - return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] - def get_embeddings(self): db = sd_hijack.model_hijack.embedding_db @@ -480,7 +472,7 @@ class Api: def train_hypernetwork(self, args: dict): try: shared.state.begin() - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -491,16 +483,15 @@ class Api: except Exception as e: error = e finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {error}".format(error = error)) + return TrainResponse(info="train embedding error: {error}".format(error=error)) def get_memory(self): try: diff --git a/modules/artists.py b/modules/artists.py deleted file mode 100644 index 3612758b..00000000 --- a/modules/artists.py +++ /dev/null @@ -1,25 +0,0 @@ -import os.path
-import csv
-from collections import namedtuple
-
-Artist = namedtuple("Artist", ['name', 'weight', 'category'])
-
-
-class ArtistsDatabase:
- def __init__(self, filename):
- self.cats = set()
- self.artists = []
-
- if not os.path.exists(filename):
- return
-
- with open(filename, "r", newline='', encoding="utf8") as file:
- reader = csv.DictReader(file)
-
- for row in reader:
- artist = Artist(row["artist"], float(row["score"]), row["category"])
- self.artists.append(artist)
- self.cats.add(artist.category)
-
- def categories(self):
- return sorted(self.cats)
diff --git a/modules/extra_networks.py b/modules/extra_networks.py new file mode 100644 index 00000000..1978673d --- /dev/null +++ b/modules/extra_networks.py @@ -0,0 +1,147 @@ +import re
+from collections import defaultdict
+
+from modules import errors
+
+extra_network_registry = {}
+
+
+def initialize():
+ extra_network_registry.clear()
+
+
+def register_extra_network(extra_network):
+ extra_network_registry[extra_network.name] = extra_network
+
+
+class ExtraNetworkParams:
+ def __init__(self, items=None):
+ self.items = items or []
+
+
+class ExtraNetwork:
+ def __init__(self, name):
+ self.name = name
+
+ def activate(self, p, params_list):
+ """
+ Called by processing on every run. Whatever the extra network is meant to do should be activated here.
+ Passes arguments related to this extra network in params_list.
+ User passes arguments by specifying this in his prompt:
+
+ <name:arg1:arg2:arg3>
+
+ Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
+ separated by colon.
+
+ Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
+ in this case, all effects of this extra networks should be disabled.
+
+ Can be called multiple times before deactivate() - each new call should override the previous call completely.
+
+ For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
+
+ > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
+
+ params_list will be:
+
+ [
+ ExtraNetworkParams(items=["agm", "1.1"]),
+ ExtraNetworkParams(items=["ray"])
+ ]
+
+ """
+ raise NotImplementedError
+
+ def deactivate(self, p):
+ """
+ Called at the end of processing for housekeeping. No need to do anything here.
+ """
+
+ raise NotImplementedError
+
+
+def activate(p, extra_network_data):
+ """call activate for extra networks in extra_network_data in specified order, then call
+ activate for all remaining registered networks with an empty argument list"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ print(f"Skipping unknown extra network: {extra_network_name}")
+ continue
+
+ try:
+ extra_network.activate(p, extra_network_args)
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.activate(p, [])
+ except Exception as e:
+ errors.display(e, f"activating extra network {extra_network_name}")
+
+
+def deactivate(p, extra_network_data):
+ """call deactivate for extra networks in extra_network_data in specified order, then call
+ deactivate for all remaining registered networks"""
+
+ for extra_network_name, extra_network_args in extra_network_data.items():
+ extra_network = extra_network_registry.get(extra_network_name, None)
+ if extra_network is None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating extra network {extra_network_name}")
+
+ for extra_network_name, extra_network in extra_network_registry.items():
+ args = extra_network_data.get(extra_network_name, None)
+ if args is not None:
+ continue
+
+ try:
+ extra_network.deactivate(p)
+ except Exception as e:
+ errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
+
+
+re_extra_net = re.compile(r"<(\w+):([^>]+)>")
+
+
+def parse_prompt(prompt):
+ res = defaultdict(list)
+
+ def found(m):
+ name = m.group(1)
+ args = m.group(2)
+
+ res[name].append(ExtraNetworkParams(items=args.split(":")))
+
+ return ""
+
+ prompt = re.sub(re_extra_net, found, prompt)
+
+ return prompt, res
+
+
+def parse_prompts(prompts):
+ res = []
+ extra_data = None
+
+ for prompt in prompts:
+ updated_prompt, parsed_extra_data = parse_prompt(prompt)
+
+ if extra_data is None:
+ extra_data = parsed_extra_data
+
+ res.append(updated_prompt)
+
+ return res, extra_data
+
diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py new file mode 100644 index 00000000..ff279a1f --- /dev/null +++ b/modules/extra_networks_hypernet.py @@ -0,0 +1,21 @@ +from modules import extra_networks
+from modules.hypernetworks import hypernetwork
+
+
+class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
+ def __init__(self):
+ super().__init__('hypernet')
+
+ def activate(self, p, params_list):
+ names = []
+ multipliers = []
+ for params in params_list:
+ assert len(params.items) > 0
+
+ names.append(params.items[0])
+ multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+
+ hypernetwork.load_hypernetworks(names, multipliers)
+
+ def deactivate(self, p):
+ pass
diff --git a/modules/extras.py b/modules/extras.py index d03f976e..1218f88f 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -275,7 +275,7 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename)
-chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
def to_half(tensor, enable):
@@ -303,7 +303,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
- def filename_weighed_sum():
+ def filename_weighted_sum():
a = primary_model_info.model_name
b = secondary_model_info.model_name
Ma = round(1 - multiplier, 2)
@@ -311,7 +311,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ return f"{Ma}({a}) + {Mb}({b})"
- def filename_add_differnece():
+ def filename_add_difference():
a = primary_model_info.model_name
b = secondary_model_info.model_name
c = tertiary_model_info.model_name
@@ -323,8 +323,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ return primary_model_info.model_name
theta_funcs = {
- "Weighted sum": (filename_weighed_sum, None, weighted_sum),
- "Add difference": (filename_add_differnece, get_difference, add_difference),
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
+ "Add difference": (filename_add_difference, get_difference, add_difference),
"No interpolation": (filename_nothing, None, None),
}
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
@@ -362,7 +362,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = 'Merging B and C'
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
- if key in chckpoint_dict_skip_on_merge:
+ if key in checkpoint_dict_skip_on_merge:
continue
if 'model' in key:
@@ -387,7 +387,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ for key in tqdm.tqdm(theta_0.keys()):
if theta_1 and 'model' in key and key in theta_1:
- if key in chckpoint_dict_skip_on_merge:
+ if key in checkpoint_dict_skip_on_merge:
continue
a = theta_0[key]
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a381ff59..46e12dc6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict): from modules import ui
settings_map = {
- 'sd_hypernetwork': 'Hypernet',
- 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash',
@@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Clip skip" not in res:
res["Clip skip"] = "1"
- if "Hypernet strength" not in res:
- res["Hypernet strength"] = "1"
-
- if "Hypernet" in res:
- hypernet_name = res["Hypernet"]
- hypernet_hash = res.get("Hypernet hash", None)
- res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
+ hypernet = res.get("Hypernet", None)
+ if hypernet is not None:
+ res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
if "Hires resize-1" not in res:
res["Hires resize-1"] = 0
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74e78582..80a47c79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,7 +25,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
- multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
@@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module): add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__()
+ self.multiplier = 1.0
+
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
@@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x
def forward(self, x):
- return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
+ return x + self.linear(x) * (self.multiplier if not self.training else 1)
def trainables(self):
layer_structure = []
@@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module): return layer_structure
-def apply_strength(value=None):
- HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
-
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
if layer_structure is None:
@@ -192,6 +190,20 @@ class Hypernetwork: for param in layer.parameters():
param.requires_grad = mode
+ def to(self, device):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.to(device)
+
+ return self
+
+ def set_multiplier(self, multiplier):
+ for k, layers in self.layers.items():
+ for layer in layers:
+ layer.multiplier = multiplier
+
+ return self
+
def eval(self):
for k, layers in self.layers.items():
for layer in layers:
@@ -269,11 +281,13 @@ class Hypernetwork: self.optimizer_state_dict = None
if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
- print("Loaded existing optimizer from checkpoint")
- print(f"Optimizer name is {self.optimizer_name}")
+ if shared.opts.print_hypernet_extra:
+ print("Loaded existing optimizer from checkpoint")
+ print(f"Optimizer name is {self.optimizer_name}")
else:
self.optimizer_name = "AdamW"
- print("No saved optimizer exists in checkpoint")
+ if shared.opts.print_hypernet_extra:
+ print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
@@ -306,23 +320,43 @@ def list_hypernetworks(path): return res
-def load_hypernetwork(filename):
- path = shared.hypernetworks.get(filename, None)
- # Prevent any file named "None.pt" from being loaded.
- if path is not None and filename != "None":
- print(f"Loading hypernetwork {filename}")
- try:
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+def load_hypernetwork(name):
+ path = shared.hypernetworks.get(name, None)
- except Exception:
- print(f"Error loading hypernetwork {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if shared.loaded_hypernetwork is not None:
- print("Unloading hypernetwork")
+ if path is None:
+ return None
+
+ hypernetwork = Hypernetwork()
+
+ try:
+ hypernetwork.load(path)
+ except Exception:
+ print(f"Error loading hypernetwork {path}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ return hypernetwork
+
+
+def load_hypernetworks(names, multipliers=None):
+ already_loaded = {}
+
+ for hypernetwork in shared.loaded_hypernetworks:
+ if hypernetwork.name in names:
+ already_loaded[hypernetwork.name] = hypernetwork
- shared.loaded_hypernetwork = None
+ shared.loaded_hypernetworks.clear()
+
+ for i, name in enumerate(names):
+ hypernetwork = already_loaded.get(name, None)
+ if hypernetwork is None:
+ hypernetwork = load_hypernetwork(name)
+
+ if hypernetwork is None:
+ continue
+
+ hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
+ shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
@@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str): return applicable[0]
-def apply_hypernetwork(hypernetwork, context, layer=None):
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None:
- return context, context
+ return context_k, context_v
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
- context_k = hypernetwork_layers[0](context)
- context_v = hypernetwork_layers[1](context)
+ context_k = hypernetwork_layers[0](context_k)
+ context_v = hypernetwork_layers[1](context_v)
+ return context_k, context_v
+
+
+def apply_hypernetworks(hypernetworks, context, layer=None):
+ context_k = context
+ context_v = context
+ for hypernetwork in hypernetworks:
+ context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
+
return context_k, context_v
@@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x)
context = default(context, x)
- context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
+ context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
@@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
- shared.loaded_hypernetwork = Hypernetwork()
- shared.loaded_hypernetwork.load(path)
+ hypernetwork = Hypernetwork()
+ hypernetwork.load(path)
+ shared.loaded_hypernetworks = [hypernetwork]
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
@@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else:
images_dir = None
- hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0
diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 81e3f519..76599f5a 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
+
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
@@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(*args):
-
- initial_hypernetwork = shared.loaded_hypernetwork
+ shared.loaded_hypernetworks = []
assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'
@@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception:
raise
finally:
- shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
sd_hijack.apply_optimizations()
diff --git a/modules/interrogate.py b/modules/interrogate.py index 738d8ff7..19938cbb 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -5,12 +5,13 @@ from collections import namedtuple import re
import torch
+import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
-from modules import devices, paths, lowvram, modelloader
+from modules import devices, paths, lowvram, modelloader, errors
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.")
+def download_default_clip_interrogate_categories(content_dir):
+ print("Downloading CLIP categories...")
+
+ tmpdir = content_dir + "_tmp"
+ try:
+ os.makedirs(tmpdir)
+
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
+ torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
+
+ os.rename(tmpdir, content_dir)
+
+ except Exception as e:
+ errors.display(e, "downloading default CLIP interrogate categories")
+ finally:
+ if os.path.exists(tmpdir):
+ os.remove(tmpdir)
+
+
class InterrogateModels:
blip_model = None
clip_model = None
clip_preprocess = None
- categories = None
dtype = None
running_on_cpu = None
def __init__(self, content_dir):
- self.categories = []
+ self.loaded_categories = None
+ self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
- if os.path.exists(content_dir):
- for filename in os.listdir(content_dir):
+ def categories(self):
+ if self.loaded_categories is not None:
+ return self.loaded_categories
+
+ self.loaded_categories = []
+
+ if not os.path.exists(self.content_dir):
+ download_default_clip_interrogate_categories(self.content_dir)
+
+ if os.path.exists(self.content_dir):
+ for filename in os.listdir(self.content_dir):
m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1))
- with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
+ with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()]
- self.categories.append(Category(name=filename, topn=topn, items=lines))
+ self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
+
+ return self.loaded_categories
def load_blip_model(self):
import models.blip
@@ -139,7 +172,6 @@ class InterrogateModels: shared.state.begin()
shared.state.job = 'interrogate'
try:
-
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -159,12 +191,7 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True)
- if shared.opts.interrogate_use_builtin_artists:
- artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
-
- res += ", " + artist[0]
-
- for name, topn, items in self.categories:
+ for name, topn, items in self.categories():
matches = self.rank(image_features, items, top_count=topn)
for match, score in matches:
if shared.opts.interrogate_return_ranks:
diff --git a/modules/processing.py b/modules/processing.py index a3e9f709..6e6371a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
- "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
- "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()),
- "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
@@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try:
for k, v in p.override_settings.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork':
- shared.reload_hypernetworks() # make onchange call for changing hypernet
if k == 'sd_model_checkpoint':
- sd_models.reload_model_weights() # make onchange call for changing SD model
+ sd_models.reload_model_weights()
if k == 'sd_vae':
- sd_vae.reload_vae_weights() # make onchange call for changing VAE
+ sd_vae.reload_vae_weights()
res = process_images_inner(p)
@@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
- if k == 'sd_hypernetwork': shared.reload_hypernetworks()
- if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
- if k == 'sd_vae': sd_vae.reload_vae_weights()
+ if k == 'sd_model_checkpoint':
+ sd_models.reload_model_weights()
+
+ if k == 'sd_vae':
+ sd_vae.reload_vae_weights()
return res
@@ -535,6 +532,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
+ _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1])
+
if p.scripts is not None:
p.scripts.process(p)
@@ -568,6 +567,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
+ extra_networks.activate(p, extra_network_data)
+
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
@@ -592,6 +593,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(prompts) == 0:
break
+ prompts, _ = extra_networks.parse_prompts(prompts)
+
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
@@ -681,6 +684,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ extra_networks.deactivate(p, extra_network_data)
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a9e19236..4bb45ec7 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -73,6 +73,7 @@ callback_map = dict( callbacks_image_grid=[],
callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
+ callbacks_before_ui=[],
)
@@ -189,6 +190,14 @@ def script_unloaded_callback(): report_exception(c, 'script_unloaded')
+def before_ui_callback():
+ for c in reversed(callback_map['callbacks_before_ui']):
+ try:
+ c.callback()
+ except Exception:
+ report_exception(c, 'before_ui')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -313,3 +322,9 @@ def on_script_unloaded(callback): the script did should be reverted here"""
add_callback(callback_map['callbacks_script_unloaded'], callback)
+
+
+def on_before_ui(callback):
+ """register a function to be called before the UI is created."""
+
+ add_callback(callback_map['callbacks_before_ui'], callback)
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 870eba88..f9652d21 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -69,6 +69,13 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+def fix_checkpoint():
+ """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
+ checkpoints to be added when not training (there's a warning)"""
+
+ pass
+
+
class StableDiffusionModelHijack:
fixes = None
comments = []
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index cdc63ed7..4fa54329 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
del context, context_k, context_v, x
@@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
@@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k) * self.scale
v = self.to_v(context_v)
del context, context_k, context_v, x
@@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): q = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
@@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x)
context = default(context, x)
- context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
diff --git a/modules/sd_models.py b/modules/sd_models.py index 6a681cef..12083848 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,14 +41,16 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"):
name = name[1:]
- self.title = name
+ self.name = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
self.shorthash = self.sha256[0:10] if self.sha256 else None
- self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
+ self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+
+ self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -56,13 +58,15 @@ class CheckpointInfo: checkpoint_alisases[id] = self
def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256]
self.register()
+ self.title = f'{self.name} [{self.shorthash}]'
+
return self.shorthash
@@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None def load_model_weights(model, checkpoint_info: CheckpointInfo):
+ title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash()
+ if checkpoint_info.title != title:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
cache_enabled = shared.opts.sd_checkpoint_cache > 0
diff --git a/modules/shared.py b/modules/shared.py index 2f366454..52bbb807 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -9,7 +9,6 @@ from PIL import Image import gradio as gr
import tqdm
-import modules.artists
import modules.interrogate
import modules.memmon
import modules.styles
@@ -23,6 +22,7 @@ demo = None sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
+
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
@@ -100,6 +100,8 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
+parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
+
script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
@@ -145,7 +147,7 @@ config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = {}
-loaded_hypernetwork = None
+loaded_hypernetworks = []
def reload_hypernetworks():
@@ -153,8 +155,6 @@ def reload_hypernetworks(): global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
- hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
-
class State:
@@ -255,8 +255,6 @@ class State: state = State()
state.server_start = time.time()
-artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
-
styles_filename = cmd_opts.styles_file
prompt_styles = modules.styles.StyleDatabase(styles_filename)
@@ -399,8 +397,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
- "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
- "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
@@ -411,7 +407,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
- "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@@ -422,7 +417,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
- "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
@@ -661,3 +655,17 @@ mem_mon.start() def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
+
+
+def html_path(filename):
+ return os.path.join(script_path, "html", filename)
+
+
+def html(filename):
+ path = html_path(filename)
+
+ if os.path.exists(path):
+ with open(path, encoding="utf8") as file:
+ return file.read()
+
+ return ""
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5a7be422..4e90f690 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -50,6 +50,7 @@ class Embedding: self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.optimizer_state_dict = None
+ self.filename = None
def save(self, filename):
embedding_data = {
@@ -182,6 +183,7 @@ class EmbeddingDatabase: embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
+ embedding.filename = path
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
diff --git a/modules/ui.py b/modules/ui.py index 0c5ba358..daebbc9f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
@@ -75,6 +75,7 @@ css_hide_progressbar = """ .wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
+.wrap.cover-bg .z-20::before { content:"" }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
.meta-text-center { display:none!important; }
@@ -90,6 +91,7 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️
+extra_networks_symbol = '\U0001F3B4' # 🎴
def plaintext_to_html(text):
@@ -227,17 +229,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di left, _ = os.path.splitext(filename)
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
- return [gr_show(True), None]
+ return [gr.update(), None]
def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB"))
- return gr_show(True) if prompt is None else prompt
+ return gr.update() if prompt is None else prompt
def interrogate_deepbooru(image):
prompt = deepbooru.model.tag(image)
- return gr_show(True) if prompt is None else prompt
+ return gr.update() if prompt is None else prompt
def create_seed_inputs(target_interface):
@@ -324,6 +326,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps):
try:
+ text, _ = extra_networks.parse_prompt(text)
+
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
@@ -335,43 +339,23 @@ def update_token_counter(text, steps): flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
- style_class = ' class="red"' if (token_count > max_length) else ""
- return f"<span {style_class}>{token_count}/{max_length}</span>"
+ return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
def create_toprow(is_img2img):
id_part = "img2img" if is_img2img else "txt2img"
- with gr.Row(elem_id="toprow"):
- with gr.Column(scale=6):
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
- placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
+ prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
- placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
- )
-
- with gr.Column(scale=1, elem_id="roll_col"):
- paste = gr.Button(value=paste_symbol, elem_id="paste")
- save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
- prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
- clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
+ negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
button_interrogate = None
button_deepbooru = None
@@ -380,7 +364,7 @@ def create_toprow(is_img2img): button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
- with gr.Column(scale=1):
+ with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
with gr.Row(elem_id=f"{id_part}_generate_box"):
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
@@ -398,11 +382,30 @@ def create_toprow(is_img2img): outputs=[],
)
- with gr.Row():
+ with gr.Row(elem_id=f"{id_part}_tools"):
+ paste = ToolButton(value=paste_symbol, elem_id="paste")
+ clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+ prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
+ save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
+
+ token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
+ negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
+
+ clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[prompt, negative_prompt],
+ outputs=[prompt, negative_prompt],
+ )
+
+ with gr.Row(elem_id=f"{id_part}_styles_row"):
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
+ return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
def setup_progressbar(*args, **kwargs):
@@ -439,7 +442,7 @@ def apply_setting(key, value): opts.data_labels[key].onchange()
opts.save(shared.config_filename)
- return value
+ return getattr(opts, key)
def update_generation_info(generation_info, html_info, img_index):
@@ -532,7 +535,7 @@ Requested path was: {f} generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
- _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }",
+ _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info, html_info],
outputs=[html_info, html_info],
)
@@ -597,6 +600,16 @@ def ordered_ui_categories(): yield category
+def get_value_for_setting(key):
+ value = getattr(opts, key)
+
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+ return gr.update(value=value, **args)
+
+
def create_ui():
import modules.img2img
import modules.txt2img
@@ -609,11 +622,15 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
+ txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
+ with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
for category in ordered_ui_categories():
@@ -785,15 +802,22 @@ def create_ui(): ]
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+
+ ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
+ img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
+ with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
+ from modules import ui_extra_networks
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+
with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
copy_image_buttons = []
@@ -1015,19 +1039,18 @@ def create_ui(): init_img_inpaint,
],
outputs=[img2img_prompt, dummy_component],
- show_progress=False,
)
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
img2img_interrogate.click(
- fn=lambda *args : process_interrogate(interrogate, *args),
+ fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args,
)
img2img_deepbooru.click(
- fn=lambda *args : process_interrogate(interrogate_deepbooru, *args),
+ fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
**interrogate_args,
)
@@ -1054,6 +1077,9 @@ def create_ui(): )
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
+ negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+
+ ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
img2img_paste_fields = [
(img2img_prompt, "Prompt"),
@@ -1180,10 +1206,19 @@ def create_ui(): outputs=[html, generation_info, html2],
)
+ def update_interp_description(value):
+ interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
+ interp_descriptions = {
+ "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
+ "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
+ "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
+ }
+ return interp_descriptions[value]
+
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact'):
- gr.HTML(value="<p style='margin-bottom: 2.5em'>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
+ interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
with FormRow(elem_id="modelmerger_models"):
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
@@ -1198,6 +1233,7 @@ def create_ui(): custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
+ interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
with FormRow():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
@@ -1600,7 +1636,7 @@ def create_ui(): opts.save(shared.config_filename)
- return gr.update(value=value), opts.dumpjson()
+ return get_value_for_setting(key), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row():
@@ -1657,10 +1693,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
- if os.path.exists("html/licenses.html"):
- with open("html/licenses.html", encoding="utf8") as file:
- with gr.TabItem("Licenses"):
- gr.HTML(file.read(), elem_id="licenses")
+ with gr.TabItem("Licenses"):
+ gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
@@ -1747,11 +1781,9 @@ def create_ui(): if os.path.exists(os.path.join(script_path, "notification.mp3")):
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
- if os.path.exists("html/footer.html"):
- with open("html/footer.html", encoding="utf8") as file:
- footer = file.read()
- footer = footer.format(versions=versions_html())
- gr.HTML(footer, elem_id="footer")
+ footer = shared.html("footer.html")
+ footer = footer.format(versions=versions_html())
+ gr.HTML(footer, elem_id="footer")
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click(
@@ -1771,15 +1803,6 @@ def create_ui(): component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
- def get_value_for_setting(key):
- value = getattr(opts, key)
-
- info = opts.data_labels[key]
- args = info.component_args() if callable(info.component_args) else info.component_args or {}
- args = {k: v for k, v in args.items() if k not in {'precision'}}
-
- return gr.update(value=value, **args)
-
def get_settings_values():
return [get_value_for_setting(key) for key in component_keys]
@@ -1885,7 +1908,7 @@ def create_ui(): if type(x) == gr.Dropdown:
def check_dropdown(val):
- if x.multiselect:
+ if getattr(x, 'multiselect', False):
return all([value in x.choices for value in val])
else:
return val in x.choices
@@ -1902,28 +1925,27 @@ def create_ui(): with open(ui_config_file, "w", encoding="utf8") as file:
json.dump(ui_settings, file, indent=4)
+ # Required as a workaround for change() event not triggering when loading values from ui-config.json
+ interp_description.value = update_interp_description(interp_method.value)
+
return demo
def reload_javascript():
- with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
- javascript = f'<script>{jsfile.read()}</script>'
-
- scripts_list = modules.scripts.list_scripts("javascript", ".js")
-
- for basedir, filename, path in scripts_list:
- with open(path, "r", encoding="utf8") as jsfile:
- javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
+ head = f'<script type="text/javascript" src="file={os.path.abspath("script.js")}"></script>\n'
+ inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
- javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
+ inline += f"set_theme('{cmd_opts.theme}');"
+
+ head += f'<script type="text/javascript">{inline}</script>\n'
- javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
+ for script in modules.scripts.list_scripts("javascript", ".js"):
+ head += f'<script type="text/javascript" src="file={script.path}"></script>\n'
def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
- res.body = res.body.replace(
- b'</head>', f'{javascript}</head>'.encode("utf8"))
+ res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8"))
res.init_headers()
return res
diff --git a/modules/ui_components.py b/modules/ui_components.py index 97acff06..46324425 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent): return "button"
+class ToolButtonTop(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool-top", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
class FormRow(gr.Row, gr.components.FormComponent):
"""Same as gr.Row but fits inside gradio forms"""
diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 00000000..e2e060c8 --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,155 @@ +import os.path
+
+from modules import shared
+import gradio as gr
+import json
+
+from modules.generation_parameters_copypaste import image_from_url_text
+
+extra_pages = []
+
+
+def register_page(page):
+ """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
+
+ extra_pages.append(page)
+
+
+class ExtraNetworksPage:
+ def __init__(self, title):
+ self.title = title
+ self.name = title.lower()
+ self.card_page = shared.html("extra-networks-card.html")
+ self.allow_negative_prompt = False
+
+ def refresh(self):
+ pass
+
+ def create_html(self, tabname):
+ items_html = ''
+
+ for item in self.list_items():
+ items_html += self.create_html_for_item(item, tabname)
+
+ if items_html == '':
+ dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
+ items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
+
+ res = f"""
+<div id='{tabname}_{self.name}_cards' class='extra-network-cards'>
+{items_html}
+</div>
+"""
+
+ return res
+
+ def list_items(self):
+ raise NotImplementedError()
+
+ def allowed_directories_for_previews(self):
+ return []
+
+ def create_html_for_item(self, item, tabname):
+ preview = item.get("preview", None)
+
+ args = {
+ "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '',
+ "prompt": json.dumps(item["prompt"]),
+ "tabname": json.dumps(tabname),
+ "local_preview": json.dumps(item["local_preview"]),
+ "name": item["name"],
+ "allow_negative_prompt": "true" if self.allow_negative_prompt else "false",
+ }
+
+ return self.card_page.format(**args)
+
+
+def intialize():
+ extra_pages.clear()
+
+
+class ExtraNetworksUi:
+ def __init__(self):
+ self.pages = None
+ self.stored_extra_pages = None
+
+ self.button_save_preview = None
+ self.preview_target_filename = None
+
+ self.tabname = None
+
+
+def create_ui(container, button, tabname):
+ ui = ExtraNetworksUi()
+ ui.pages = []
+ ui.stored_extra_pages = extra_pages.copy()
+ ui.tabname = tabname
+
+ with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
+ for page in ui.stored_extra_pages:
+ with gr.Tab(page.title):
+ page_elem = gr.HTML(page.create_html(ui.tabname))
+ ui.pages.append(page_elem)
+
+ filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
+ button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
+ button_close = gr.Button('Close', elem_id=tabname+"_extra_close")
+
+ ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
+ ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
+
+ button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container])
+ button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container])
+
+ def refresh():
+ res = []
+
+ for pg in ui.stored_extra_pages:
+ pg.refresh()
+ res.append(pg.create_html(ui.tabname))
+
+ return res
+
+ button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
+
+ return ui
+
+
+def path_is_parent(parent_path, child_path):
+ parent_path = os.path.abspath(parent_path)
+ child_path = os.path.abspath(child_path)
+
+ return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
+
+
+def setup_ui(ui, gallery):
+ def save_preview(index, images, filename):
+ if len(images) == 0:
+ print("There is no image in gallery to save as a preview.")
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ index = int(index)
+ index = 0 if index < 0 else index
+ index = len(images) - 1 if index >= len(images) else index
+
+ img_info = images[index if index >= 0 else 0]
+ image = image_from_url_text(img_info)
+
+ is_allowed = False
+ for extra_page in ui.stored_extra_pages:
+ if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
+ is_allowed = True
+ break
+
+ assert is_allowed, f'writing to {filename} is not allowed'
+
+ image.save(filename)
+
+ return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]
+
+ ui.button_save_preview.click(
+ fn=save_preview,
+ _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
+ inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
+ outputs=[*ui.pages]
+ )
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 00000000..312dbaf0 --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,34 @@ +import os
+
+from modules import shared, ui_extra_networks
+
+
+class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Hypernetworks')
+
+ def refresh(self):
+ shared.reload_hypernetworks()
+
+ def list_items(self):
+ for name, path in shared.hypernetworks.items():
+ path, ext = os.path.splitext(path)
+ previews = [path + ".png", path + ".preview.png"]
+
+ preview = None
+ for file in previews:
+ if os.path.isfile(file):
+ preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
+ break
+
+ yield {
+ "name": name,
+ "filename": path,
+ "preview": preview,
+ "prompt": f"<hypernet:{name}:1.0>",
+ "local_preview": path + ".png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return [shared.cmd_opts.hypernetwork_dir]
+
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 00000000..e4a6e3bf --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,32 @@ +import os
+
+from modules import ui_extra_networks, sd_hijack
+
+
+class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
+ def __init__(self):
+ super().__init__('Textual Inversion')
+ self.allow_negative_prompt = True
+
+ def refresh(self):
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
+
+ def list_items(self):
+ for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values():
+ path, ext = os.path.splitext(embedding.filename)
+ preview_file = path + ".preview.png"
+
+ preview = None
+ if os.path.isfile(preview_file):
+ preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file))
+
+ yield {
+ "name": embedding.name,
+ "filename": embedding.filename,
+ "preview": preview,
+ "prompt": embedding.name,
+ "local_preview": path + ".preview.png",
+ }
+
+ def allowed_directories_for_previews(self):
+ return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|