diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/devices.py | 60 | ||||
-rw-r--r-- | modules/generation_parameters_copypaste.py | 6 | ||||
-rw-r--r-- | modules/initialize_util.py | 2 | ||||
-rw-r--r-- | modules/processing.py | 2 | ||||
-rw-r--r-- | modules/sd_models.py | 49 | ||||
-rw-r--r-- | modules/sd_models_xl.py | 2 | ||||
-rw-r--r-- | modules/shared_options.py | 2 |
7 files changed, 117 insertions, 6 deletions
diff --git a/modules/devices.py b/modules/devices.py index ea1f712f..c956207f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -23,6 +23,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -73,8 +90,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -84,6 +100,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -104,12 +121,51 @@ def cond_cast_float(input): nv_rng = None +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + +@contextlib.contextmanager +def manual_cast(): + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward def autocast(disable=False): if disable: return contextlib.nullcontext() + if fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_cast() + + if has_mps() and shared.cmd_opts.precision != "full": + return manual_cast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4efe53e0..dbffe494 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -314,6 +314,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "VAE Decoder" not in res:
res["VAE Decoder"] = "Full"
+ if "FP8 weight" not in res:
+ res["FP8 weight"] = "Disable"
+
+ if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
+ res["Cache FP16 weight for LoRA"] = False
+
skip = set(shared.opts.infotext_skip_pasting)
res = {k: v for k, v in res.items() if k not in skip}
diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d89..b6767138 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,8 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
+ shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
+ shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange")
diff --git a/modules/processing.py b/modules/processing.py index bea01ec6..179f2c0f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -688,6 +688,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}",
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
+ "FP8 weight": opts.fp8_storage if devices.fp8 else None,
+ "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
diff --git a/modules/sd_models.py b/modules/sd_models.py index 9355f1e1..d0046f88 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -348,10 +348,28 @@ class SkipWritingToConfig: SkipWritingToConfig.skip = self.previous
+def check_fp8(model):
+ if model is None:
+ return None
+ if devices.get_optimal_device_name() == "mps":
+ enable_fp8 = False
+ elif shared.opts.fp8_storage == "Enable":
+ enable_fp8 = True
+ elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
+ enable_fp8 = True
+ else:
+ enable_fp8 = False
+ return enable_fp8
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
+ if devices.fp8:
+ # prevent model to load state dict in fp8
+ model.half()
+
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -404,6 +422,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16
timer.record("apply half()")
+ for module in model.modules():
+ if hasattr(module, 'fp16_weight'):
+ del module.fp16_weight
+ if hasattr(module, 'fp16_bias'):
+ del module.fp16_bias
+
+ if check_fp8(model):
+ devices.fp8 = True
+ first_stage = model.first_stage_model
+ model.first_stage_model = None
+ for module in model.modules():
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ if shared.opts.cache_fp16_weight:
+ module.fp16_weight = module.weight.data.clone().cpu().half()
+ if module.bias is not None:
+ module.fp16_bias = module.bias.data.clone().cpu().half()
+ module.to(torch.float8_e4m3fn)
+ model.first_stage_model = first_stage
+ timer.record("apply fp8")
+ else:
+ devices.fp8 = False
+
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -746,7 +786,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None
-def reload_model_weights(sd_model=None, info=None):
+def reload_model_weights(sd_model=None, info=None, forced_reload=False):
checkpoint_info = info or select_checkpoint()
timer = Timer()
@@ -758,11 +798,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
+ if check_fp8(sd_model) != devices.fp8:
+ # load from state dict again to prevent extra numerical errors
+ forced_reload = True
+ elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
- if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 01123321..11259a36 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -93,7 +93,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
- model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
+ model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
model.conditioner.wrapped = torch.nn.Module()
diff --git a/modules/shared_options.py b/modules/shared_options.py index d2e86ff1..d470eb8f 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -206,6 +206,8 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
+ "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
+ "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|