@@ -56,7 +56,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
56
56
self .block_size , self .num_blocks = factorization (self .out_dim , self .dim )
57
57
58
58
def calc_updown (self , orig_weight ):
59
- oft_blocks = self .oft_blocks .to (orig_weight .device , dtype = orig_weight . dtype )
59
+ oft_blocks = self .oft_blocks .to (orig_weight .device )
60
60
eye = torch .eye (self .block_size , device = self .oft_blocks .device )
61
61
62
62
if self .is_kohya :
@@ -66,7 +66,7 @@ def calc_updown(self, orig_weight):
66
66
block_Q = block_Q * ((new_norm_Q + 1e-8 ) / (norm_Q + 1e-8 ))
67
67
oft_blocks = torch .matmul (eye + block_Q , (eye - block_Q ).float ().inverse ())
68
68
69
- R = oft_blocks .to (orig_weight .device , dtype = orig_weight . dtype )
69
+ R = oft_blocks .to (orig_weight .device )
70
70
71
71
# This errors out for MultiheadAttention, might need to be handled up-stream
72
72
merged_weight = rearrange (orig_weight , '(k n) ... -> k n ...' , k = self .num_blocks , n = self .block_size )
@@ -77,6 +77,6 @@ def calc_updown(self, orig_weight):
77
77
)
78
78
merged_weight = rearrange (merged_weight , 'k m ... -> (k m) ...' )
79
79
80
- updown = merged_weight .to (orig_weight .device , dtype = orig_weight . dtype ) - orig_weight
80
+ updown = merged_weight .to (orig_weight .device ) - orig_weight . to ( merged_weight . dtype )
81
81
output_shape = orig_weight .shape
82
82
return self .finalize_updown (updown , orig_weight , output_shape )
0 commit comments