Skip to content

Commit bedeb1f

Browse files
ezyangpytorchmergebot
authored andcommitted
Add torch.empty_permuted (pytorch#95069)
torch.empty_permuted is a generalized version of torch.empty(memory_format=...), where you can pass an arbitrary physical layout as a tuple of dims to allow you to setup dense, non-overlapping tensors with non-standard memory format. Check the docblock for a full description of semantics. The initial motivation for this PR is with guard-less unbacked SymInts. Traditionally, the way we allocate dense tensors with arbitrary layout is with `empty_strided`. However, `empty_strided` does not know that the given strides are actually contiguous, and must test this manually to find out if it is the case. With `empty_permuted`, this is known statically to be the case and helps us skip some 0/1 guards. However, I also think torch.empty_permuted is a useful API in its own right. It is technically possible to simulate this with an empty and a permute; however, there are some downsides: * The manual incant is tricky to work out. To allocate an NHWC tensor, the invocation is `torch.empty(N, H, W, C).permute(0, 3, 1, 2)`; the permute call has to take NHWC to NCHW, and is the *inverse* of the permutation people are typically thinking of when they talk about NHWC (0, 2, 3, 1). Instead, torch.empty_permuted lets you say `torch.empty_permuted((N, C, H, W), (0, 2, 3, 1))`, letting you provide the intuitive permutation. It can be literally be read off as NHWC if you assign N=0, C=1, H=2, W=3. * An empty(requires_grad=True).permute() is no longer a leaf tensor. You can force it to be a leaf with a detach(), but it is more straightforward and less error prone to allow directly allocating a tensor with the correct permutation. It is also technically possible to simulate this with empty_strided. However, this requires the user to manually compute the contiguous output strides and is bad from a reduction of guards perspective. For what it's worth, this is one of the more common uses of as_strided in the wild, and it would be nice to get rid of it. A nice enhancement of this feature would be to accept `physical_layout` anywhere `memory_format` is accepted. However, this would be a pretty involved change, so I'm doing the easy thing instead. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#95069 Approved by: https://github.com/malfet, https://github.com/ngimel, https://github.com/albanD, https://github.com/dagitses
1 parent 50ec4dd commit bedeb1f

12 files changed

+254
-0
lines changed

aten/src/ATen/native/TensorFactories.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include <ATen/ops/empty_like.h>
4747
#include <ATen/ops/empty_like_native.h>
4848
#include <ATen/ops/empty_native.h>
49+
#include <ATen/ops/empty_permuted_native.h>
4950
#include <ATen/ops/empty_strided.h>
5051
#include <ATen/ops/empty_strided_native.h>
5152
#include <ATen/ops/eye.h>
@@ -278,6 +279,45 @@ Tensor empty_names(
278279
return result;
279280
}
280281

282+
Tensor empty_permuted_symint(SymIntArrayRef size, IntArrayRef physical_layout, c10::optional<ScalarType> dtype_opt,
283+
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt
284+
) {
285+
// size is logical; aka, the output size you'll get from the operation overall
286+
//
287+
// physical_layout follows NCHW/NHWC convention:
288+
// contiguous is [0,1,2,3], channels last is [0,2,3,1]
289+
//
290+
// this means if i is physical index, physical_layout[i] is logical index;
291+
// e.g., to find what is innermost physical dim (3), query NHWC[3] == 1
292+
// (aka it is channels)
293+
int64_t dim = static_cast<int64_t>(size.size());
294+
SymDimVector phys_size(dim);
295+
TORCH_CHECK(physical_layout.size() == dim,
296+
"Number of dimensions in size does not match the "
297+
"length of the physical_layout; i.e. len(size) = ", dim,
298+
" is not equal to len(physical_layout) = ", physical_layout.size());
299+
std::vector<bool> seen_dims(dim);
300+
for (const auto i : c10::irange(dim)) {
301+
TORCH_CHECK(physical_layout[i] >= 0 && physical_layout[i] < dim,
302+
"Dimension out of range (expected to be between 0 and ", dim - 1, ", but got ",
303+
physical_layout[i], " at index ", i, "). NB: negative dims "
304+
"not currently supported; file an issue if you want it.");
305+
TORCH_CHECK(!seen_dims[physical_layout[i]], "Duplicate dim not allowed");
306+
phys_size[i] = size[physical_layout[i]];
307+
seen_dims[physical_layout[i]] = true;
308+
}
309+
// do a contiguous allocation
310+
Tensor phys_tensor = at::empty_symint(phys_size, dtype_opt, layout_opt, device_opt, pin_memory_opt, c10::nullopt);
311+
SymIntArrayRef phys_strides = phys_tensor.sym_strides();
312+
// permute the strides (inverse permutation! This is why this is
313+
// empty_permute*d*, not empty_permute; it's not an empty + permute)
314+
SymDimVector strides(dim);
315+
for (const auto i : c10::irange(dim)) {
316+
strides[physical_layout[i]] = phys_strides[i];
317+
}
318+
return phys_tensor.as_strided_symint(size, strides);
319+
}
320+
281321
Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
282322
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
283323
return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);

