@@ -22,20 +22,28 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
22
22
self .org_module : list [torch .Module ] = [self .sd_module ]
23
23
24
24
self .scale = 1.0
25
+ self .is_kohya = False
26
+ self .is_boft = False
25
27
26
28
# kohya-ss
27
29
if "oft_blocks" in weights .w .keys ():
28
30
self .is_kohya = True
29
31
self .oft_blocks = weights .w ["oft_blocks" ] # (num_blocks, block_size, block_size)
30
32
self .alpha = weights .w ["alpha" ] # alpha is constraint
31
33
self .dim = self .oft_blocks .shape [0 ] # lora dim
32
- # LyCORIS
34
+ # LyCORIS OFT
33
35
elif "oft_diag" in weights .w .keys ():
34
- self .is_kohya = False
35
36
self .oft_blocks = weights .w ["oft_diag" ]
36
37
# self.alpha is unused
37
38
self .dim = self .oft_blocks .shape [1 ] # (num_blocks, block_size, block_size)
38
39
40
+ # LyCORIS BOFT
41
+ if weights .w ["oft_diag" ].dim () == 4 :
42
+ self .is_boft = True
43
+ self .rescale = weights .w .get ('rescale' , None )
44
+ if self .rescale is not None :
45
+ self .rescale = self .rescale .reshape (- 1 , * [1 ]* (self .org_module [0 ].weight .dim () - 1 ))
46
+
39
47
is_linear = type (self .sd_module ) in [torch .nn .Linear , torch .nn .modules .linear .NonDynamicallyQuantizableLinear ]
40
48
is_conv = type (self .sd_module ) in [torch .nn .Conv2d ]
41
49
is_other_linear = type (self .sd_module ) in [torch .nn .MultiheadAttention ] # unsupported
@@ -51,6 +59,13 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
51
59
self .constraint = self .alpha * self .out_dim
52
60
self .num_blocks = self .dim
53
61
self .block_size = self .out_dim // self .dim
62
+ elif self .is_boft :
63
+ self .constraint = None
64
+ self .boft_m = weights .w ["oft_diag" ].shape [0 ]
65
+ self .block_num = weights .w ["oft_diag" ].shape [1 ]
66
+ self .block_size = weights .w ["oft_diag" ].shape [2 ]
67
+ self .boft_b = self .block_size
68
+ #self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
54
69
else :
55
70
self .constraint = None
56
71
self .block_size , self .num_blocks = factorization (self .out_dim , self .dim )
@@ -68,14 +83,37 @@ def calc_updown(self, orig_weight):
68
83
69
84
R = oft_blocks .to (orig_weight .device )
70
85
71
- # This errors out for MultiheadAttention, might need to be handled up-stream
72
- merged_weight = rearrange (orig_weight , '(k n) ... -> k n ...' , k = self .num_blocks , n = self .block_size )
73
- merged_weight = torch .einsum (
74
- 'k n m, k n ... -> k m ...' ,
75
- R ,
76
- merged_weight
77
- )
78
- merged_weight = rearrange (merged_weight , 'k m ... -> (k m) ...' )
86
+ if not self .is_boft :
87
+ # This errors out for MultiheadAttention, might need to be handled up-stream
88
+ merged_weight = rearrange (orig_weight , '(k n) ... -> k n ...' , k = self .num_blocks , n = self .block_size )
89
+ merged_weight = torch .einsum (
90
+ 'k n m, k n ... -> k m ...' ,
91
+ R ,
92
+ merged_weight
93
+ )
94
+ merged_weight = rearrange (merged_weight , 'k m ... -> (k m) ...' )
95
+ else :
96
+ # TODO: determine correct value for scale
97
+ scale = 1.0
98
+ m = self .boft_m
99
+ b = self .boft_b
100
+ r_b = b // 2
101
+ inp = orig_weight
102
+ for i in range (m ):
103
+ bi = R [i ] # b_num, b_size, b_size
104
+ if i == 0 :
105
+ # Apply multiplier/scale and rescale into first weight
106
+ bi = bi * scale + (1 - scale ) * eye
107
+ inp = rearrange (inp , "(c g k) ... -> (c k g) ..." , g = 2 , k = 2 ** i * r_b )
108
+ inp = rearrange (inp , "(d b) ... -> d b ..." , b = b )
109
+ inp = torch .einsum ("b i j, b j ... -> b i ..." , bi , inp )
110
+ inp = rearrange (inp , "d b ... -> (d b) ..." )
111
+ inp = rearrange (inp , "(c k g) ... -> (c g k) ..." , g = 2 , k = 2 ** i * r_b )
112
+ merged_weight = inp
113
+
114
+ # Rescale mechanism
115
+ if self .rescale is not None :
116
+ merged_weight = self .rescale .to (merged_weight ) * merged_weight
79
117
80
118
updown = merged_weight .to (orig_weight .device ) - orig_weight .to (merged_weight .dtype )
81
119
output_shape = orig_weight .shape
0 commit comments