@@ -192,8 +192,8 @@ class AutoDiffBroadcastInDimRev
192
192
AutoDiffBroadcastInDimRev, BroadcastInDimOp> {
193
193
public:
194
194
LogicalResult createReverseModeAdjoint (Operation *orig, OpBuilder &builder,
195
- MGradientUtilsReverse *gutils,
196
- SmallVector<Value> caches) const {
195
+ MGradientUtilsReverse *gutils,
196
+ SmallVector<Value> caches) const {
197
197
auto op = cast<BroadcastInDimOp>(orig);
198
198
auto inTy = op.getOperand ().getType ();
199
199
auto outTy = op.getType ();
@@ -205,16 +205,16 @@ class AutoDiffBroadcastInDimRev
205
205
206
206
SmallVector<int64_t > newDims;
207
207
for (auto en : llvm::enumerate (outTy.getShape ())) {
208
- if (llvm::is_contained (bcastDims, en.index ())) continue ;
208
+ if (llvm::is_contained (bcastDims, en.index ()))
209
+ continue ;
209
210
newDims.push_back (en.index ());
210
211
}
211
212
212
213
Value zero = gutils->getShadowType (inTy)
213
- .cast <AutoDiffTypeInterface>()
214
- .createNullValue (builder, op.getLoc ());
214
+ .cast <AutoDiffTypeInterface>()
215
+ .createNullValue (builder, op.getLoc ());
215
216
216
- auto red = builder.create <ReduceOp>(op.getLoc (),
217
- TypeRange (zero.getType ()),
217
+ auto red = builder.create <ReduceOp>(op.getLoc (), TypeRange (zero.getType ()),
218
218
inDiffe, zero, newDims);
219
219
red.getBody ().push_back (new Block ());
220
220
Block &body = red.getBody ().front ();
@@ -228,9 +228,9 @@ class AutoDiffBroadcastInDimRev
228
228
bodyBuilder.create <ReturnOp>(op.getLoc (), ValueRange (add));
229
229
230
230
Value res = red->getResult (0 );
231
- Type resTy = gutils->getShadowType (op.getOperand ().getType ());
231
+ Type resTy = gutils->getShadowType (op.getOperand ().getType ());
232
232
if (res.getType () != resTy)
233
- res = builder.create <ReshapeOp>(op.getLoc (), resTy, res);
233
+ res = builder.create <ReshapeOp>(op.getLoc (), resTy, res);
234
234
235
235
gutils->addToDiffe (op.getOperand (), res, builder);
236
236
return success ();
@@ -250,8 +250,8 @@ class AutoDiffSliceRev
250
250
SliceOp> {
251
251
public:
252
252
LogicalResult createReverseModeAdjoint (Operation *orig, OpBuilder &builder,
253
- MGradientUtilsReverse *gutils,
254
- SmallVector<Value> caches) const {
253
+ MGradientUtilsReverse *gutils,
254
+ SmallVector<Value> caches) const {
255
255
auto op = cast<SliceOp>(orig);
256
256
auto inTy = op.getOperand ().getType ();
257
257
auto outTy = op.getType ();
@@ -263,21 +263,25 @@ class AutoDiffSliceRev
263
263
SmallVector<int64_t > starts;
264
264
SmallVector<int64_t > edge_padding_high;
265
265
SmallVector<int64_t > interior_padding;
266
- for (auto &&[start, limit, stride, dim] : llvm::zip (
267
- op.getStartIndices (), op.getLimitIndices (), op.getStrides (), inTy.getShape ())) {
266
+ for (auto &&[start, limit, stride, dim] :
267
+ llvm::zip (op.getStartIndices (), op.getLimitIndices (), op.getStrides (),
268
+ inTy.getShape ())) {
268
269
starts.push_back (start);
269
270
edge_padding_high.push_back (dim - limit);
270
271
interior_padding.push_back (stride - 1 );
271
272
}
272
273
273
-
274
- auto zeroPad = RankedTensorType::get ({}, inTy.getElementType ()).cast <AutoDiffTypeInterface>().createNullValue (builder,
275
- op.getLoc ());
276
- auto red = builder.create <stablehlo::PadOp>(op.getLoc (), inDiffe, zeroPad, builder.getDenseI64ArrayAttr (starts), builder.getDenseI64ArrayAttr (edge_padding_high), builder.getDenseI64ArrayAttr (interior_padding));
274
+ auto zeroPad = RankedTensorType::get ({}, inTy.getElementType ())
275
+ .cast <AutoDiffTypeInterface>()
276
+ .createNullValue (builder, op.getLoc ());
277
+ auto red = builder.create <stablehlo::PadOp>(
278
+ op.getLoc (), inDiffe, zeroPad, builder.getDenseI64ArrayAttr (starts),
279
+ builder.getDenseI64ArrayAttr (edge_padding_high),
280
+ builder.getDenseI64ArrayAttr (interior_padding));
277
281
278
282
gutils->addToDiffe (op.getOperand (), red->getResult (0 ), builder);
279
283
return success ();
280
- #if 0
284
+ #if 0
281
285
282
286
Value idxs;
283
287
{
@@ -351,7 +355,7 @@ class AutoDiffSliceRev
351
355
// gutils->setDiffe(op.getOperand(), red->getResult(0), builder);
352
356
353
357
return success();
354
- #endif
358
+ #endif
355
359
}
356
360
357
361
SmallVector<Value> cacheValues (Operation *orig,
@@ -368,26 +372,27 @@ class AutoDiffReduceRev
368
372
ReduceOp> {
369
373
public:
370
374
LogicalResult createReverseModeAdjoint (Operation *orig, OpBuilder &builder,
371
- MGradientUtilsReverse *gutils,
372
- SmallVector<Value> caches) const {
375
+ MGradientUtilsReverse *gutils,
376
+ SmallVector<Value> caches) const {
373
377
auto op = cast<ReduceOp>(orig);
374
378
if (!isEligibleForCompactPrint (op)) {
375
- orig->emitError () << " Unsupported operation in reduction rev autodiff(1): "
376
- << *orig << " \n " ;
379
+ orig->emitError ()
380
+ << " Unsupported operation in reduction rev autodiff(1): " << *orig
381
+ << " \n " ;
377
382
return failure ();
378
383
}
379
384
380
385
Operation &innerOp = op.getBody ().front ().front ();
381
-
386
+
382
387
auto inTy = op->getOperand (0 ).getType ().cast <RankedTensorType>();
383
388
auto zero = inTy.cast <AutoDiffTypeInterface>().createNullValue (builder,
384
389
op.getLoc ());
385
390
auto inDiffe = gutils->diffe (op->getResult (0 ), builder);
386
391
gutils->zeroDiffe (op->getResult (0 ), builder);
387
-
388
- SmallVector<int64_t > toBroadcast;
389
- {
390
- size_t idx= 0 ;
392
+
393
+ SmallVector<int64_t > toBroadcast;
394
+ {
395
+ size_t idx = 0 ;
391
396
for (auto en : llvm::enumerate (inTy.getShape ())) {
392
397
if (llvm::is_contained (op.getDimensions (), en.index ())) {
393
398
// reduced op
@@ -396,56 +401,62 @@ class AutoDiffReduceRev
396
401
toBroadcast.push_back (idx);
397
402
idx++;
398
403
}
399
- }
404
+ }
400
405
401
406
if (isa<AddOp>(innerOp)) {
402
- if (!gutils->isConstantValue (op.getInputs ()[0 ])) {
407
+ if (!gutils->isConstantValue (op.getInputs ()[0 ])) {
403
408
Value bcast;
404
409
405
-
406
- bcast = builder.create <BroadcastInDimOp>(op.getLoc (), gutils->getShadowType (inTy), inDiffe, builder.getDenseI64ArrayAttr (toBroadcast));
410
+ bcast = builder.create <BroadcastInDimOp>(
411
+ op.getLoc (), gutils->getShadowType (inTy), inDiffe,
412
+ builder.getDenseI64ArrayAttr (toBroadcast));
407
413
408
414
gutils->addToDiffe (op.getInputs ()[0 ], bcast, builder);
409
- }
410
- if (!gutils->isConstantValue (op.getInitValues ()[0 ])) {
415
+ }
416
+ if (!gutils->isConstantValue (op.getInitValues ()[0 ])) {
411
417
gutils->addToDiffe (op.getInitValues ()[0 ], inDiffe, builder);
412
- }
413
- return success ();
418
+ }
419
+ return success ();
414
420
}
415
421
416
422
if (isa<MaxOp>(innerOp) || isa<MinOp>(innerOp)) {
417
- // TODO: technically we should invert the order here to pick the last value (or divide by count) if multiple are the same as the
418
- // result
423
+ // TODO: technically we should invert the order here to pick the last
424
+ // value (or divide by count) if multiple are the same as the result
419
425
auto ores = gutils->getNewFromOriginal (op->getResult (0 ));
420
426
421
427
if (!gutils->isConstantValue (op.getInputs ()[0 ])) {
422
428
auto oprev = gutils->getNewFromOriginal (op.getInputs ()[0 ]);
423
429
auto attr = builder.getDenseI64ArrayAttr (toBroadcast);
424
- auto bc = builder.create <BroadcastInDimOp>(op.getLoc (), oprev.getType (), ores, attr);
430
+ auto bc = builder.create <BroadcastInDimOp>(op.getLoc (), oprev.getType (),
431
+ ores, attr);
425
432
426
- auto cmp = builder.create <CompareOp>(op.getLoc (), bc, oprev, ComparisonDirection::EQ);
433
+ auto cmp = builder.create <CompareOp>(op.getLoc (), bc, oprev,
434
+ ComparisonDirection::EQ);
427
435
428
- auto bc2 = builder.create <BroadcastInDimOp>(op.getLoc (), oprev.getType (), inDiffe, attr);
436
+ auto bc2 = builder.create <BroadcastInDimOp>(
437
+ op.getLoc (), oprev.getType (), inDiffe, attr);
429
438
430
439
auto res = builder.create <SelectOp>(op.getLoc (), cmp, bc2, zero);
431
440
gutils->addToDiffe (op.getInputs ()[0 ], res, builder);
432
441
}
433
442
if (!gutils->isConstantValue (op.getInitValues ()[0 ])) {
434
443
auto oprev = gutils->getNewFromOriginal (op.getInitValues ()[0 ]);
435
444
436
- auto zeroI = inDiffe.getType ().cast <AutoDiffTypeInterface>().createNullValue (builder,
437
- op.getLoc ());
445
+ auto zeroI =
446
+ inDiffe.getType ().cast <AutoDiffTypeInterface>().createNullValue (
447
+ builder, op.getLoc ());
438
448
439
- auto cmp = builder.create <CompareOp>(op.getLoc (), ores, oprev, ComparisonDirection::EQ);
449
+ auto cmp = builder.create <CompareOp>(op.getLoc (), ores, oprev,
450
+ ComparisonDirection::EQ);
440
451
441
452
auto res = builder.create <SelectOp>(op.getLoc (), cmp, inDiffe, zeroI);
442
453
gutils->addToDiffe (op.getInitValues ()[0 ], res, builder);
443
454
}
444
455
return success ();
445
456
}
446
-
457
+
447
458
orig->emitError () << " Unsupported operation in reduction rev autodiff(1): "
448
- << *orig << " \n " ;
459
+ << *orig << " \n " ;
449
460
return failure ();
450
461
}
451
462
@@ -463,40 +474,43 @@ class AutoDiffConcatenateRev
463
474
ConcatenateOp> {
464
475
public:
465
476
LogicalResult createReverseModeAdjoint (Operation *orig, OpBuilder &builder,
466
- MGradientUtilsReverse *gutils,
467
- SmallVector<Value> caches) const {
477
+ MGradientUtilsReverse *gutils,
478
+ SmallVector<Value> caches) const {
468
479
auto op = cast<ConcatenateOp>(orig);
469
480
470
481
auto inDiffe = gutils->diffe (op->getResult (0 ), builder);
471
482
gutils->zeroDiffe (op->getResult (0 ), builder);
472
483
473
484
auto dim = op.getDimension ();
474
485
size_t startDim = 0 ;
475
- for (auto &ope : op->getOpOperands ()) {
476
- auto op = ope.get ();
477
- auto inTy = gutils->getShadowType (op.getType ());
478
- SmallVector<int64_t > start;
479
- SmallVector<int64_t > limit;
480
- SmallVector<int64_t > strides;
481
- SmallVector<int64_t > tys;
482
- auto RT = inTy.cast <RankedTensorType>();
483
- for (auto i=0 ; i<RT.getShape ().size (); i++) {
484
- tys.push_back (RT.getShape ()[i]);
485
- if (i == dim) {
486
- start.push_back (startDim);
487
- limit.push_back (startDim + RT.getShape ()[i]);
488
- startDim += RT.getShape ()[i];
489
- strides.push_back (1 );
490
- continue ;
491
- }
492
- start.push_back (0 );
493
- limit.push_back (RT.getShape ()[i]);
494
- strides.push_back (1 );
486
+ for (auto &ope : op->getOpOperands ()) {
487
+ auto op = ope.get ();
488
+ auto inTy = gutils->getShadowType (op.getType ());
489
+ SmallVector<int64_t > start;
490
+ SmallVector<int64_t > limit;
491
+ SmallVector<int64_t > strides;
492
+ SmallVector<int64_t > tys;
493
+ auto RT = inTy.cast <RankedTensorType>();
494
+ for (auto i = 0 ; i < RT.getShape ().size (); i++) {
495
+ tys.push_back (RT.getShape ()[i]);
496
+ if (i == dim) {
497
+ start.push_back (startDim);
498
+ limit.push_back (startDim + RT.getShape ()[i]);
499
+ startDim += RT.getShape ()[i];
500
+ strides.push_back (1 );
501
+ continue ;
495
502
}
496
- if (gutils->isConstantValue (op)) continue ;
497
- auto res = builder.create <SliceOp>(op.getLoc (), RankedTensorType::get (tys, RT.getElementType ()), inDiffe, start, limit, strides);
498
- auto res2 = builder.create <ReshapeOp>(op.getLoc (), inTy, res);
499
- gutils->addToDiffe (op, res2, builder);
503
+ start.push_back (0 );
504
+ limit.push_back (RT.getShape ()[i]);
505
+ strides.push_back (1 );
506
+ }
507
+ if (gutils->isConstantValue (op))
508
+ continue ;
509
+ auto res = builder.create <SliceOp>(
510
+ op.getLoc (), RankedTensorType::get (tys, RT.getElementType ()), inDiffe,
511
+ start, limit, strides);
512
+ auto res2 = builder.create <ReshapeOp>(op.getLoc (), inTy, res);
513
+ gutils->addToDiffe (op, res2, builder);
500
514
}
501
515
return success ();
502
516
}
0 commit comments