aten/src/ATen/native/native_functions.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,11 @@
22412241
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
22422242
QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized
22432243

2244+
- func: empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
2245+
dispatch:
2246+
CompositeExplicitAutograd: empty_permuted_symint
2247+
autogen: empty_permuted.out
2248+
22442249
# We do not make new_empty a composite that calls into new_empty_strided, as the strided version
22452250
# is significantly more difficult to implement by different backends
22462251
- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

test/expect/HasDecompTest.test_has_decomposition.expect

+2
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,8 @@ aten::embedding_renorm_
719719
aten::empty.memory_format
720720
aten::empty.names
721721
aten::empty.names_out
722+
aten::empty_permuted
723+
aten::empty_permuted.out
722724
aten::empty_quantized
723725
aten::empty_quantized.out
724726
aten::equal

test/inductor/test_torchinductor_opinfo.py

+1
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ def wrapper_set_seed(op, *args, **kwargs):
429429
inductor_override_kwargs = {
430430
# the return value of empty is undefined
431431
"empty": {"assert_equal": False},
432+
"empty_permuted": {"assert_equal": False},
432433
"empty_like": {"assert_equal": False},
433434
"new_empty": {"assert_equal": False},
434435
"new_empty_strided": {"assert_equal": False},

test/test_proxy_tensor.py

+1
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,7 @@ def f(a, b, c, d, e):
11531153
skip('new_empty'),
11541154
skip('empty_like'),
11551155
skip('empty'),
1156+
skip('empty_permuted'),
11561157
# flaky
11571158
skip('linalg.lstsq', 'grad_oriented'),
11581159
skip('nn.functional.max_unpool1d', '', device_type='cpu'),

torch/_inductor/decomposition.py

+12
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ def floordiv(a, b):
6161
return aten.div.Tensor_mode(a, b, rounding_mode="floor")
6262

6363

64+
# Not really sure how to put this into the main library. PrimTorch wants
65+
# empty_permuted to go to the prim, and typically users don't really want
66+
# to decompose to empty_strided (but inductor is OK with it, because we are
67+
# cool with strides and everything goes to empty_strided)
68+
@register_decomposition([aten.empty_permuted.default])
69+
def empty_permuted(size, physical_layout, **kwargs):
70+
perm = [0] * len(size)
71+
for p, l in enumerate(physical_layout):
72+
perm[l] = p
73+
return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
74+
75+
6476
def get_alignment_size(x):
6577
if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16:
6678
return 8

torch/_prims/__init__.py

+56
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@
193193
# Tensor Creation Prims
194194
#
195195
"empty_strided",
196+
"empty_permuted",
196197
"scalar_tensor",
197198
"iota",
198199
#
@@ -2466,6 +2467,61 @@ def _empty_strided_meta(
24662467
)
24672468

24682469

2470+
def _empty_permuted_meta(
2471+
shape: ShapeType,
2472+
physical_layout: DimsSequenceType,
2473+
*,
2474+
dtype: torch.dtype,
2475+
device: torch.device,
2476+
requires_grad: bool,
2477+
) -> TensorLikeType:
2478+
p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
2479+
dim = len(shape)
2480+
utils.check(
2481+
len(physical_layout) == dim,
2482+
lambda: (
2483+
"Number of dimensions in the tensor input does not match the "
2484+
f"length of the physical layout; i.e. len(size) = {dim} "
2485+
f"is not equal to len(physical_layout) = {len(physical_layout)}"
2486+
),
2487+
)
2488+
strides = [0] * len(shape)
2489+
seen_dims = set()
2490+
for p, l in enumerate(physical_layout):
2491+
utils.check(
2492+
0 <= l < dim,
2493+
lambda: (
2494+
f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
2495+
f"{l} at index {p}). NB: negative dims "
2496+
"not currently supported; file an issue if you want it."
2497+
),
2498+
)
2499+
utils.check(l not in seen_dims, lambda: "Duplicate dim not allowed")
2500+
strides[l] = p_strides[p]
2501+
seen_dims.add(l)
2502+
return TensorMeta(
2503+
shape=shape,
2504+
strides=strides,
2505+
dtype=dtype,
2506+
device=device,
2507+
)
2508+
2509+
2510+
_empty_permuted_doc = """
2511+
Creates a tensor with uninitialized values according to some physical layout,
2512+
that is guaranteed to be non-overlapping and dense.
2513+
"""
2514+
2515+
# TODO: add layout, pin_memory
2516+
empty_permuted = _make_prim(
2517+
schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950
2518+
return_type=RETURN_TYPE.NEW,
2519+
meta=_empty_permuted_meta,
2520+
impl_aten=torch.empty_permuted,
2521+
doc=_empty_permuted_doc,
2522+
)
2523+
2524+
24692525
def _full_meta(
24702526
shape: ShapeType,
24712527
fill_value: NumberType,

torch/_refs/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -4042,6 +4042,27 @@ def empty(
40424042
)
40434043

40444044

4045+
@out_wrapper()
4046+
def empty_permuted(
4047+
shape,
4048+
physical_layout,
4049+
dtype: Optional[torch.dtype] = None,
4050+
layout: torch.layout = torch.strided,
4051+
device: Optional[torch.device] = None,
4052+
requires_grad: bool = False,
4053+
pin_memory: bool = False,
4054+
) -> TensorLikeType:
4055+
return prims.empty_permuted(
4056+
shape,
4057+
physical_layout,
4058+
dtype=dtype,
4059+
layout=layout,
4060+
device=device,
4061+
pin_memory=pin_memory,
4062+
requires_grad=requires_grad,
4063+
)
4064+
4065+
40454066
@register_decomposition(aten.new_empty)
40464067
def new_empty(
40474068
a: TensorLikeType,

torch/_torch_docs.py

+45
Original file line numberDiff line numberDiff line change
@@ -12353,6 +12353,51 @@ def merge_dicts(*dicts):
1235312353
),
1235412354
)
1235512355

12356+
add_docstr(
12357+
torch.empty_permuted,
12358+
r"""
12359+
empty_permuted(size, physical_layout, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor
12360+
12361+
Creates an uninitialized, non-overlapping and dense tensor with the
12362+
specified :attr:`size`, with :attr:`physical_layout` specifying how the
12363+
dimensions are physically laid out in memory (each logical dimension is listed
12364+
from outermost to innermost). :attr:`physical_layout` is a generalization
12365+
of NCHW/NHWC notation: if each dimension is assigned a number according to
12366+
what order they occur in size (N=0, C=1, H=2, W=3), then NCHW is ``(0, 1, 2, 3)``
12367+
while NHWC is ``(0, 2, 3, 1)``. Equivalently, the strides of the output
12368+
tensor ``t`` are such that ``t.stride(physical_layout[i]) == contiguous_strides[i]``
12369+
(notably, this function is *not* equivalent to ``torch.empty(size).permute(physical_layout)``).
12370+
12371+
Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense
12372+
tensor with no overlaps. If possible, prefer using this function over
12373+
:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
12374+
12375+
Args:
12376+
size (tuple of int): the shape of the output tensor
12377+
physical_layout (tuple of int): the ordering of dimensions physically in memory
12378+
12379+
Keyword args:
12380+
{dtype}
12381+
{layout}
12382+
{device}
12383+
{requires_grad}
12384+
{pin_memory}
12385+
12386+
Examples:
12387+
12388+
>>> torch.empty((2, 3, 5, 7)).stride()
12389+
(105, 35, 7, 1)
12390+
>>> torch.empty_permuted((2, 3, 5, 7), (0, 1, 2, 3)).stride()
12391+
(105, 35, 7, 1)
12392+
>>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).stride()
12393+
(105, 1, 21, 3)
12394+
>>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride()
12395+
(105, 1, 21, 3)
12396+
""".format(
12397+
**factory_common_args
12398+
),
12399+
)
12400+
1235612401
add_docstr(
1235712402
torch.full,
1235812403
r"""

torch/overrides.py

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def get_ignored_functions() -> Set[Callable]:
144144
torch.cudnn_grid_sampler,
145145
torch.cudnn_is_acceptable,
146146
torch.empty,
147+
torch.empty_permuted,
147148
torch.empty_strided,
148149
torch.empty_quantized,
149150
torch.eye,

torch/testing/_internal/common_methods_invocations.py

+69
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,33 @@ def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs):
15671567
for case in cases:
15681568
yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad)
15691569

1570+
def sample_inputs_empty_permuted(op, device, dtype, requires_grad, **kwargs):
1571+
# shape
1572+
cases = (
1573+
(), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1),
1574+
)
1575+
1576+
for case in cases:
1577+
for layout in itertools.permutations(range(len(case))):
1578+
yield SampleInput(case, layout, device=device, dtype=dtype, requires_grad=requires_grad)
1579+
1580+
def error_inputs_empty_permuted(op_info, device, **kwargs):
1581+
yield ErrorInput(
1582+
SampleInput((2,), args=((0, 1),)),
1583+
error_type=RuntimeError,
1584+
error_regex="Number of dimensions in size does not match the length of the physical_layout"
1585+
)
1586+
yield ErrorInput(
1587+
SampleInput((2,), args=((3,),)),
1588+
error_type=RuntimeError,
1589+
error_regex="Dimension out of range"
1590+
)
1591+
yield ErrorInput(
1592+
SampleInput((2, 3), args=((0, 0),)),
1593+
error_type=RuntimeError,
1594+
error_regex="Duplicate dim not allowed"
1595+
)
1596+
15701597
def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs):
15711598
# Not including a scalar tensor in vals because meta tests start failing due to
15721599
# lack of meta support for _local_scalar_dense
@@ -15751,6 +15778,48 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
1575115778
# UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
1575215779
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
1575315780
)),
15781+
OpInfo('empty_permuted',
15782+
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
15783+
sample_inputs_func=sample_inputs_empty_permuted,
15784+
error_inputs_func=error_inputs_empty_permuted,
15785+
supports_out=False,
15786+
supports_autograd=False,
15787+
skips=(
15788+
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
15789+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15790+
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
15791+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15792+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
15793+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15794+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
15795+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15796+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
15797+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15798+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
15799+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15800+
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
15801+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15802+
DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
15803+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15804+
DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
15805+
# Empty tensor data is garbage so it's hard to make comparisons with it.
15806+
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
15807+
DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 'TestCompositeCompliance',
15808+
'test_operator'),
15809+
# requires_grad doesn't exist in the jit schema
15810+
DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
15811+
DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
15812+
'TestCommon',
15813+
'test_out'),
15814+
DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
15815+
'TestCommon',
15816+
'test_out_warning'),
15817+
DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
15818+
'TestLazyOpInfo'),
15819+
DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
15820+
'TestCommon', 'test_complex_half_reference_testing'),
15821+
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
15822+
)),
1575415823
OpInfo('scalar_tensor',
1575515824
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
1575615825
sample_inputs_func=sample_inputs_scalar_tensor,

torch/utils/_device.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def _device_constructors():
88
return {
99
# standard ones
1010
torch.empty,
11+
torch.empty_permuted,
1112
torch.empty_strided,
1213
torch.empty_quantized,
1314
torch.ones,

0 commit comments

Comments
 (0)