diff options
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 78 |
1 files changed, 67 insertions, 11 deletions
diff --git a/modules/devices.py b/modules/devices.py index 0158b11f..f00079c6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,20 +1,44 @@ +import sys, os, shlex import contextlib - import torch - from modules import errors +from packaging import version -# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility -has_mps = getattr(torch, 'has_mps', False) -cpu = torch.device("cpu") +# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. +# check `getattr` and try it for compatibility +def has_mps() -> bool: + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + + +def extract_device_id(args, name): + for x in range(len(args)): + if name in args[x]: + return args[x + 1] + + return None + + +def get_cuda_device_string(): + from modules import shared + + if shared.cmd_opts.device_id is not None: + return f"cuda:{shared.cmd_opts.device_id}" + + return "cuda" def get_optimal_device(): if torch.cuda.is_available(): - return torch.device("cuda") + return torch.device(get_cuda_device_string()) - if has_mps: + if has_mps(): return torch.device("mps") return cpu @@ -22,8 +46,9 @@ def get_optimal_device(): def torch_gc(): if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + with torch.cuda.device(get_cuda_device_string()): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def enable_tf32(): @@ -34,8 +59,11 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +cpu = torch.device("cpu") +device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 +dtype_vae = torch.float16 + def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. @@ -59,10 +87,38 @@ def randn_without_seed(shape): return torch.randn(shape, device=device) -def autocast(): +def autocast(disable=False): from modules import shared + if disable: + return contextlib.nullcontext() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() return torch.autocast("cuda") + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 +orig_tensor_to = torch.Tensor.to +def tensor_to_fix(self, *args, **kwargs): + if self.device.type != 'mps' and \ + ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ + (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): + self = self.contiguous() + return orig_tensor_to(self, *args, **kwargs) + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 +orig_layer_norm = torch.nn.functional.layer_norm +def layer_norm_fix(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': + args = list(args) + args[0] = args[0].contiguous() + return orig_layer_norm(*args, **kwargs) + + +# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working +if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix |