Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit 02173da

Browse files
committed
scale and zp as operands
1 parent 35f4fba commit 02173da

File tree

2 files changed

+36
-57
lines changed

2 files changed

+36
-57
lines changed

lib/Conversion/TorchToTcp/TcpCustomOp.cpp

+20-33
Original file line numberDiff line numberDiff line change
@@ -166,33 +166,28 @@ class ConvertAtenFakeQuantizePerTensorAffineTensorQparamsOp
166166
// scale should be a [1] tensor.
167167
if (!scaleElements || scaleElements.getNumElements() != 1)
168168
return rewriter.notifyMatchFailure(op, "Unsupported scale type or size");
169-
auto scale = (*scaleElements.begin()).convertToDouble();
170-
helper.addDenseFloatArrayAttr("scale", {scale});
169+
helper.addOperand("scale", adaptor.getScale());
171170

172171
// zero_point
173172
auto zeroPointOp = op.getZeroPoint().getDefiningOp();
174-
int64_t zeroPoint;
175173
if (!zeroPointOp)
176174
return rewriter.notifyMatchFailure(op, "Missing zero point operation");
177-
if (dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) ||
178-
dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
179-
zeroPoint = 0;
180-
} else {
181-
auto zeroPointTensor =
182-
dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp);
183-
if (!zeroPointTensor)
184-
return rewriter.notifyMatchFailure(
185-
op, "Zero point operation is not ValueTensorLiteralOp or Zero "
186-
"operation");
175+
if (auto zeroPointTensor =
176+
dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp)) {
187177
auto zeroPointElements =
188178
dyn_cast<DenseIntElementsAttr>(zeroPointTensor.getValueAttr());
189179
// zero_point should be a [1] tensor.
190180
if (!zeroPointElements || zeroPointElements.getNumElements() != 1)
191181
return rewriter.notifyMatchFailure(
192182
op, "Unsupported zero point type or size");
193-
zeroPoint = (*zeroPointElements.begin()).getSExtValue();
183+
} else if (!dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) &&
184+
!dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
185+
// zero like operations are converted through torch-to-tcp
186+
return rewriter.notifyMatchFailure(
187+
op, "Zero point operation is not ValueTensorLiteralOp or Zero "
188+
"operation");
194189
}
195-
helper.addDenseIntArrayAttr("zero_point", {zeroPoint});
190+
helper.addOperand("zero_point", adaptor.getZeroPoint());
196191

197192
return helper.replace();
198193
}
@@ -226,37 +221,29 @@ class ConvertAtenFakeQuantizePerChannelAffineOp
226221
// scale should be a [C] tensor.
227222
if (!scaleElements || scaleElements.getType().getShape().size() != 1)
228223
return rewriter.notifyMatchFailure(op, "Unsupported scale type or size");
229-
SmallVector<double> scale;
230-
for (auto val : scaleElements.getValues<APFloat>())
231-
scale.push_back(val.convertToDouble());
232-
helper.addDenseFloatArrayAttr("scale", scale);
224+
helper.addOperand("scale", adaptor.getScale());
233225

234226
// zero_point
235227
auto zeroPointOp = op.getZeroPoint().getDefiningOp();
236-
SmallVector<int64_t> zeroPoint;
237228
if (!zeroPointOp)
238229
return rewriter.notifyMatchFailure(op, "Missing zero point operation");
239-
if (dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) ||
240-
dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
241-
zeroPoint.assign(scale.size(), 0);
242-
} else {
243-
auto zeroPointTensor =
244-
dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp);
245-
if (!zeroPointTensor)
246-
return rewriter.notifyMatchFailure(
247-
op, "Zero point operation is not ValueTensorLiteralOp or Zero "
248-
"operation");
230+
if (auto zeroPointTensor =
231+
dyn_cast<torch::Torch::ValueTensorLiteralOp>(zeroPointOp)) {
249232
auto zeroPointElements =
250233
dyn_cast<DenseIntElementsAttr>(zeroPointTensor.getValueAttr());
251234
// zero_point should be a [C] tensor.
252235
if (!zeroPointElements ||
253236
zeroPointElements.getType().getShape().size() != 1)
254237
return rewriter.notifyMatchFailure(
255238
op, "Unsupported zero point type or size");
256-
for (auto val : zeroPointElements.getValues<APInt>())
257-
zeroPoint.push_back(val.getSExtValue());
239+
} else if (!dyn_cast<torch::Torch::AtenZerosOp>(zeroPointOp) &&
240+
!dyn_cast<torch::Torch::AtenZerosLikeOp>(zeroPointOp)) {
241+
// zero like operations are converted through torch-to-tcp
242+
return rewriter.notifyMatchFailure(
243+
op, "Zero point operation is not ValueTensorLiteralOp or Zero "
244+
"operation");
258245
}
259-
helper.addDenseIntArrayAttr("zero_point", zeroPoint);
246+
helper.addOperand("zero_point", adaptor.getZeroPoint());
260247

