Skip to content

Commit 6e4fc5e

Browse files
Merge pull request #14871 from v0xie/boft
Support inference with LyCORIS BOFT networks
2 parents 9d5becb + 4eb9496 commit 6e4fc5e

File tree

1 file changed

+48
-10
lines changed

1 file changed

+48
-10
lines changed

extensions-builtin/Lora/network_oft.py

+48-10
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,28 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
2222
self.org_module: list[torch.Module] = [self.sd_module]
2323

2424
self.scale = 1.0
25+
self.is_kohya = False
26+
self.is_boft = False
2527

2628
# kohya-ss
2729
if "oft_blocks" in weights.w.keys():
2830
self.is_kohya = True
2931
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
3032
self.alpha = weights.w["alpha"] # alpha is constraint
3133
self.dim = self.oft_blocks.shape[0] # lora dim
32-
# LyCORIS
34+
# LyCORIS OFT
3335
elif "oft_diag" in weights.w.keys():
34-
self.is_kohya = False
3536
self.oft_blocks = weights.w["oft_diag"]
3637
# self.alpha is unused
3738
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
3839

40+
# LyCORIS BOFT
41+
if weights.w["oft_diag"].dim() == 4:
42+
self.is_boft = True
43+
self.rescale = weights.w.get('rescale', None)
44+
if self.rescale is not None:
45+
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
46+
3947
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
4048
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
4149
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
@@ -51,6 +59,13 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
5159
self.constraint = self.alpha * self.out_dim
5260
self.num_blocks = self.dim
5361
self.block_size = self.out_dim // self.dim
62+
elif self.is_boft:
63+
self.constraint = None
64+
self.boft_m = weights.w["oft_diag"].shape[0]
65+
self.block_num = weights.w["oft_diag"].shape[1]
66+
self.block_size = weights.w["oft_diag"].shape[2]
67+
self.boft_b = self.block_size
68+
#self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
5469
else:
5570
self.constraint = None
5671
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
@@ -68,14 +83,37 @@ def calc_updown(self, orig_weight):
6883

6984
R = oft_blocks.to(orig_weight.device)
7085

71-
# This errors out for MultiheadAttention, might need to be handled up-stream
72-
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
73-
merged_weight = torch.einsum(
74-
'k n m, k n ... -> k m ...',
75-
R,
76-
merged_weight
77-
)
78-
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
86+
if not self.is_boft:
87+
# This errors out for MultiheadAttention, might need to be handled up-stream
88+
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
89+
merged_weight = torch.einsum(
90+
'k n m, k n ... -> k m ...',
91+
R,
92+
merged_weight
93+
)
94+
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
95+
else:
96+
# TODO: determine correct value for scale
97+
scale = 1.0
98+
m = self.boft_m
99+
b = self.boft_b
100+
r_b = b // 2
101+
inp = orig_weight
102+
for i in range(m):
103+
bi = R[i] # b_num, b_size, b_size
104+
if i == 0:
105+
# Apply multiplier/scale and rescale into first weight
106+
bi = bi * scale + (1 - scale) * eye
107+
inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
108+
inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
109+
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
110+
inp = rearrange(inp, "d b ... -> (d b) ...")
111+
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
112+
merged_weight = inp
113+
114+
# Rescale mechanism
115+
if self.rescale is not None:
116+
merged_weight = self.rescale.to(merged_weight) * merged_weight
79117

80118
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
81119
output_shape = orig_weight.shape

0 commit comments

Comments
 (0)