diff options
author | JaredTherriault <noirjt@live.com> | 2023-09-04 17:29:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-04 17:29:33 -0700 |
commit | 5e16914a4e157ab3ed96f8b7841e1290a56f4484 (patch) | |
tree | 655f4582e692f0fc3667b3b668ad365ac3ab92ae /modules/patches.py | |
parent | 8f3b02f09535f55d3673aa9ea589396b8614f799 (diff) | |
parent | 5ef669de080814067961f28357256e8fe27544f4 (diff) |
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/patches.py')
-rw-r--r-- | modules/patches.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/modules/patches.py b/modules/patches.py new file mode 100644 index 00000000..348235e7 --- /dev/null +++ b/modules/patches.py @@ -0,0 +1,64 @@ +from collections import defaultdict
+
+
+def patch(key, obj, field, replacement):
+ """Replaces a function in a module or a class.
+
+ Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
+ If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+ replacement: the new function
+
+ Returns:
+ the original function
+ """
+
+ patch_key = (obj, field)
+ if patch_key in originals[key]:
+ raise RuntimeError(f"patch for {field} is already applied")
+
+ original_func = getattr(obj, field)
+ originals[key][patch_key] = original_func
+
+ setattr(obj, field, replacement)
+
+ return original_func
+
+
+def undo(key, obj, field):
+ """Undoes the peplacement by the patch().
+
+ If the function is not replaced, raises an exception.
+
+ Arguments:
+ key: identifying information for who is doing the replacement. You can use __name__.
+ obj: the module or the class
+ field: name of the function as a string
+
+ Returns:
+ Always None
+ """
+
+ patch_key = (obj, field)
+
+ if patch_key not in originals[key]:
+ raise RuntimeError(f"there is no patch for {field} to undo")
+
+ original_func = originals[key].pop(patch_key)
+ setattr(obj, field, original_func)
+
+ return None
+
+
+def original(key, obj, field):
+ """Returns the original function for the patch created by the patch() function"""
+ patch_key = (obj, field)
+
+ return originals[key].get(patch_key, None)
+
+
+originals = defaultdict(dict)
|