261248
return helper.replace();
262249
}

test/Conversion/TorchToTcp/tcp_custom_ops.mlir

+16-24
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,11 @@ func.func @torch.aten.fake_quantize_per_tensor_affine(%input: !torch.vtensor<[1,
162162
// CHECK-LABEL: func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams(
163163
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,64,32,32],f32>) -> !torch.vtensor<[1,64,32,32],f32>
164164
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,64,32,32],f32> -> tensor<1x64x32x32xf32>
165-
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.fake_quantize_per_tensor_affine.tensor_qparams") %[[T0]] {
165+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.fake_quantize_per_tensor_affine.tensor_qparams") %[[T0]], %{{.*}}, %{{.*}} {
166166
// CHECK-SAME: quant_max = 255 : i64,
167167
// CHECK-SAME: quant_min = 0 : i64,
168-
// CHECK-SAME: scale = array<f64: 0.039370078593492508>,
169-
// CHECK-SAME: torch_operand_names = ["self"],
170-
// CHECK-SAME: zero_point = array<i64: 2>}
171-
// CHECK-SAME: tensor<1x64x32x32xf32> -> tensor<1x64x32x32xf32>
168+
// CHECK-SAME: torch_operand_names = ["self", "scale", "zero_point"]} :
169+
// CHECK-SAME: tensor<1x64x32x32xf32>, tensor<1xf32>, tensor<1xi32> -> tensor<1x64x32x32xf32>
172170
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<1x64x32x32xf32> -> !torch.vtensor<[1,64,32,32],f32>
173171
// CHECK: return %[[RES]] : !torch.vtensor<[1,64,32,32],f32>
174172
func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams(%input: !torch.vtensor<[1,64,32,32],f32>) -> !torch.vtensor<[1,64,32,32],f32> {
@@ -185,13 +183,11 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams(%input: !to
185183
// CHECK-LABEL: func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams_zero(
186184
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,64,32,32],f32>) -> !torch.vtensor<[1,64,32,32],f32>
187185
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,64,32,32],f32> -> tensor<1x64x32x32xf32>
188-
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.fake_quantize_per_tensor_affine.tensor_qparams") %[[T0]] {
186+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.fake_quantize_per_tensor_affine.tensor_qparams") %[[T0]], %{{.*}}, %{{.*}} {
189187
// CHECK-SAME: quant_max = 255 : i64,
190188
// CHECK-SAME: quant_min = 0 : i64,
191-
// CHECK-SAME: scale = array<f64: 0.039370078593492508>,
192-
// CHECK-SAME: torch_operand_names = ["self"],
193-
// CHECK-SAME: zero_point = array<i64: 0>}
194-
// CHECK-SAME: tensor<1x64x32x32xf32> -> tensor<1x64x32x32xf32>
189+
// CHECK-SAME: torch_operand_names = ["self", "scale", "zero_point"]} :
190+
// CHECK-SAME: tensor<1x64x32x32xf32>, tensor<1xf32>, tensor<1xi32> -> tensor<1x64x32x32xf32>
195191
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<1x64x32x32xf32> -> !torch.vtensor<[1,64,32,32],f32>
196192
// CHECK: return %[[RES]] : !torch.vtensor<[1,64,32,32],f32>
197193
func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams_zero(%input: !torch.vtensor<[1,64,32,32],f32>) -> !torch.vtensor<[1,64,32,32],f32> {
@@ -202,10 +198,10 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams_zero(%input
202198
%5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
203199
%cuda3A0 = torch.constant.device "cuda:0"
204200
%false = torch.constant.bool false
205-
%zero_point = torch.aten.zeros %5, %int3, %none, %cuda3A0, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.tensor
201+
%zero_point = torch.aten.zeros %5, %int3, %none, %cuda3A0, %false : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[1],si32>
206202
%int0 = torch.constant.int 0
207203
%int255 = torch.constant.int 255
208-
%output = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %input, %scale, %zero_point, %int0, %int255 : !torch.vtensor<[1,64,32,32],f32>, !torch.vtensor<[1],f32>, !torch.tensor, !torch.int, !torch.int -> !torch.vtensor<[1,64,32,32],f32>
204+
%output = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %input, %scale, %zero_point, %int0, %int255 : !torch.vtensor<[1,64,32,32],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[1,64,32,32],f32>
209205
return %output : !torch.vtensor<[1,64,32,32],f32>
210206
}
211207

@@ -214,14 +210,12 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams_zero(%input
214210
// CHECK-LABEL: func.func @torch.aten.fake_quantize_per_channel_affine(
215211
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32>
216212
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,3,32,32],f32> -> tensor<1x3x32x32xf32>
217-
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.fake_quantize_per_channel_affine") %[[T0]] {
213+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.fake_quantize_per_channel_affine") %[[T0]], %{{.*}}, %{{.*}} {
218214
// CHECK-SAME: axis = 1 : i64,
219215
// CHECK-SAME: quant_max = 255 : i64,
220216
// CHECK-SAME: quant_min = 0 : i64,
221-
// CHECK-SAME: scale = array<f64: 0.039370078593492508, 0.039370078593492508, 0.039370078593492508>,
222-
// CHECK-SAME: torch_operand_names = ["self"],
223-
// CHECK-SAME: zero_point = array<i64: 2, 2, 2>}
224-
// CHECK-SAME: tensor<1x3x32x32xf32> -> tensor<1x3x32x32xf32>
217+
// CHECK-SAME: torch_operand_names = ["self", "scale", "zero_point"]} :
218+
// CHECK-SAME: tensor<1x3x32x32xf32>, tensor<3xf32>, tensor<3xi32> -> tensor<1x3x32x32xf32>
225219
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<1x3x32x32xf32> -> !torch.vtensor<[1,3,32,32],f32>
226220
// CHECK: return %[[RES]] : !torch.vtensor<[1,3,32,32],f32>
227221
func.func @torch.aten.fake_quantize_per_channel_affine(%input: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> {
@@ -239,14 +233,12 @@ func.func @torch.aten.fake_quantize_per_channel_affine(%input: !torch.vtensor<[1
239233
// CHECK-LABEL: func.func @torch.aten.fake_quantize_per_channel_affine_zero_like(
240234
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32>
241235
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,3,32,32],f32> -> tensor<1x3x32x32xf32>
242-
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.fake_quantize_per_channel_affine") %[[T0]] {
236+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.fake_quantize_per_channel_affine") %[[T0]], %{{.*}}, %{{.*}} {
243237
// CHECK-SAME: axis = 1 : i64,
244238
// CHECK-SAME: quant_max = 255 : i64,
245239
// CHECK-SAME: quant_min = 0 : i64,
246-
// CHECK-SAME: scale = array<f64: 0.039370078593492508, 0.039370078593492508, 0.039370078593492508>,
247-
// CHECK-SAME: torch_operand_names = ["self"],
248-
// CHECK-SAME: zero_point = array<i64: 0, 0, 0>}
249-
// CHECK-SAME: tensor<1x3x32x32xf32> -> tensor<1x3x32x32xf32>
240+
// CHECK-SAME: torch_operand_names = ["self", "scale", "zero_point"]} :
241+
// CHECK-SAME: tensor<1x3x32x32xf32>, tensor<3xf32>, tensor<3xi32> -> tensor<1x3x32x32xf32>
250242
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<1x3x32x32xf32> -> !torch.vtensor<[1,3,32,32],f32>
251243
// CHECK: return %[[RES]] : !torch.vtensor<[1,3,32,32],f32>
252244
func.func @torch.aten.fake_quantize_per_channel_affine_zero_like(%input: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> {
@@ -258,7 +250,7 @@ func.func @torch.aten.fake_quantize_per_channel_affine_zero_like(%input: !torch.
258250
%none = torch.constant.none
259251
%cuda3A0 = torch.constant.device "cuda:0"
260252
%false = torch.constant.bool false
261-
%zero_point = torch.aten.zeros_like %scale, %int3, %none, %cuda3A0, %false, %none : !torch.vtensor<[3],f32>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.tensor
262-
%output = torch.aten.fake_quantize_per_channel_affine %input, %scale, %zero_point, %int1, %int0, %int255 : !torch.vtensor<[1,3,32,32],f32>, !torch.vtensor<[3],f32>, !torch.tensor, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3,32,32],f32>
253+
%zero_point = torch.aten.zeros_like %scale, %int3, %none, %cuda3A0, %false, %none : !torch.vtensor<[3],f32>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[3],si32>
254+
%output = torch.aten.fake_quantize_per_channel_affine %input, %scale, %zero_point, %int1, %int0, %int255 : !torch.vtensor<[1,3,32,32],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3,32,32],f32>
263255
return %output : !torch.vtensor<[1,3,32,32],f32>
264256
}

0 commit comments

Comments
 (0)