diff options
author | JaredTherriault <noirjt@live.com> | 2023-09-04 17:29:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-04 17:29:33 -0700 |
commit | 5e16914a4e157ab3ed96f8b7841e1290a56f4484 (patch) | |
tree | 655f4582e692f0fc3667b3b668ad365ac3ab92ae /extensions-builtin/Lora/lora_patches.py | |
parent | 8f3b02f09535f55d3673aa9ea589396b8614f799 (diff) | |
parent | 5ef669de080814067961f28357256e8fe27544f4 (diff) |
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'extensions-builtin/Lora/lora_patches.py')
-rw-r--r-- | extensions-builtin/Lora/lora_patches.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py new file mode 100644 index 00000000..b394d8e9 --- /dev/null +++ b/extensions-builtin/Lora/lora_patches.py @@ -0,0 +1,31 @@ +import torch
+
+import networks
+from modules import patches
+
+
+class LoraPatches:
+ def __init__(self):
+ self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
+ self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
+ self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
+ self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
+ self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
+ self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
+ self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
+ self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
+ self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
+ self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
+
+ def undo(self):
+ self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
+ self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
+ self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
+ self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
+ self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
+ self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
+ self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
+ self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
+ self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
+ self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
+
|