From ee71eee1818f6f6eba9895c93ba25e0cad27e069 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Apr 2023 12:36:50 +0300 Subject: stuff related to torch version change --- modules/safe.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'modules/safe.py') diff --git a/modules/safe.py b/modules/safe.py index 82d44be3..dadf319c 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -1,6 +1,5 @@ # this code is adapted from the script contributed by anon from /h/ -import io import pickle import collections import sys @@ -12,11 +11,9 @@ import _codecs import zipfile import re - # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage - def encode(*args): out = _codecs.encode(*args) return out @@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id): assert saved_id[0] == 'storage' - return TypedStorage() + return TypedStorage(_internal=True) def find_class(self, module, name): if self.extra_handler is not None: -- cgit v1.2.1 From 5ab7f213bec2f816f9c5644becb32eb72c8ffb89 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 2 May 2023 09:20:35 +0300 Subject: fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle --- modules/safe.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/safe.py') diff --git a/modules/safe.py b/modules/safe.py index dadf319c..e6c2f2c0 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -24,7 +24,11 @@ class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id): assert saved_id[0] == 'storage' - return TypedStorage(_internal=True) + + try: + return TypedStorage(_internal=True) + except TypeError: + return TypedStorage() # PyTorch before 2.0 does not have the _internal argument def find_class(self, module, name): if self.extra_handler is not None: -- cgit v1.2.1 From 550256db1ce18778a9d56ff343d844c61b9f9b83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:19:16 +0300 Subject: ruff manual fixes --- modules/safe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/safe.py') diff --git a/modules/safe.py b/modules/safe.py index e6c2f2c0..2d5b972f 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -104,7 +104,7 @@ def check_pt(filename, extra_handler): def load(filename, *args, **kwargs): - return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs) + return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs) def load_with_extra(filename, extra_handler=None, *args, **kwargs): -- cgit v1.2.1 From a5121e7a0623db328a9462d340d389ed6737374a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:37:18 +0300 Subject: fixes for B007 --- modules/safe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/safe.py') diff --git a/modules/safe.py b/modules/safe.py index 2d5b972f..1e791c5b 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -95,11 +95,11 @@ def check_pt(filename, extra_handler): except zipfile.BadZipfile: - # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle + # if it's not a zip file, it's an old pytorch format, with five objects written to pickle with open(filename, "rb") as file: unpickler = RestrictedUnpickler(file) unpickler.extra_handler = extra_handler - for i in range(5): + for _ in range(5): unpickler.load() -- cgit v1.2.1 From cb5f61281a95be72fc812b7d350b6ec23e2f9bdd Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 13 May 2023 11:04:26 -0400 Subject: Allow bf16 in safe unpickler --- modules/safe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/safe.py') diff --git a/modules/safe.py b/modules/safe.py index 1e791c5b..e8f50774 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -40,7 +40,7 @@ class RestrictedUnpickler(pickle.Unpickler): return getattr(collections, name) if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: return getattr(torch._utils, name) - if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']: + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']: return getattr(torch, name) if module == 'torch.nn.modules.container' and name in ['ParameterDict']: return getattr(torch.nn.modules.container, name) -- cgit v1.2.1