diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-10-22 13:58:00 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-22 13:58:00 +0300 |
commit | e80bdcab91df0d91fa268991bee1d0143e81920a (patch) | |
tree | 347f8cbcdf644885fcf3481ed7a2dc55f8942c6e /modules | |
parent | 5aa9525046b7520d39fe8fc8c5c6cc10ab4d5fdb (diff) | |
parent | 1fa53dab2c5a857b9773f904fadf853dac1f1bd6 (diff) |
Merge pull request #3377 from Extraltodeus/cuda-device-id-selection
Implementation of CUDA device id selection (--device-id 0/1/2)
Diffstat (limited to 'modules')
-rw-r--r-- | modules/devices.py | 21 | ||||
-rw-r--r-- | modules/shared.py | 1 |
2 files changed, 19 insertions, 3 deletions
diff --git a/modules/devices.py b/modules/devices.py index eb422583..8a159282 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ +import sys, os, shlex import contextlib - import torch - from modules import errors # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility @@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") +def extract_device_id(args, name): + for x in range(len(args)): + if name in args[x]: return args[x+1] + return None def get_optimal_device(): if torch.cuda.is_available(): - return torch.device("cuda") + # CUDA device selection support: + if "shared" not in sys.modules: + commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop. + sys.argv += shlex.split(commandline_args) + device_id = extract_device_id(sys.argv, '--device-id') + else: + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + return torch.device(cuda_device) + else: + return torch.device("cuda") if has_mps: return torch.device("mps") diff --git a/modules/shared.py b/modules/shared.py index 7d786f07..5d83971e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -79,6 +79,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
+parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
cmd_opts = parser.parse_args()
|