Skip to content

Commit 4f5080d

Browse files
Merge branch 'main' into LpNormalization
2 parents 4022098 + c1c8638 commit 4f5080d

File tree

3 files changed

+67
-33
lines changed

3 files changed

+67
-33
lines changed

src/Accelerators/NNPA/Dialect/ZLow/ZLow.td

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ def ZLowInvSqrtOp:ZLow_Op<"invsqrt", [MemRefsNormalizable]> {
132132
let description = [{
133133
ZLow operation to perform a invsqrt.
134134
}];
135-
let arguments = (ins ZMemRef:$X,
136-
MemRefOf<[I64]>:$shape,
137-
ZMemRef:$Out,
135+
let arguments = (ins Arg<ZMemRef, "", [MemRead]>:$X,
136+
Arg<MemRefOf<[I64]>, "", [MemRead]>:$shape,
137+
Arg<ZMemRef, "", [MemWrite]>:$Out,
138138
StrAttr:$layout);
139139
}
140140

@@ -170,9 +170,9 @@ def ZLowLeakyReluOp:ZLow_Op<"leakyrelu", [MemRefsNormalizable]> {
170170
let description = [{
171171
ZLow operation to perform a leakyrelu.
172172
}];
173-
let arguments = (ins ZMemRef:$X,
174-
MemRefOf<[I64]>:$shape,
175-
ZMemRef:$Out,
173+
let arguments = (ins Arg<ZMemRef, "", [MemRead]>:$X,
174+
Arg<MemRefOf<[I64]>, "", [MemRead]>:$shape,
175+
Arg<ZMemRef, "", [MemWrite]>:$Out,
176176
DefaultValuedAttr<F32Attr, "0.01">:$alpha,
177177
StrAttr:$layout);
178178
}
@@ -194,9 +194,9 @@ def ZLowGeluOp:ZLow_Op<"gelu", [MemRefsNormalizable]> {
194194
let description = [{
195195
ZLow operation to perform a gelu.
196196
}];
197-
let arguments = (ins ZMemRef:$X,
198-
MemRefOf<[I64]>:$shape,
199-
ZMemRef:$Out,
197+
let arguments = (ins Arg<ZMemRef, "", [MemRead]>:$X,
198+
Arg<MemRefOf<[I64]>, "", [MemRead]>:$shape,
199+
Arg<ZMemRef, "", [MemWrite]>:$Out,
200200
StrAttr:$layout);
201201
}
202202

@@ -244,9 +244,9 @@ def ZLowSqrtOp:ZLow_Op<"sqrt", [MemRefsNormalizable]> {
244244
let description = [{
245245
ZLow operation to perform a sqrt.
246246
}];
247-
let arguments = (ins ZMemRef:$X,
248-
MemRefOf<[I64]>:$shape,
249-
ZMemRef:$Out,
247+
let arguments = (ins Arg<ZMemRef, "", [MemRead]>:$X,
248+
Arg<MemRefOf<[I64]>, "", [MemRead]>:$shape,
249+
Arg<ZMemRef, "", [MemWrite]>:$Out,
250250
StrAttr:$layout);
251251
}
252252

@@ -255,10 +255,10 @@ def ZLowReduceMaxOp:ZLow_Op<"reducemax", [MemRefsNormalizable]> {
255255
let description = [{
256256
ZLow operation to perform a reducemax.
257257
}];
258-
let arguments = (ins ZMemRef:$X,
259-
MemRefOf<[I8]>:$work_area,
260-
MemRefOf<[I64]>:$shape,
261-
ZMemRef:$Out,
258+
let arguments = (ins Arg<ZMemRef, "", [MemRead]>:$X,
259+
Arg<MemRefOf<[I8]>, "", [MemRead]>:$work_area,
260+
Arg<MemRefOf<[I64]>, "", [MemRead]>:$shape,
261+
Arg<ZMemRef, "", [MemWrite]>:$Out,
262262
StrAttr:$layout);
263263
}
264264

@@ -267,10 +267,10 @@ def ZLowReduceMinOp:ZLow_Op<"reducemin", [MemRefsNormalizable]> {
267267
let description = [{
268268
ZLow operation to perform a reducemin.
269269
}];
270-
let arguments = (ins ZMemRef:$X,
271-
MemRefOf<[I8]>:$work_area,
272-
MemRefOf<[I64]>:$shape,
273-
ZMemRef:$Out,
270+
let arguments = (ins Arg<ZMemRef, "", [MemRead]>:$X,
271+
Arg<MemRefOf<[I8]>, "", [MemRead]>:$work_area,
272+
Arg<MemRefOf<[I64]>, "", [MemRead]>:$shape,
273+
Arg<ZMemRef, "", [MemWrite]>:$Out,
274274
StrAttr:$layout);
275275
}
276276

@@ -335,12 +335,20 @@ def ZLowQuantizedMatMulOp:ZLow_Op<"quantizedMatmul", [MemRefsNormalizable]> {
335335
Values for `q_type` are "DLFLOAT16", "INT8", "WEIGHTS", "UNDEFINED".
336336

337337
}];
338-
let arguments = (ins ZQMemRef:$X, ODMemRefF32:$x_rec_scale, ODMemRefF32:$x_offset,
339-
ZQMemRef:$Y, ODMemRefF32:$y_rec_scale, ODMemRefF32:$y_offset,
340-
ZQMemRef:$Bias, ODMemRefF32:$bias_rec_scale, ODMemRefF32:$bias_offset,
341-
AnyTypeOf<[ZQMemRef, NoneType]>:$work_area,
342-
MemRefOf<[I64]>:$shape,
343-
ZQMemRef:$Out, ODMemRefF32:$out_rec_scale, ODMemRefF32:$out_offset,
338+
let arguments = (ins Arg<ZQMemRef, "", [MemRead]>:$X,
339+
Arg<ODMemRefF32, "", [MemRead]>:$x_rec_scale,
340+
Arg<ODMemRefF32, "", [MemRead]>:$x_offset,
341+
Arg<ZQMemRef, "", [MemRead]>:$Y,
342+
Arg<ODMemRefF32, "", [MemRead]>:$y_rec_scale,
343+
Arg<ODMemRefF32, "", [MemRead]>:$y_offset,
344+
Arg<ZQMemRef, "", [MemRead]>:$Bias,
345+
Arg<ODMemRefF32, "", [MemRead]>:$bias_rec_scale,
346+
Arg<ODMemRefF32, "", [MemRead]>:$bias_offset,
347+
Arg<AnyTypeOf<[ZQMemRef, NoneType]>, "", [MemRead, MemWrite]>:$work_area,
348+
Arg<MemRefOf<[I64]>, "", [MemRead]>:$shape,
349+
Arg<ZQMemRef, "", [MemWrite]>:$Out,
350+
Arg<ODMemRefF32, "", [MemWrite]>:$out_rec_scale,
351+
Arg<ODMemRefF32, "", [MemWrite]>:$out_offset,
344352
StrAttr:$x_q_type,
345353
StrAttr:$y_q_type,
346354
StrAttr:$bias_q_type,
@@ -476,10 +484,10 @@ def ZLowQuantizedStickOp:ZLow_Op<"quantizedStick", [MemRefsNormalizable]> {
476484
"ZLow operation to perform a quantization stick."
477485
"Type is one of values: dlfloat16, int8, and weights."
478486
}];
479-
let arguments = (ins MemRefOf<[I8, F32]>:$X,
480-
MemRefRankOf<[F32], [0]>:$rec_scale,
481-
MemRefRankOf<[F32], [0]>:$offset,
482-
ZQMemRef:$out,
487+
let arguments = (ins Arg<MemRefOf<[I8, F32]>, "", [MemRead]>:$X,
488+
Arg<MemRefRankOf<[F32], [0]>, "", [MemRead]>:$rec_scale,
489+
Arg<MemRefRankOf<[F32], [0]>, "", [MemRead]>:$offset,
490+
Arg<ZQMemRef, "", [MemWrite]>:$out,
483491
StrAttr:$layout,
484492
StrAttr:$q_type);
485493
let hasVerifier = 1;

src/Dialect/Krnl/Krnl.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ def KrnlRandomNormalOp : Op<Krnl_Dialect, "random_normal",
13821382
Operation that generates a random normally distributed tensor.
13831383
}];
13841384

1385-
let arguments = (ins AnyTypeOf<[AnyMemRef]>:$output,
1385+
let arguments = (ins Arg<AnyTypeOf<[AnyMemRef]>, "output of the random tensor", [MemWrite]>:$output,
13861386
Index:$numberOfValues,
13871387
AnyFloat:$mean,
13881388
AnyFloat:$scale,
@@ -1423,7 +1423,7 @@ def KrnlNoneOp : Op<Krnl_Dialect, "noValue"> {
14231423
}]>];
14241424
}
14251425

1426-
def KrnlPrintTensorOp : Op<Krnl_Dialect, "print_tensor", [MemRefsNormalizable]> {
1426+
def KrnlPrintTensorOp : Op<Krnl_Dialect, "print_tensor", [MemRefsNormalizable, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
14271427
let summary = "Print a tensor.";
14281428
let description = [{
14291429
This operation can be used to generate a call to a runtime function which prints a tensor.
@@ -1440,7 +1440,7 @@ def KrnlPrintTensorOp : Op<Krnl_Dialect, "print_tensor", [MemRefsNormalizable]>
14401440
let arguments = (ins StrAttr:$msg, AnyMemRef:$input);
14411441
}
14421442

1443-
def KrnlPrintOp : Op<Krnl_Dialect, "print", [MemRefsNormalizable]> {
1443+
def KrnlPrintOp : Op<Krnl_Dialect, "print", [MemRefsNormalizable, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
14441444
let summary = "Print a value.";
14451445
let description = [{
14461446
This operation can be used to print the input value. The user needs to provide a

src/Dialect/Krnl/KrnlOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,32 @@ void KrnlInstrumentOp::getEffects(
665665
MemoryEffects::Write::get(), SideEffects::DefaultResource::get());
666666
}
667667

668+
//===----------------------------------------------------------------------===//
669+
// KrnlPrintTensorOp
670+
//===----------------------------------------------------------------------===//
671+
672+
void KrnlPrintTensorOp::getEffects(
673+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
674+
&effects) {
675+
676+
effects.emplace_back(MemoryEffects::Read::get(), &getInputMutable());
677+
678+
effects.emplace_back(
679+
MemoryEffects::Write::get(), SideEffects::DefaultResource::get());
680+
}
681+
682+
//===----------------------------------------------------------------------===//
683+
// KrnlPrintOp
684+
//===----------------------------------------------------------------------===//
685+
686+
void KrnlPrintOp::getEffects(
687+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
688+
&effects) {
689+
690+
effects.emplace_back(
691+
MemoryEffects::Write::get(), SideEffects::DefaultResource::get());
692+
}
693+
668694
//===----------------------------------------------------------------------===//
669695
// KrnlBlockOp
670696
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)