diff options
author | Zac Liu <liuguang@baai.ac.cn> | 2022-12-06 09:16:15 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-06 09:16:15 +0800 |
commit | 3ebf977a6e4f478ab918e44506974beee32da276 (patch) | |
tree | f68456207e5cd78718ec1e9c588ecdc22d568d81 /modules/sd_hijack.py | |
parent | 231fb72872191ffa8c446af1577c9003b3d19d4f (diff) | |
parent | 44c46f0ed395967cd3830dd481a2db759fda5b3b (diff) |
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index edb8b420..eb679ef9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -17,6 +17,7 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
@@ -189,11 +190,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor:
if attr.device != devices.device:
-
- if devices.has_mps():
- attr = attr.to(device="mps", dtype=torch.float32)
- else:
- attr = attr.to(devices.device)
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)
|