@@ -56,14 +56,15 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
56
56
descriptor_.groups = groups;
57
57
}
58
58
59
- Tensor _mps_convolution (
59
+ Tensor _mps_convolution_impl (
60
60
const Tensor& input_t ,
61
61
const Tensor& weight_t ,
62
62
const c10::optional<Tensor>& bias_opt,
63
63
IntArrayRef padding,
64
64
IntArrayRef stride,
65
65
IntArrayRef dilation,
66
- int64_t groups) {
66
+ int64_t groups,
67
+ c10::optional<IntArrayRef> input_shape) {
67
68
TORCH_CHECK (input_t .dim () < 5 , " Conv3D is not supported on MPS" );
68
69
69
70
namespace native_mps = at::native::mps;
@@ -83,6 +84,8 @@ Tensor _mps_convolution(
83
84
auto memory_format = input_t .suggest_memory_format ();
84
85
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
85
86
auto output_t = at::empty (
87
+ input_shape.has_value () ?
88
+ input_shape.value () :
86
89
conv_output_size (input->sizes (), weight->sizes (),
87
90
padding, stride, dilation),
88
91
input->scalar_type (),
@@ -237,21 +240,29 @@ Tensor _mps_convolution(
237
240
return *output;
238
241
}
239
242
243
+ Tensor _mps_convolution (
244
+ const Tensor& input_t ,
245
+ const Tensor& weight_t ,
246
+ const c10::optional<Tensor>& bias_opt,
247
+ IntArrayRef padding,
248
+ IntArrayRef stride,
249
+ IntArrayRef dilation,
250
+ int64_t groups) {
251
+ return _mps_convolution_impl (input_t , weight_t , bias_opt, padding, stride, dilation, groups, c10::nullopt);
252
+ }
253
+
240
254
Tensor mps_convolution_backward_input (
241
- IntArrayRef input_size, const Tensor& grad_output_ , const Tensor& weight_ ,
255
+ IntArrayRef input_size, const Tensor& grad_output_t , const Tensor& weight_t ,
242
256
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
243
257
namespace native_mps = at::native::mps;
244
258
using namespace mps ;
245
259
CheckedFrom c = " mps_convolution_backward_input" ;
246
- TensorArg grad_output{ grad_output_ , " grad_output" , 1 },
247
- weight{ weight_ , " weight" , 2 };
260
+ TensorArg grad_output{ grad_output_t , " grad_output" , 1 },
261
+ weight{ weight_t , " weight" , 2 };
248
262
checkAllSameType (c, {grad_output, weight});
249
263
checkAllSameGPU (c, {grad_output, weight});
250
- auto memory_format = grad_output_ .suggest_memory_format ();
264
+ auto memory_format = grad_output_t .suggest_memory_format ();
251
265
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
252
- Tensor grad_output_t = grad_output_.contiguous (memory_format);
253
- Tensor weight_t = weight_.contiguous (memory_format);
254
- MPSShape* weightShape = getMPSShape (weight_);
255
266
auto grad_input_t = at::empty ( input_size, grad_output_t .options (), c10::nullopt);
256
267
257
268
// Avoid "grad_input" when this is being used as transposed convolution
@@ -327,10 +338,10 @@ Tensor mps_convolution_backward_input(
327
338
}
328
339
329
340
MPSGraphTensor* gradOutputTensor = native_mps::mpsGraphRankedPlaceHolder (mpsGraph, native_mps::getMPSScalarType (grad_output_t .scalar_type ()), gradOutputShape);
330
- MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder (mpsGraph, native_mps::getMPSScalarType ( weight_t . scalar_type ()), weightShape );
341
+ MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder (mpsGraph, weight_t );
331
342
332
343
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
333
- if (is_channels_last && grad_output_t . is_contiguous () && ! grad_output_t . is_view () ) {
344
+ if (is_channels_last) {
334
345
gradOutputTensorTranspose = mps::convertNHWCtoNCHW (mpsGraph, gradOutputTensorTranspose);
335
346
}
336
347
MPSGraphTensor* gradInputTensor;
@@ -359,7 +370,7 @@ Tensor mps_convolution_backward_input(
359
370
}
360
371
361
372
auto gradOutputPlaceholder = Placeholder (cachedGraph->gradOutputTensor_ , grad_output_t , gradOutputShape);
362
- auto weightsPlaceholder = Placeholder (cachedGraph->weightTensor_ , weight_t , weightShape );
373
+ auto weightsPlaceholder = Placeholder (cachedGraph->weightTensor_ , weight_t );
363
374
auto outputPlaceholder = Placeholder (cachedGraph->gradInputTensor_ , *grad_input);
364
375
365
376
NSDictionary <MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
@@ -377,17 +388,14 @@ Tensor mps_convolution_backward_input(
377
388
}
378
389
379
390
Tensor mps_convolution_backward_weights (
380
- IntArrayRef weight_size, const Tensor& grad_output_ , const Tensor& input_ ,
391
+ IntArrayRef weight_size, const Tensor& grad_output_t , const Tensor& input_t ,
381
392
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
382
393
namespace native_mps = at::native::mps;
383
394
using namespace mps ;
384
395
CheckedFrom c = " mps_convolution_backward_weights" ;
385
- auto memory_format = input_ .suggest_memory_format ();
396
+ auto memory_format = grad_output_t .suggest_memory_format ();
386
397
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
387
398
388
- auto grad_output_t = grad_output_.to (memory_format);
389
- auto input_t = input_.to (memory_format);
390
-
391
399
MPSShape* gradOutputShape = mps::getMPSShape (grad_output_t , memory_format);
392
400
393
401
// For uniformity with everything else, although it seems grad_weight
@@ -475,7 +483,7 @@ Tensor mps_convolution_backward_weights(
475
483
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder (mpsGraph, input_t );
476
484
477
485
MPSGraphTensor *gradOutputTensorTranspose = gradOutputTensor;
478
- if (is_channels_last && grad_output_t . is_contiguous () && ! grad_output_t . is_view () ) {
486
+ if (is_channels_last) {
479
487
gradOutputTensorTranspose = mps::convertNHWCtoNCHW (mpsGraph, gradOutputTensorTranspose);
480
488
}
481
489
@@ -525,12 +533,9 @@ Tensor mps_convolution_backward_weights(
525
533
}
526
534
527
535
std::tuple<at::Tensor,at::Tensor,at::Tensor> mps_convolution_backward (
528
- const at::Tensor& input, const at::Tensor& grad_output_t , const at::Tensor& weight,
536
+ const at::Tensor& input, const at::Tensor& grad_output , const at::Tensor& weight,
529
537
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
530
538
std::array<bool ,3 > output_mask) {
531
-
532
- Tensor grad_output = grad_output_t .contiguous (input.suggest_memory_format ());
533
-
534
539
Tensor grad_input, grad_weight, grad_bias;
535
540
if (input.numel () == 0 ) {
536
541
if (output_mask[0 ]) {
@@ -576,10 +581,10 @@ Tensor _mps_convolution_transpose(
576
581
Tensor mps_convolution_transpose_backward_input (
577
582
const Tensor& grad_output_t , const Tensor& weight_t ,
578
583
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
579
- int64_t groups)
584
+ int64_t groups, IntArrayRef input_shape )
580
585
{
581
- return at::_mps_convolution (
582
- grad_output_t , weight_t , c10::nullopt, padding, stride, dilation, groups);
586
+ return _mps_convolution_impl (
587
+ grad_output_t , weight_t , c10::nullopt, padding, stride, dilation, groups, input_shape );
583
588
}
584
589
585
590
Tensor mps_convolution_transpose_backward_weight (
@@ -595,15 +600,12 @@ Tensor mps_convolution_transpose_backward_weight(
595
600
596
601
597
602
std::tuple<Tensor,Tensor> mps_convolution_transpose_backward (
598
- const Tensor& input, const Tensor& grad_output_t , const Tensor& weight,
603
+ const Tensor& input, const Tensor& grad_output , const Tensor& weight,
599
604
IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
600
605
std::array<bool ,2 > output_mask) {
601
-
602
- Tensor grad_output = grad_output_t .contiguous (input.suggest_memory_format ());
603
-
604
606
Tensor grad_input, grad_weight;
605
607
if (output_mask[0 ]) {
606
- grad_input = mps_convolution_transpose_backward_input (grad_output, weight, padding, stride, dilation, groups);
608
+ grad_input = mps_convolution_transpose_backward_input (grad_output, weight, padding, stride, dilation, groups, input. sizes () );
607
609
}
608
610
if (output_mask[1 ]) {
609
611
grad_weight = mps_convolution_transpose_backward_weight (weight.sizes (), grad_output, input, padding, stride, dilation, groups);
0 commit comments