From ac90cf38c6b55d57d37923aa1fe86c7374e32d0b Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Tue, 22 Nov 2022 10:13:07 -0500 Subject: safetensors optional for now --- modules/sd_models.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 2bbb3bf5..75f7ab09 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,6 @@ import sys import gc from collections import namedtuple import torch -from safetensors.torch import load_file, save_file import re from omegaconf import OmegaConf @@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None): # safely load weights # TODO: safetensors supports zero copy fast load to gpu, see issue #684. # GPU only for now, see https://github.com/huggingface/safetensors/issues/95 + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") return load_file(model_filename, device='cuda') else: return torch.load(model_filename, map_location=map_override) @@ -157,6 +160,10 @@ def torch_save(model, output_filename): basename, exttype = os.path.splitext(output_filename) if(checkpoint_types[exttype] == 'safetensors'): # [===== >] Reticulating brines... + try: + from safetensors.torch import save_file + except ImportError as e: + raise ImportError(f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: {e}") save_file(model, output_filename, metadata={"format": "pt"}) else: torch.save(model, output_filename) -- cgit v1.2.1