|
3 | 3 | #include "llvm/Support/Debug.h"
|
4 | 4 | #include "llvm/Support/raw_ostream.h"
|
5 | 5 |
|
| 6 | +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" |
6 | 7 | #include "triton/Analysis/AxisInfo.h"
|
7 | 8 | #include "triton/Dialect/Triton/IR/Dialect.h"
|
8 | 9 | #include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
@@ -1019,6 +1020,73 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
|
1019 | 1020 | }
|
1020 | 1021 | };
|
1021 | 1022 |
|
| 1023 | +class ExtractSliceOpInfoVisitor final |
| 1024 | + : public AxisInfoVisitorImpl<triton::amdgpu::ExtractSliceOp> { |
| 1025 | +public: |
| 1026 | + using AxisInfoVisitorImpl< |
| 1027 | + triton::amdgpu::ExtractSliceOp>::AxisInfoVisitorImpl; |
| 1028 | + |
| 1029 | + AxisInfo |
| 1030 | + getAxisInfo(triton::amdgpu::ExtractSliceOp op, |
| 1031 | + ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override { |
| 1032 | + |
| 1033 | + auto srcType = cast<RankedTensorType>(op.getOperand().getType()); |
| 1034 | + auto srcShape = srcType.getShape(); |
| 1035 | + |
| 1036 | + auto dstType = cast<RankedTensorType>(op.getResult().getType()); |
| 1037 | + auto dstShape = dstType.getShape(); |
| 1038 | + |
| 1039 | + auto offsets = op.getStaticOffsets(); |
| 1040 | + |
| 1041 | + AxisInfo opInfo = operands[0]->getValue(); |
| 1042 | + auto origContiguity = opInfo.getContiguity(); |
| 1043 | + auto origDivisibility = opInfo.getDivisibility(); |
| 1044 | + auto origConstancy = opInfo.getConstancy(); |
| 1045 | + |
| 1046 | + auto recompute = [](ArrayRef<int64_t> vec, int64_t c) { |
| 1047 | + auto result = std::numeric_limits<int64_t>::max(); |
| 1048 | + for (auto &v : vec) { |
| 1049 | + // compute the upper bound of `v` based on `contiguity` |
| 1050 | + auto newC = ((v + c - 1) / c) * c - v; |
| 1051 | + // make sure that the new value is not broken because |
| 1052 | + // of the sliced boundaries |
| 1053 | + newC = newC == 0 ? c : newC; |
| 1054 | + |
| 1055 | + // conside the minumal value along each dimension |
| 1056 | + result = result > newC ? newC : result; |
| 1057 | + } |
| 1058 | + assert(vec.size() == 2); |
| 1059 | + const auto dimSize = vec[1] - vec[0]; |
| 1060 | + |
| 1061 | + // make sure that the value doesn't exceed the dimension size |
| 1062 | + return result > dimSize ? dimSize : result; |
| 1063 | + }; |
| 1064 | + |
| 1065 | + SmallVector<int64_t> contiguity(origContiguity.size()); |
| 1066 | + SmallVector<int64_t> divisibility(opInfo.getDivisibility().size()); |
| 1067 | + SmallVector<int64_t> constancy(opInfo.getConstancy().size()); |
| 1068 | + |
| 1069 | + for (size_t dim = 0; dim < opInfo.getRank(); ++dim) { |
| 1070 | + auto start = offsets[dim]; |
| 1071 | + auto end = start + dstShape[dim]; |
| 1072 | + contiguity[dim] = recompute({start, end}, origContiguity[dim]); |
| 1073 | + // note: contiguity cannot increase while slicing a tensor |
| 1074 | + assert(contiguity[dim] <= origContiguity[dim]); |
| 1075 | + |
| 1076 | + constancy[dim] = recompute({start, end}, origConstancy[dim]); |
| 1077 | + |
| 1078 | + // note: preserve divisibility if contiguity stays the same. |
| 1079 | + // otherwise, set it to 1 b/c we are not able to infer this property from |
| 1080 | + // the given data |
| 1081 | + divisibility[dim] = |
| 1082 | + contiguity[dim] == origContiguity[dim] ? origDivisibility[dim] : 1; |
| 1083 | + } |
| 1084 | + |
| 1085 | + return AxisInfo(contiguity, divisibility, constancy, |
| 1086 | + opInfo.getConstantValue()); |
| 1087 | + } |
| 1088 | +}; |
| 1089 | + |
1022 | 1090 | //===----------------------------------------------------------------------===//
|
1023 | 1091 | // AxisInfoAnalysis
|
1024 | 1092 | //===----------------------------------------------------------------------===//
|
@@ -1062,6 +1130,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
1062 | 1130 | MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
|
1063 | 1131 | MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
|
1064 | 1132 | visitors.append<LoadOpAxisInfoVisitor>();
|
| 1133 | + visitors.append<ExtractSliceOpInfoVisitor>(); |
1065 | 1134 | }
|
1066 | 1135 |
|
1067 | 1136 | LogicalResult AxisInfoAnalysis::visitOperation(
|
|
0 commit comments