Skip to content

Commit 8b6848c

Browse files
Merge pull request #14546 from AUTOMATIC1111/fix-oft-dtype
Fix dtype casting in OFT module
2 parents a4ee640 + f8f38c7 commit 8b6848c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

extensions-builtin/Lora/network_oft.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
5656
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
5757

5858
def calc_updown(self, orig_weight):
59-
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
59+
oft_blocks = self.oft_blocks.to(orig_weight.device)
6060
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
6161

6262
if self.is_kohya:
@@ -66,7 +66,7 @@ def calc_updown(self, orig_weight):
6666
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
6767
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
6868

69-
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
69+
R = oft_blocks.to(orig_weight.device)
7070

7171
# This errors out for MultiheadAttention, might need to be handled up-stream
7272
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
@@ -77,6 +77,6 @@ def calc_updown(self, orig_weight):
7777
)
7878
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
7979

80-
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
80+
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
8181
output_shape = orig_weight.shape
8282
return self.finalize_updown(updown, orig_weight, output_shape)

0 commit comments

Comments
 (0)