diff options
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 21 |
1 files changed, 5 insertions, 16 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 7f8502f5..3b6cdea1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,7 +1,6 @@ import collections
import os.path
import sys
-import gc
import threading
import torch
@@ -357,12 +356,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
- model.load_state_dict(state_dict, strict=False)
- timer.record("apply weights to model")
-
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = state_dict
+ checkpoints_loaded[checkpoint_info] = state_dict.copy()
+
+ model.load_state_dict(state_dict, strict=False)
+ timer.record("apply weights to model")
del state_dict
@@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
- timer = Timer()
-
- if model_data.sd_model:
- model_data.sd_model.to(devices.cpu)
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
- model_data.sd_model = None
- sd_model = None
- gc.collect()
- devices.torch_gc()
-
- print(f"Unloaded weights {timer.summary()}.")
+ send_model_to_cpu(sd_model or shared.sd_model)
return sd_model
|