Skip to content

Commit cca7a31

Browse files
committed
[BE] Delete trailing whitespaces
In MPS ops. Also use nested namespaces, as since pytorch-2.4 this should be a C++17 project
1 parent 342eb92 commit cca7a31

File tree

7 files changed

+39
-58
lines changed

7 files changed

+39
-58
lines changed

torchvision/csrc/ops/mps/deform_conv2d_kernel.mm

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#include <ATen/native/mps/OperationUtils.h>
44
#include "mps_kernels.h"
55

6-
namespace vision {
7-
namespace ops {
6+
namespace vision::ops {
87

98
namespace {
109

@@ -61,25 +60,25 @@
6160
uint32_t dilation_w_u = static_cast<uint32_t>(dilation_w);
6261

6362
TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels,
64-
"Input channels (", in_channels,
63+
"Input channels (", in_channels,
6564
") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")");
6665
TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0,
67-
"Weight tensor's out channels (", weight_c.size(0),
66+
"Weight tensor's out channels (", weight_c.size(0),
6867
") must be divisible by n_weight_grps (", n_weight_grps, ")");
6968
TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w,
70-
"Offset tensor shape[1] is invalid: got ", offset_c.size(1),
69+
"Offset tensor shape[1] is invalid: got ", offset_c.size(1),
7170
", expected ", n_offset_grps * 2 * weight_h * weight_w);
7271
TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w,
73-
"Mask tensor shape[1] is invalid: got ", mask_c.size(1),
72+
"Mask tensor shape[1] is invalid: got ", mask_c.size(1),
7473
", expected ", n_offset_grps * weight_h * weight_w);
7574
TORCH_CHECK(in_channels % n_offset_grps == 0,
76-
"Input tensor channels (", in_channels,
75+
"Input tensor channels (", in_channels,
7776
") must be divisible by n_offset_grps (", n_offset_grps, ")");
7877
TORCH_CHECK(offset_c.size(0) == batch,
7978
"Offset tensor batch size (", offset_c.size(0),
8079
") must match input tensor batch size (", batch, ")");
8180
TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w,
82-
"Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3),
81+
"Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3),
8382
") must match calculated output dimensions (", out_h, ", ", out_w, ")");
8483
TORCH_CHECK(!use_mask || mask_c.size(0) == batch,
8584
"Mask tensor batch size (", mask_c.size(0),
@@ -145,5 +144,4 @@
145144
TORCH_FN(deform_conv2d_forward_kernel));
146145
}
147146

148-
} // namespace ops
149-
} // namespace vision
147+
} // namespace vision::ops

torchvision/csrc/ops/mps/mps_kernels.h

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
#include <ATen/native/mps/OperationUtils.h>
22

