From fdecb636855748e03efc40c846a0043800aadfcc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 09:05:06 +0300 Subject: add an ability to merge three checkpoints --- modules/extras.py | 29 +++++++++++++++++++++-------- modules/ui.py | 11 +++++++---- 2 files changed, 28 insertions(+), 12 deletions(-) (limited to 'modules') diff --git a/modules/extras.py b/modules/extras.py index b24d7de3..532d869f 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -159,48 +159,61 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name): +def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name): # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) - def weighted_sum(theta0, theta1, alpha): + def weighted_sum(theta0, theta1, theta2, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def sigmoid(theta0, theta1, alpha): + def sigmoid(theta0, theta1, theta2, alpha): alpha = alpha * alpha * (3 - (2 * alpha)) return theta0 + ((theta1 - theta0) * alpha) # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def inv_sigmoid(theta0, theta1, alpha): + def inv_sigmoid(theta0, theta1, theta2, alpha): import math alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) return theta0 + ((theta1 - theta0) * alpha) + def add_difference(theta0, theta1, theta2, alpha): + return theta0 + (theta1 - theta2) * (1.0 - alpha) + primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] + teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) print(f"Loading {primary_model_info.filename}...") primary_model = torch.load(primary_model_info.filename, map_location='cpu') + theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) print(f"Loading {secondary_model_info.filename}...") secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') - - theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) + if teritary_model_info is not None: + print(f"Loading {teritary_model_info.filename}...") + teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') + theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) + else: + theta_2 = None + theta_funcs = { "Weighted Sum": weighted_sum, "Sigmoid": sigmoid, "Inverse Sigmoid": inv_sigmoid, + "Add difference": add_difference, } theta_func = theta_funcs[interp_method] print(f"Merging...") + for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint if save_as_half: theta_0[key] = theta_0[key].half() + # I believe this part should be discarded, but I'll leave it for now until I am sure for key in theta_1.keys(): if 'model' in key and key not in theta_0: theta_0[key] = theta_1[key] @@ -219,4 +232,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int sd_models.list_models() print(f"Checkpoint saved.") - return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)] + return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] diff --git a/modules/ui.py b/modules/ui.py index 7446439d..220fb80b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1024,11 +1024,12 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="
A merger of the two checkpoints will be generated in your checkpoint directory.
") with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) - interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3) + interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Save as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') @@ -1473,6 +1474,7 @@ Requested path was: {f} inputs=[ primary_model_name, secondary_model_name, + tertiary_model_name, interp_method, interp_amount, save_as_half, @@ -1482,6 +1484,7 @@ Requested path was: {f} submit_result, primary_model_name, secondary_model_name, + tertiary_model_name, component_dict['sd_model_checkpoint'], ] ) -- cgit v1.2.1