aboutsummaryrefslogtreecommitdiff
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2022-10-25 02:01:57 -0400
committerAUTOMATIC1111 <16777216c@gmail.com>2022-10-25 09:42:53 +0300
commitfaed465a0b1a7d19669568738c93e04907c10415 (patch)
tree912274ef626bababc846ee9849bc4b390e968c44 /modules/esrgan_model.py
parent4c24347e45776d505937856ab280548d9298f0a8 (diff)
MPS Upscalers Fix
Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device.
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a49e2258..a13cf6ac 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -190,7 +190,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
+ img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()