Skip to content

Commit 1399769

Browse files
committed
fix: Revert previous fix for unfold method in torch frontend and push a better fix for shape mismatch issues with the native fw
1 parent 4396c35 commit 1399769

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

ivy/functional/frontends/torch/tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,9 +767,17 @@ def unfold(self, dimension, size, step):
767767
slicing[dimension] = slice(i, i + size)
768768
slices.append(self.ivy_array[tuple(slicing)])
769769
stacked = torch_frontend.stack(slices, dim=dimension)
770+
new_shape = list(self.shape)
771+
num_slices = (self.shape[dimension] - size) // step + 1
772+
new_shape[dimension] = num_slices
773+
if dimension == -1:
774+
new_shape.insert(dimension, size)
775+
else:
776+
new_shape.insert(dimension + 1, size)
777+
reshaped = stacked.reshape(new_shape)
770778
dims = list(range(len(stacked.shape)))
771779
dims[-2], dims[-1] = dims[-1], dims[-2]
772-
return stacked.permute(*dims)
780+
return reshaped.permute(*dims)
773781

774782
def long(self, memory_format=None):
775783
self.ivy_array = ivy.astype(self.ivy_array, ivy.int64, copy=False)

0 commit comments

Comments
 (0)