From 24129368f1b732be25ef486edb2cf5a6ace66737 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 27 Jun 2023 09:19:04 +0300 Subject: send tensors to the correct device when loading from safetensors file with memmap disabled for #11260 --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 0391398a..f65f4e36 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -246,11 +246,13 @@ def read_metadata_from_safetensors(filename): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() + if not shared.opts.disable_mmap_load_safetensors: - device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read()) + pl_sd = {k: v.to(device) for k, v in pl_sd.items()} else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.1