From abfa22c16fb3d9b1ed8d049c7b68e94d1cca5b82 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 19:25:43 -0500 Subject: Revert "MPS Upscalers Fix" This reverts commit 768b95394a8500da639b947508f78296524f1836. --- modules/devices.py | 9 --------- 1 file changed, 9 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..a87d0d4c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -94,12 +94,3 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): - return input_tensor.contiguous() if device.type == 'mps' else input_tensor - - -def mps_contiguous_to(input_tensor, device): - return mps_contiguous(input_tensor, device).to(device) -- cgit v1.2.1 From e247b7400a592c0a19c197cd080aeec38ee02b68 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 17 Nov 2022 03:52:17 -0500 Subject: Add fixes for PyTorch 1.12.1 Fix typo "MasOS" -> "macOS" If MPS is available and PyTorch is an earlier version than 1.13: * Monkey patch torch.Tensor.to to ensure all tensors sent to MPS are contiguous * Monkey patch torch.nn.functional.layer_norm to ensure input tensor is contiguous (required for this program to work with MPS on unmodified PyTorch 1.12.1) --- modules/devices.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index a87d0d4c..6e8277e5 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -2,9 +2,10 @@ import sys, os, shlex import contextlib import torch from modules import errors +from packaging import version -# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: if not getattr(torch, 'has_mps', False): @@ -94,3 +95,28 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 +orig_tensor_to = torch.Tensor.to +def tensor_to_fix(self, *args, **kwargs): + if self.device.type != 'mps' and \ + ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ + (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): + self = self.contiguous() + return orig_tensor_to(self, *args, **kwargs) + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 +orig_layer_norm = torch.nn.functional.layer_norm +def layer_norm_fix(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': + args = list(args) + args[0] = args[0].contiguous() + return orig_layer_norm(*args, **kwargs) + + +# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working +if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix -- cgit v1.2.1 From c67c40f983997594f76b2312f92c3761e8d83715 Mon Sep 17 00:00:00 2001 From: Matthew McGoogan Date: Sat, 26 Nov 2022 23:25:16 +0000 Subject: torch.cuda.empty_cache() defaults to cuda:0 device unless explicitly set otherwise first. Updating torch_gc() to use the device set by --device-id if specified to avoid OOM edge cases on multi-GPU systems. --- modules/devices.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..93d82bbc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -44,8 +44,18 @@ def get_optimal_device(): def torch_gc(): if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + from modules import shared + + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + else: + cuda_device = "cuda" + + with torch.cuda.device(cuda_device): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def enable_tf32(): -- cgit v1.2.1 From 5b2c316890b7b8af95f0d0334d1fd34b9a687b99 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 13:08:54 +0300 Subject: eliminate duplicated code from #5095 --- modules/devices.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 93d82bbc..dd50fe24 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -24,17 +24,18 @@ def extract_device_id(args, name): return None -def get_optimal_device(): - if torch.cuda.is_available(): - from modules import shared +def get_cuda_device_string(): + from modules import shared + + if shared.cmd_opts.device_id is not None: + return f"cuda:{shared.cmd_opts.device_id}" - device_id = shared.cmd_opts.device_id + return "cuda" - if device_id is not None: - cuda_device = f"cuda:{device_id}" - return torch.device(cuda_device) - else: - return torch.device("cuda") + +def get_optimal_device(): + if torch.cuda.is_available(): + return torch.device(get_cuda_device_string()) if has_mps(): return torch.device("mps") @@ -44,16 +45,7 @@ def get_optimal_device(): def torch_gc(): if torch.cuda.is_available(): - from modules import shared - - device_id = shared.cmd_opts.device_id - - if device_id is not None: - cuda_device = f"cuda:{device_id}" - else: - cuda_device = "cuda" - - with torch.cuda.device(cuda_device): + with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() -- cgit v1.2.1