Skip to content

Commit 91ac21d

Browse files
committed
[AMD] Extended AxisAnalysis to account ExtractSliceOps
1 parent d5d09d0 commit 91ac21d

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llvm/Support/Debug.h"
44
#include "llvm/Support/raw_ostream.h"
55

6+
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
67
#include "triton/Analysis/AxisInfo.h"
78
#include "triton/Dialect/Triton/IR/Dialect.h"
89
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -1019,6 +1020,79 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
10191020
}
10201021
};
10211022

1023+
class ExtractSliceOpInfoVisitor final
1024+
: public AxisInfoVisitorImpl<triton::amdgpu::ExtractSliceOp> {
1025+
public:
1026+
using AxisInfoVisitorImpl<
1027+
triton::amdgpu::ExtractSliceOp>::AxisInfoVisitorImpl;
1028+
1029+
AxisInfo
1030+
getAxisInfo(triton::amdgpu::ExtractSliceOp op,
1031+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
1032+
1033+
auto srcType = cast<RankedTensorType>(op.getOperand().getType());
1034+
auto srcShape = srcType.getShape();
1035+
1036+
auto dstType = cast<RankedTensorType>(op.getResult().getType());
1037+
auto dstShape = dstType.getShape();
1038+
1039+
auto offsets = op.getStaticOffsets();
1040+
1041+
AxisInfo opInfo = operands[0]->getValue();
1042+
auto origContiguity = opInfo.getContiguity();
1043+
auto origDivisibility = opInfo.getDivisibility();
1044+
auto origConstancy = opInfo.getConstancy();
1045+
1046+
auto recompute = [](ArrayRef<int64_t> vec, int64_t c) {
1047+
auto result = std::numeric_limits<int64_t>::max();
1048+
for (auto &v : vec) {
1049+
// compute the upper bound of `v` based on `contiguity`
1050+
auto newC = ((v + c - 1) / c) * c - v;
1051+
// make sure that the new value is not broken because
1052+
// of the sliced boundaries
1053+
newC = newC == 0 ? c : newC;
1054+
1055+
// conside the minumal value along each dimension
1056+
result = result > newC ? newC : result;
1057+
}
1058+
assert(vec.size() == 2);
1059+
const auto dimSize = vec[1] - vec[0];
1060+
1061+
// make sure that the value doesn't exceed the dimension size
1062+
return result > dimSize ? dimSize : result;
1063+
};
1064+
1065+
SmallVector<int64_t> contiguity(origContiguity.size());
1066+
SmallVector<int64_t> divisibility(opInfo.getDivisibility().size());
1067+
SmallVector<int64_t> constancy(opInfo.getConstancy().size());
1068+
1069+
for (size_t dim = 0; dim < opInfo.getRank(); ++dim) {
1070+
auto start = offsets[dim];
1071+
auto end = start + dstShape[dim];
1072+
contiguity[dim] = recompute({start, end}, origContiguity[dim]);
1073+
// note: contiguity cannot increase while slicing a tensor
1074+
assert(contiguity[dim] <= origContiguity[dim]);
1075+
1076+
constancy[dim] = recompute({start, end}, origConstancy[dim]);
1077+
1078+
divisibility[dim] = origDivisibility[dim];
1079+
if (contiguity[dim] != origContiguity[dim]) {
1080+
// note: assume n is the largest power of two that divides `x` and `x +
1081+
// c`
1082+
// 1. x % n = 0 and 2. (x + c) % n = 0
1083+
// reminder of a sum can be calculated as: 3. (x + c) % n = (x % n + c %
1084+
// n) % n = 0 becuase of 1. one can write 4. (c % n) % n or 5. c % n = 0
1085+
divisibility[dim] = std::min(
1086+
origDivisibility[dim],
1087+
int64_t(log2Int(highestPowOf2Divisor<int64_t>(contiguity[dim]))));
1088+
}
1089+
}
1090+
1091+
return AxisInfo(contiguity, divisibility, constancy,
1092+
opInfo.getConstantValue());
1093+
}
1094+
};
1095+
10221096
//===----------------------------------------------------------------------===//
10231097
// AxisInfoAnalysis
10241098
//===----------------------------------------------------------------------===//
@@ -1062,6 +1136,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10621136
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
10631137
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
10641138
visitors.append<LoadOpAxisInfoVisitor>();
1139+
visitors.append<ExtractSliceOpInfoVisitor>();
10651140
}
10661141

10671142
LogicalResult AxisInfoAnalysis::visitOperation(

lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ add_triton_library(TritonAnalysis
1717
TritonIR
1818
TritonGPUIR
1919
TritonNvidiaGPUIR
20+
TritonAMDGPUIR
2021
)

0 commit comments

Comments
 (0)