From abfa22c16fb3d9b1ed8d049c7b68e94d1cca5b82 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 19:25:43 -0500 Subject: Revert "MPS Upscalers Fix" This reverts commit 768b95394a8500da639b947508f78296524f1836. --- modules/devices.py | 9 --------- 1 file changed, 9 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..a87d0d4c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -94,12 +94,3 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): - return input_tensor.contiguous() if device.type == 'mps' else input_tensor - - -def mps_contiguous_to(input_tensor, device): - return mps_contiguous(input_tensor, device).to(device) -- cgit v1.2.1 From e247b7400a592c0a19c197cd080aeec38ee02b68 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 17 Nov 2022 03:52:17 -0500 Subject: Add fixes for PyTorch 1.12.1 Fix typo "MasOS" -> "macOS" If MPS is available and PyTorch is an earlier version than 1.13: * Monkey patch torch.Tensor.to to ensure all tensors sent to MPS are contiguous * Monkey patch torch.nn.functional.layer_norm to ensure input tensor is contiguous (required for this program to work with MPS on unmodified PyTorch 1.12.1) --- modules/devices.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index a87d0d4c..6e8277e5 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -2,9 +2,10 @@ import sys, os, shlex import contextlib import torch from modules import errors +from packaging import version -# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: if not getattr(torch, 'has_mps', False): @@ -94,3 +95,28 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 +orig_tensor_to = torch.Tensor.to +def tensor_to_fix(self, *args, **kwargs): + if self.device.type != 'mps' and \ + ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ + (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): + self = self.contiguous() + return orig_tensor_to(self, *args, **kwargs) + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 +orig_layer_norm = torch.nn.functional.layer_norm +def layer_norm_fix(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': + args = list(args) + args[0] = args[0].contiguous() + return orig_layer_norm(*args, **kwargs) + + +# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working +if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix -- cgit v1.2.1