diff options
-rw-r--r-- | modules/extras.py | 14 | ||||
-rw-r--r-- | modules/sd_hijack.py | 7 | ||||
-rw-r--r-- | modules/sd_models.py | 15 | ||||
-rw-r--r-- | modules/ui.py | 16 |
4 files changed, 38 insertions, 14 deletions
diff --git a/modules/extras.py b/modules/extras.py index d03f976e..1218f88f 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -275,7 +275,7 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename)
-chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
def to_half(tensor, enable):
@@ -303,7 +303,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
- def filename_weighed_sum():
+ def filename_weighted_sum():
a = primary_model_info.model_name
b = secondary_model_info.model_name
Ma = round(1 - multiplier, 2)
@@ -311,7 +311,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ return f"{Ma}({a}) + {Mb}({b})"
- def filename_add_differnece():
+ def filename_add_difference():
a = primary_model_info.model_name
b = secondary_model_info.model_name
c = tertiary_model_info.model_name
@@ -323,8 +323,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ return primary_model_info.model_name
theta_funcs = {
- "Weighted sum": (filename_weighed_sum, None, weighted_sum),
- "Add difference": (filename_add_differnece, get_difference, add_difference),
+ "Weighted sum": (filename_weighted_sum, None, weighted_sum),
+ "Add difference": (filename_add_difference, get_difference, add_difference),
"No interpolation": (filename_nothing, None, None),
}
filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
@@ -362,7 +362,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = 'Merging B and C'
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
- if key in chckpoint_dict_skip_on_merge:
+ if key in checkpoint_dict_skip_on_merge:
continue
if 'model' in key:
@@ -387,7 +387,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ for key in tqdm.tqdm(theta_0.keys()):
if theta_1 and 'model' in key and key in theta_1:
- if key in chckpoint_dict_skip_on_merge:
+ if key in checkpoint_dict_skip_on_merge:
continue
a = theta_0[key]
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 870eba88..f9652d21 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -69,6 +69,13 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+def fix_checkpoint():
+ """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
+ checkpoints to be added when not training (there's a warning)"""
+
+ pass
+
+
class StableDiffusionModelHijack:
fixes = None
comments = []
diff --git a/modules/sd_models.py b/modules/sd_models.py index 6a681cef..12083848 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,14 +41,16 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"):
name = name[1:]
- self.title = name
+ self.name = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title)
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
self.shorthash = self.sha256[0:10] if self.sha256 else None
- self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else [])
+ self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+
+ self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -56,13 +58,15 @@ class CheckpointInfo: checkpoint_alisases[id] = self
def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title)
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256]
self.register()
+ self.title = f'{self.name} [{self.shorthash}]'
+
return self.shorthash
@@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None def load_model_weights(model, checkpoint_info: CheckpointInfo):
+ title = checkpoint_info.title
sd_model_hash = checkpoint_info.calculate_shorthash()
+ if checkpoint_info.title != title:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
cache_enabled = shared.opts.sd_checkpoint_cache > 0
diff --git a/modules/ui.py b/modules/ui.py index af416d5f..13d80ae2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -439,7 +439,7 @@ def apply_setting(key, value): opts.data_labels[key].onchange()
opts.save(shared.config_filename)
- return value
+ return getattr(opts, key)
def update_generation_info(generation_info, html_info, img_index):
@@ -597,6 +597,16 @@ def ordered_ui_categories(): yield category
+def get_value_for_setting(key):
+ value = getattr(opts, key)
+
+ info = opts.data_labels[key]
+ args = info.component_args() if callable(info.component_args) else info.component_args or {}
+ args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+ return gr.update(value=value, **args)
+
+
def create_ui():
import modules.img2img
import modules.txt2img
@@ -1600,7 +1610,7 @@ def create_ui(): opts.save(shared.config_filename)
- return gr.update(value=value), opts.dumpjson()
+ return get_value_for_setting(key), opts.dumpjson()
with gr.Blocks(analytics_enabled=False) as settings_interface:
with gr.Row():
@@ -1772,7 +1782,7 @@ def create_ui(): component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
def get_settings_values():
- return [getattr(opts, key) for key in component_keys]
+ return [get_value_for_setting(key) for key in component_keys]
demo.load(
fn=get_settings_values,
|