diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-19 08:36:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-19 08:36:20 +0300 |
commit | 448d6bef372079cbd6d5a3acd8fbfd6f03799ee3 (patch) | |
tree | 46741eb48d9588db94f56d134f8af952b71bd514 /modules/sd_models.py | |
parent | 7056fdf2bee50e5952cc0bac2047e96de336a36a (diff) | |
parent | 0dc74545c0b5510911757ed9f2be703aab58f014 (diff) |
Merge pull request #12599 from AUTOMATIC1111/ram_optim
RAM optimization round 2
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd6..685585b1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -343,7 +343,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
- if not shared.cmd_opts.no_half:
+ if shared.cmd_opts.no_half:
+ model.float()
+ timer.record("apply float()")
+ else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
@@ -518,6 +521,13 @@ def send_model_to_cpu(m): devices.torch_gc()
+def model_target_device():
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ return devices.cpu
+ else:
+ return devices.device
+
+
def send_model_to_device(m):
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
@@ -579,7 +589,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model")
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ if shared.cmd_opts.no_half:
+ weight_dtype_conversion = None
+ else:
+ weight_dtype_conversion = {
+ 'first_stage_model': None,
+ '': torch.float16,
+ }
+
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
|