Skip to content

Commit 6f11e6d

Browse files
[MPS] Convolution fixes (pytorch#95318)
* [MPS] Convolution cleanup; remove unnecessary contiguous calls (pytorch#95078) - Fixes convolution crashes in backward with weights - Removes unnecessary contiguous calls Pull Request resolved: pytorch#95078 Approved by: https://github.com/kulinseth * [MPS] Fix nn.functional.conv_transpose2d grad (pytorch#94871) - add _mps_convolution_impl that takes optional shape - for conv_tranpose2d grad, use the shape from forward pass directly - for conv, calculate the shape from input - remove nn.functional.conv_transpose2d grad from blocklist Pull Request resolved: pytorch#94871 Approved by: https://github.com/kulinseth --------- Co-authored-by: Denis Vieriu <[email protected]> Co-authored-by: Denis Vieriu <[email protected]>
1 parent fcec27f commit 6f11e6d

File tree

2 files changed

+117
-40
lines changed

2 files changed

+117
-40
lines changed

aten/src/ATen/native/mps/operations/Convolution.mm

+32-30
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,15 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
5656
descriptor_.groups = groups;
5757
}
5858

59-
Tensor _mps_convolution(
59+
Tensor _mps_convolution_impl(
6060
const Tensor& input_t,
6161
const Tensor& weight_t,
6262
const c10::optional<Tensor>& bias_opt,
6363
IntArrayRef padding,
6464
IntArrayRef stride,
6565
IntArrayRef dilation,
66-
int64_t groups) {
66+
int64_t groups,
67+
c10::optional<IntArrayRef> input_shape) {
6768
TORCH_CHECK(input_t.dim() < 5, "Conv3D is not supported on MPS");
6869

6970
namespace native_mps = at::native::mps;
@@ -83,6 +84,8 @@ Tensor _mps_convolution(
8384
auto memory_format = input_t.suggest_memory_format();
8485
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
8586
auto output_t = at::empty(
87+
input_shape.has_value() ?
88+
input_shape.value() :
8689
conv_output_size(input->sizes(), weight->sizes(),
8790
padding, stride, dilation),
8891
input->scalar_type(),
@@ -237,21 +240,29 @@ Tensor _mps_convolution(
237240
return *output;
238241
}
239242

243+
Tensor _mps_convolution(
244+
const Tensor& input_t,
245+
const Tensor& weight_t,
246+
const c10::optional<Tensor>& bias_opt,
247+
IntArrayRef padding,
248+
IntArrayRef stride,
249+
IntArrayRef dilation,
250+
int64_t groups) {
251+
return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, c10::nullopt);
252+
}
253+
240254
Tensor mps_convolution_backward_input(
241-
IntArrayRef input_size, const Tensor& grad_output_, const Tensor& weight_,
255+
IntArrayRef input_size, const Tensor& grad_output_t, const Tensor& weight_t,
242256
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
243257
namespace native_mps = at::native::mps;
244258
using namespace mps;
245259
CheckedFrom c = "mps_convolution_backward_input";
246-
TensorArg grad_output{ grad_output_, "grad_output", 1 },
247-
weight{ weight_, "weight", 2 };
260+
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
261+
weight{ weight_t, "weight", 2 };
248262
checkAllSameType(c, {grad_output, weight});
249263
checkAllSameGPU(c, {grad_output, weight});
250-
auto memory_format = grad_output_.suggest_memory_format();
264+
auto memory_format = grad_output_t.suggest_memory_format();
251265
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
252-
Tensor grad_output_t = grad_output_.contiguous(memory_format);
253-
Tensor weight_t = weight_.contiguous(memory_format);
254-
MPSShape* weightShape = getMPSShape(weight_);
255266
auto grad_input_t = at::empty( input_size, grad_output_t.options(), c10::nullopt);
256267

257268
// Avoid "grad_input" when this is being used as transposed convolution
@@ -327,10 +338,10 @@ Tensor mps_convolution_backward_input(
327338
}
328339

329340
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(grad_output_t.scalar_type()), gradOutputShape);
330-
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(weight_t.scalar_type()), weightShape);
341+
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
331342

332343
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
333-
if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) {
344+
if (is_channels_last) {
334345
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
335346
}
336347
MPSGraphTensor* gradInputTensor;
@@ -359,7 +370,7 @@ Tensor mps_convolution_backward_input(
359370
}
360371

361372
auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
362-
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t, weightShape);
373+
auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
363374
auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);
364375

365376
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
@@ -377,17 +388,14 @@ Tensor mps_convolution_backward_input(
377388
}
378389

379390
Tensor mps_convolution_backward_weights(
380-
IntArrayRef weight_size, const Tensor& grad_output_, const Tensor& input_,
391+
IntArrayRef weight_size, const Tensor& grad_output_t, const Tensor& input_t,
381392
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
382393
namespace native_mps = at::native::mps;
383394
using namespace mps;
384395
CheckedFrom c = "mps_convolution_backward_weights";
385-
auto memory_format = input_.suggest_memory_format();
396+
auto memory_format = grad_output_t.suggest_memory_format();
386397
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
387398

388-
auto grad_output_t = grad_output_.to(memory_format);
389-
auto input_t = input_.to(memory_format);
390-
391399
MPSShape* gradOutputShape = mps::getMPSShape(grad_output_t, memory_format);
392400

393401
// For uniformity with everything else, although it seems grad_weight
@@ -475,7 +483,7 @@ Tensor mps_convolution_backward_weights(
475483
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
476484

477485
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
478-
if (is_channels_last && grad_output_t.is_contiguous() && !grad_output_t.is_view()) {
486+
if (is_channels_last) {
479487
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
480488
}
481489

@@ -525,12 +533,9 @@ Tensor mps_convolution_backward_weights(
525533
}
526534

527535
std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward(
528-
const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight,
536+
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
529537
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
530538
std::array<bool,3> output_mask) {
531-
532-
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
533-
534539
Tensor grad_input, grad_weight, grad_bias;
535540
if (input.numel() == 0) {
536541
if (output_mask[0]) {
@@ -576,10 +581,10 @@ Tensor _mps_convolution_transpose(
576581
Tensor mps_convolution_transpose_backward_input(
577582
const Tensor& grad_output_t, const Tensor& weight_t,
578583
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
579-
int64_t groups)
584+
int64_t groups, IntArrayRef input_shape)
580585
{
581-
return at::_mps_convolution(
582-
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups);
586+
return _mps_convolution_impl(
587+
grad_output_t, weight_t, c10::nullopt, padding, stride, dilation, groups, input_shape);
583588
}
584589

585590
Tensor mps_convolution_transpose_backward_weight(
@@ -595,15 +600,12 @@ Tensor mps_convolution_transpose_backward_weight(
595600

596601

597602
std::tuple<Tensor,Tensor> mps_convolution_transpose_backward(
598-
const Tensor& input, const Tensor& grad_output_t, const Tensor& weight,
603+
const Tensor& input, const Tensor& grad_output, const Tensor& weight,
599604
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
600605
std::array<bool,2> output_mask) {
601-
602-
Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
603-
604606
Tensor grad_input, grad_weight;
605607
if (output_mask[0]) {
606-
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups);
608+
grad_input = mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
607609
}
608610
if (output_mask[1]) {
609611
grad_weight = mps_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups);

test/test_mps.py

+85-10
Original file line numberDiff line numberDiff line change
@@ -7736,7 +7736,8 @@ def test_conv_transpose_1d_nn_functional(self):
77367736
def test_conv_backward_1d_channels_last(self):
77377737
def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
77387738
# https://github.com/pytorch/pytorch/issues/84511
7739-
conv_cpu = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups)
7739+
conv_cpu = torch.nn.Conv1d(
7740+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).requires_grad_()
77407741
conv_mps = torch.nn.Conv1d(
77417742
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
77427743
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
@@ -7776,15 +7777,89 @@ def test_conv1d_contiguous(self):
77767777

77777778
def test_conv2d_all_strides_paddings(self):
77787779
# https://github.com/pytorch/pytorch/issues/83180
7779-
y_cpu = torch.randn(2, 2, 3, 6)
7780-
y_gpu = y_cpu.to(device='mps')
7781-
for strideX in range(1, 4):
7782-
for strideY in range(1, 4):
7783-
conv_cpu = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=(strideX, strideY))
7784-
conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
7785-
x_cpu = conv_cpu(y_cpu)
7786-
x_gpu = conv_gpu(y_gpu)
7787-
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
7780+
def helper(N, C, H, W, groups, input_mem_format, weight_mem_format, permute_data):
7781+
x_cpu = torch.randn(N, C, H, W).to(memory_format=input_mem_format).requires_grad_()
7782+
x_mps = x_cpu.detach().clone().to(device='mps').requires_grad_()
7783+
7784+
if permute_data:
7785+
x_cpu.permute(0, 2, 3, 1)
7786+
x_mps.permute(0, 2, 3, 1)
7787+
7788+
for strideX in range(1, 4):
7789+
for strideY in range(1, 4):
7790+
conv_cpu = torch.nn.Conv2d(
7791+
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY)).requires_grad_()
7792+
conv_cpu.weight.data = conv_cpu.weight.to(memory_format=weight_mem_format).requires_grad_()
7793+
7794+
conv_mps = torch.nn.Conv2d(
7795+
in_channels=N, out_channels=C, kernel_size=H, groups=groups, stride=(strideX, strideY), device="mps")
7796+
conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_()
7797+
conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_()
7798+
7799+
res_cpu = conv_cpu(x_cpu)
7800+
res_mps = conv_mps(x_mps)
7801+
self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
7802+
7803+
res_cpu = res_cpu.sum().backward()
7804+
res_mps = res_mps.sum().backward()
7805+
self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
7806+
self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
7807+
self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
7808+
self.assertEqual(x_cpu.grad, x_mps.grad)
7809+
7810+
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
7811+
for mem_format_weight in [torch.contiguous_format, torch.channels_last]:
7812+
for permute_data in [True, False]:
7813+
helper(2, 2, 3, 6, 1, mem_format_input, mem_format_weight, permute_data)
7814+
helper(10, 10, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
7815+
helper(32, 32, 4, 6, 2, mem_format_input, mem_format_weight, permute_data)
7816+
7817+
def test_conv_transpose_2d_strided(self):
7818+
def helper(m_cpu, memory_format):
7819+
m_mps = copy.deepcopy(m_cpu).requires_grad_()
7820+
m_mps.weight.data = m_cpu.weight.data.detach().clone().to("mps").requires_grad_()
7821+
m_mps.bias.data = m_cpu.bias.data.detach().clone().to("mps").requires_grad_()
7822+
7823+
input_cpu = torch.randn(20, 16, 50, 100).to(memory_format=memory_format).requires_grad_()
7824+
input_mps = input_cpu.detach().clone().to("mps")
7825+
7826+
output_cpu = m_cpu(input_cpu)
7827+
output_mps = m_mps(input_mps)
7828+
self.assertEqual(output_cpu, output_mps)
7829+
7830+
for mem_format_input in [torch.contiguous_format, torch.channels_last]:
7831+
# With square kernels and equal stride
7832+
helper(nn.ConvTranspose2d(16, 33, 3, stride=2).requires_grad_(), mem_format_input)
7833+
7834+
# non-square kernels and unequal stride and with padding
7835+
helper(nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2)).requires_grad_(), mem_format_input)
7836+
7837+
def test_conv_transpose_2d_specified_output(self):
7838+
input_cpu = torch.randn(1, 16, 12, 12)
7839+
input_mps = input_cpu.detach().clone().to("mps")
7840+
7841+
downsample_cpu = nn.Conv2d(16, 16, 3, stride=2, padding=1)
7842+
downsample_mps = nn.Conv2d(16, 16, 3, stride=2, padding=1, device="mps")
7843+
downsample_mps.weight.data = downsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
7844+
downsample_mps.bias.data = downsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
7845+
7846+
upsample_cpu = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
7847+
upsample_mps = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1, device="mps")
7848+
upsample_mps.weight.data = upsample_cpu.weight.data.detach().clone().to("mps").requires_grad_()
7849+
upsample_mps.bias.data = upsample_cpu.bias.data.detach().clone().to("mps").requires_grad_()
7850+
7851+
h_cpu = downsample_cpu(input_cpu)
7852+
h_mps = downsample_mps(input_mps)
7853+
self.assertEqual(h_cpu, h_mps)
7854+
7855+
size_cpu = h_cpu.size()
7856+
size_mps = h_mps.size()
7857+
self.assertEqual(size_cpu, size_mps)
7858+
7859+
output_cpu = upsample_cpu(h_cpu, output_size=input_cpu.size())
7860+
output_mps = upsample_mps(h_mps, output_size=input_mps.size())
7861+
self.assertEqual(output_cpu, output_mps)
7862+
self.assertEqual(output_cpu.size(), output_mps.size())
77887863

77897864
def test_conv2d_single_stride(self):
77907865
y_cpu = torch.randn(2, 2, 3, 6)

0 commit comments

Comments
 (0)