aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2022-10-25 02:01:57 -0400
committerAUTOMATIC1111 <16777216c@gmail.com>2022-10-25 09:42:53 +0300
commitfaed465a0b1a7d19669568738c93e04907c10415 (patch)
tree912274ef626bababc846ee9849bc4b390e968c44 /modules/devices.py
parent4c24347e45776d505937856ab280548d9298f0a8 (diff)
MPS Upscalers Fix
Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device.
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 033a42d5..7511e1dc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -81,3 +81,7 @@ def autocast(disable=False):
return contextlib.nullcontext()
return torch.autocast("cuda")
+
+# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
+def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
+def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)