diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-31 07:38:34 +0300 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-31 07:38:34 +0300 |
commit | 5ef669de080814067961f28357256e8fe27544f4 (patch) | |
tree | 655f4582e692f0fc3667b3b668ad365ac3ab92ae /extensions-builtin/Lora/network_norm.py | |
parent | c9c8485bc1e8720aba70f029d25cba1c4abf2b5c (diff) | |
parent | e7965a5eb804a51e949df07c66c0b7c61ab7fa7b (diff) |
Merge branch 'release_candidate'
Diffstat (limited to 'extensions-builtin/Lora/network_norm.py')
-rw-r--r-- | extensions-builtin/Lora/network_norm.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py new file mode 100644 index 00000000..ce450158 --- /dev/null +++ b/extensions-builtin/Lora/network_norm.py @@ -0,0 +1,28 @@ +import network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + + return None + + +class NetworkModuleNorm(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, orig_weight): + output_shape = self.w_norm.shape + updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + + if self.b_norm is not None: + ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) |