aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
authorZac Liu <liuguang@baai.ac.cn>2022-12-06 09:16:15 +0800
committerGitHub <noreply@github.com>2022-12-06 09:16:15 +0800
commit3ebf977a6e4f478ab918e44506974beee32da276 (patch)
treef68456207e5cd78718ec1e9c588ecdc22d568d81 /modules/sd_hijack.py
parent231fb72872191ffa8c446af1577c9003b3d19d4f (diff)
parent44c46f0ed395967cd3830dd481a2db759fda5b3b (diff)
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index edb8b420..eb679ef9 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -17,6 +17,7 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
+import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
@@ -189,11 +190,7 @@ def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != devices.device:
-
- if devices.has_mps():
- attr = attr.to(device="mps", dtype=torch.float32)
- else:
- attr = attr.to(devices.device)
+ attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)