aboutsummaryrefslogtreecommitdiff
path: root/modules/devices.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-11-27 12:56:02 +0300
committerGitHub <noreply@github.com>2022-11-27 12:56:02 +0300
commit997ac57020b734894dd9fb19301e80bc52d7de72 (patch)
treeb37af34db84b8192790a412461142c34409354b5 /modules/devices.py
parentd860b56c219a45b274c2147521b2a5823ee90a15 (diff)
parentc67c40f983997594f76b2312f92c3761e8d83715 (diff)
Merge pull request #5095 from mlmcgoogan/master
torch.cuda.empty_cache() defaults to cuda:0 device unless explicitly …
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 67165bf6..93d82bbc 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -44,8 +44,18 @@ def get_optimal_device():
def torch_gc():
if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
+ from modules import shared
+
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ else:
+ cuda_device = "cuda"
+
+ with torch.cuda.device(cuda_device):
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
def enable_tf32():