aboutsummaryrefslogtreecommitdiff
path: root/extensions-builtin/Lora/lyco_helpers.py
diff options
context:
space:
mode:
authorLeon Feng <523684+leon0707@users.noreply.github.com>2023-07-18 04:24:14 -0400
committerGitHub <noreply@github.com>2023-07-18 04:24:14 -0400
commita3730bd9becd2f1f5d209885b694b0dec178d110 (patch)
tree8ac9948d89606f7519df786f07f6ddb93c3d2720 /extensions-builtin/Lora/lyco_helpers.py
parentd6668347c8b85b11b696ac56777cc396e34ee1f9 (diff)
parent871b8687a82bb2ca907d8a49c87aed7635b8fc33 (diff)
Merge branch 'dev' into fix-11805
Diffstat (limited to 'extensions-builtin/Lora/lyco_helpers.py')
-rw-r--r--extensions-builtin/Lora/lyco_helpers.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py
new file mode 100644
index 00000000..279b34bc
--- /dev/null
+++ b/extensions-builtin/Lora/lyco_helpers.py
@@ -0,0 +1,21 @@
+import torch
+
+
+def make_weight_cp(t, wa, wb):
+ temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
+ return torch.einsum('i j k l, i r -> r j k l', temp, wa)
+
+
+def rebuild_conventional(up, down, shape, dyn_dim=None):
+ up = up.reshape(up.size(0), -1)
+ down = down.reshape(down.size(0), -1)
+ if dyn_dim is not None:
+ up = up[:, :dyn_dim]
+ down = down[:dyn_dim, :]
+ return (up @ down).reshape(shape)
+
+
+def rebuild_cp_decomposition(up, down, mid):
+ up = up.reshape(up.size(0), -1)
+ down = down.reshape(down.size(0), -1)
+ return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)