From 76ab31e18898d4c2aacb9725cfbe25b230bff974 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 12 Nov 2022 11:02:40 +0800 Subject: Fix wrong mps selection below MasOS 12.3 --- modules/devices.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 7511e1dc..9a3d29d7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,8 +3,15 @@ import contextlib import torch from modules import errors -# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility -has_mps = getattr(torch, 'has_mps', False) +# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# check `getattr` and try it for compatibility +def has_mps() -> bool: + if getattr(torch, 'has_mps', False): return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False cpu = torch.device("cpu") @@ -25,7 +32,7 @@ def get_optimal_device(): else: return torch.device("cuda") - if has_mps: + if has_mps(): return torch.device("mps") return cpu -- cgit v1.2.1 From 1130d5df669911a5c67696be90bccca3ecf5f487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 12 Nov 2022 11:09:28 +0800 Subject: Update devices.py --- modules/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 9a3d29d7..bd3e4ffb 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -6,7 +6,7 @@ from modules import errors # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: - if getattr(torch, 'has_mps', False): return False + if not getattr(torch, 'has_mps', False): return False try: torch.zeros(1).to(torch.device("mps")) return True -- cgit v1.2.1 From 0ab0a50f9ae14bd7ce7ec518323ebd31c7971155 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:00:49 +0300 Subject: change formatting to match the main program in devices.py --- modules/devices.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index bd3e4ffb..67165bf6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,23 +3,27 @@ import contextlib import torch from modules import errors + # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: - if not getattr(torch, 'has_mps', False): return False + if not getattr(torch, 'has_mps', False): + return False try: torch.zeros(1).to(torch.device("mps")) return True except Exception: return False -cpu = torch.device("cpu") def extract_device_id(args, name): for x in range(len(args)): - if name in args[x]: return args[x+1] + if name in args[x]: + return args[x + 1] + return None + def get_optimal_device(): if torch.cuda.is_available(): from modules import shared @@ -52,10 +56,12 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") +cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 + def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': @@ -89,6 +95,11 @@ def autocast(disable=False): return torch.autocast("cuda") + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor -def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) +def mps_contiguous(input_tensor, device): + return input_tensor.contiguous() if device.type == 'mps' else input_tensor + + +def mps_contiguous_to(input_tensor, device): + return mps_contiguous(input_tensor, device).to(device) -- cgit v1.2.1