3-
namespace vision {
4-
namespace ops {
5-
6-
namespace mps {
3+
namespace vision::ops::mps {
74

85
static at::native::mps::MetalShaderLibrary lib(R"VISION_METAL(
96
@@ -115,15 +112,15 @@ inline T bilinear_interpolate_deformable_conv2d(
115112
T v1 = 0;
116113
if (y_low >= 0 && x_low >= 0)
117114
v1 = input[y_low * width + x_low];
118-
115+
119116
T v2 = 0;
120117
if (y_low >= 0 && x_high <= width - 1)
121118
v2 = input[y_low * width + x_high];
122-
119+
123120
T v3 = 0;
124121
if (y_high <= height - 1 && x_low >= 0)
125122
v3 = input[y_high * width + x_low];
126-
123+
127124
T v4 = 0;
128125
if (y_high <= height - 1 && x_high <= width - 1)
129126
v4 = input[y_high * width + x_high];
@@ -228,7 +225,7 @@ kernel void nms(constant T * dev_boxes [[buffer(0)]],
228225
constant float & iou_threshold [[buffer(3)]],
229226
uint2 tgid [[threadgroup_position_in_grid]],
230227
uint2 tid2 [[thread_position_in_threadgroup]]) {
231-
228+
232229
const uint row_start = tgid.y;
233230
const uint col_start = tgid.x;
234231
const uint tid = tid2.x;
@@ -245,7 +242,7 @@ kernel void nms(constant T * dev_boxes [[buffer(0)]],
245242
const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid;
246243
uint64_t t = 0;
247244
uint start = 0;
248-
245+
249246
if (row_start == col_start) {
250247
start = tid + 1;
251248
}
@@ -309,48 +306,48 @@ kernel void deformable_im2col_kernel(
309306
int out_b = (tid / (out_w * out_h)) % batch_size;
310307
int in_c = tid / (out_w * out_h * batch_size);
311308
int out_c = in_c * weight_h * weight_w;
312-
309+
313310
int c_per_offset_grp = n_in_channels / n_offset_grps;
314311
int grp_idx = in_c / c_per_offset_grp;
315-
312+
316313
int col_offset = out_c * (batch_size * out_h * out_w)
317314
+ out_b * (out_h * out_w)
318315
+ out_y * out_w + out_x;
319316
device T* local_columns_ptr = columns_ptr + col_offset;
320-
317+
321318
int input_offset = out_b * (n_in_channels * height * width)
322319
+ in_c * (height * width);
323320
constant T* local_input_ptr = input_ptr + input_offset;
324-
321+
325322
int offset_offset = (out_b * n_offset_grps + grp_idx) * 2 * weight_h * weight_w * out_h * out_w;
326323
constant T* local_offset_ptr = offset_ptr + offset_offset;
327-
324+
328325
constant T* local_mask_ptr = nullptr;
329326
if (use_mask) {
330327
int mask_offset = (out_b * n_offset_grps + grp_idx) * weight_h * weight_w * out_h * out_w;
331328
local_mask_ptr = mask_ptr + mask_offset;
332329
}
333-
330+
334331
for (int i = 0; i < weight_h; ++i) {
335332
for (int j = 0; j < weight_w; ++j) {
336333
int mask_index = i * weight_w + j;
337334
int offset_index = 2 * mask_index;
338-
335+
339336
T mask_value = 1;
340337
if (use_mask) {
341338
mask_value = local_mask_ptr[mask_index * (out_h * out_w) + out_y * out_w + out_x];
342339
}
343-
340+
344341
T offset_h_val = local_offset_ptr[offset_index * (out_h * out_w) + out_y * out_w + out_x];
345342
T offset_w_val = local_offset_ptr[(offset_index + 1) * (out_h * out_w) + out_y * out_w + out_x];
346-
343+
347344
T y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h_val;
348345
T x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w_val;
349-
346+
350347
T interp = bilinear_interpolate_deformable_conv2d(local_input_ptr, height, width, y, x, tid);
351-
348+
352349
*local_columns_ptr = mask_value * interp;
353-
350+
354351
local_columns_ptr += batch_size * out_h * out_w;
355352
}
356353
}
@@ -584,7 +581,7 @@ kernel void roi_align_backward(
584581
atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast<T>(g2));
585582
atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast<T>(g3));
586583
atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast<T>(g4));
587-
584+
588585
} // if
589586
} // ix
590587
} // iy
@@ -742,7 +739,6 @@ kernel void roi_pool_backward(
742739
if (argmax != -1) {
743740
atomic_add_float(grad_input + offset + argmax, static_cast<T>(grad_output[output_offset + ph * h_stride + pw * w_stride]));
744741
}
745-
746742
} // MPS_1D_KERNEL_LOOP
747743
}
748744
@@ -1139,7 +1135,6 @@ kernel void ps_roi_pool_backward(
11391135
atomic_add_float(grad_input + offset + grad_input_index, diff_val);
11401136
}
11411137
}
1142-
11431138
} // MPS_1D_KERNEL_LOOP
11441139
}
11451140
@@ -1157,7 +1152,7 @@ kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
11571152
constant int64_t & width [[buffer(7)]], \
11581153
constant int64_t & pooled_height [[buffer(8)]], \
11591154
constant int64_t & pooled_width [[buffer(9)]], \
1160-
constant int64_t & channels_out [[buffer(10)]], \
1155+
constant int64_t & channels_out [[buffer(10)]], \
11611156
constant float & spatial_scale [[buffer(11)]], \
11621157
uint2 tgid [[threadgroup_position_in_grid]], \
11631158
uint2 tptg [[threads_per_threadgroup]], \
@@ -1192,6 +1187,4 @@ static id<MTLComputePipelineState> visionPipelineState(
11921187
return lib.getPipelineStateForFunc(kernel);
11931188
}
11941189

1195-
} // namespace mps
1196-
} // namespace ops
1197-
} // namespace vision
1190+
} // namespace vision::ops::mps

torchvision/csrc/ops/mps/nms_kernel.mm

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
#include <ATen/native/mps/OperationUtils.h>
33
#include "mps_kernels.h"
44

5-
namespace vision {
6-
namespace ops {
5+
namespace vision::ops {
76

87
namespace {
98

@@ -105,5 +104,4 @@
105104
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
106105
}
107106

108-
} // namespace ops
109-
} // namespace vision
107+
} // namespace vision::ops

torchvision/csrc/ops/mps/ps_roi_align_kernel.mm

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#include "mps_helpers.h"
44
#include "mps_kernels.h"
55

6-
namespace vision {
7-
namespace ops {
6+
namespace vision::ops {
87

98
namespace {
109

@@ -201,5 +200,4 @@
201200
m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel));
202201
}
203202

204-
} // namespace ops
205-
} // namespace vision
203+
} // namespace vision::ops

torchvision/csrc/ops/mps/ps_roi_pool_kernel.mm

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#include "mps_helpers.h"
44
#include "mps_kernels.h"
55

6-
namespace vision {
7-
namespace ops {
6+
namespace vision::ops {
87

98
namespace {
109

@@ -195,5 +194,4 @@
195194
m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel));
196195
}
197196

198-
} // namespace ops
199-
} // namespace vision
197+
} // namespace vision::ops

torchvision/csrc/ops/mps/roi_align_kernel.mm

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#include "mps_helpers.h"
44
#include "mps_kernels.h"
55

6-
namespace vision {
7-
namespace ops {
6+
namespace vision::ops {
87

98
namespace {
109

@@ -193,5 +192,4 @@
193192
m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel));
194193
}
195194

196-
} // namespace ops
197-
} // namespace vision
195+
} // namespace vision::ops

torchvision/csrc/ops/mps/roi_pool_kernel.mm

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#include "mps_helpers.h"
44
#include "mps_kernels.h"
55

6-
namespace vision {
7-
namespace ops {
6+
namespace vision::ops {
87

98
namespace {
109

@@ -192,5 +191,4 @@
192191
m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel));
193192
}
194193

195-
} // namespace ops
196-
} // namespace vision
194+
} // namespace vision::ops

0 commit comments

Comments
 (0)