Skip to content

Commit 677fdbc

Browse files
committed
cleanup
1 parent 1f7c32e commit 677fdbc

File tree

8 files changed

+396
-299
lines changed

8 files changed

+396
-299
lines changed

WORKSPACE

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,16 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen
6060

6161
pip_install_dependencies()
6262

63-
ENZYME_COMMIT = "97066352a40b3c66f9a1f41ec1802af255216c0c"
64-
ENZYME_SHA256 = ""
63+
ENZYME_COMMIT = "0a129ae7e45114a08f281e50632b9f967fae8396"
64+
ENZYME_SHA256 = "715982efd0a0ef8038e8ad35047e9c1941eb3f9cb038883342969b0bcc8915ad"
6565

66-
local_repository(
66+
http_archive(
6767
name = "enzyme",
68-
path = "../Enzyme/enzyme"
68+
sha256 = ENZYME_SHA256,
69+
strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme",
70+
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
6971
)
7072

71-
# http_archive(
72-
# name = "enzyme",
73-
# sha256 = ENZYME_SHA256,
74-
# strip_prefix = "Enzyme-" + ENZYME_COMMIT + "/enzyme",
75-
# urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
76-
# )
77-
7873
JAX_COMMIT = "9a098e922aff62a3b49bd673b9518d97ee599248"
7974
JAX_SHA256 = ""
8075

src/enzyme_ad/jax/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ cc_library(
145145
"@stablehlo//:reference_ops",
146146
"@llvm-project//mlir:ArithDialect",
147147
"@llvm-project//mlir:FuncDialect",
148+
"@llvm-project//mlir:TensorDialect",
148149
"@llvm-project//mlir:IR",
149150
"@llvm-project//mlir:FunctionInterfaces",
150151
"@llvm-project//mlir:ControlFlowInterfaces",
@@ -225,6 +226,7 @@ pybind_library(
225226
"@llvm-project//mlir:ArithDialect",
226227
"@llvm-project//mlir:FuncDialect",
227228
"@llvm-project//mlir:FuncExtensions",
229+
"@llvm-project//mlir:TensorDialect",
228230

229231
"@llvm-project//mlir:Parser",
230232
"@llvm-project//mlir:Pass",

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ class AutoDiffBroadcastInDimRev
192192
AutoDiffBroadcastInDimRev, BroadcastInDimOp> {
193193
public:
194194
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
195-
MGradientUtilsReverse *gutils,
196-
SmallVector<Value> caches) const {
195+
MGradientUtilsReverse *gutils,
196+
SmallVector<Value> caches) const {
197197
auto op = cast<BroadcastInDimOp>(orig);
198198
auto inTy = op.getOperand().getType();
199199
auto outTy = op.getType();
@@ -205,16 +205,16 @@ class AutoDiffBroadcastInDimRev
205205

206206
SmallVector<int64_t> newDims;
207207
for (auto en : llvm::enumerate(outTy.getShape())) {
208-
if (llvm::is_contained(bcastDims, en.index())) continue;
208+
if (llvm::is_contained(bcastDims, en.index()))
209+
continue;
209210
newDims.push_back(en.index());
210211
}
211212

212213
Value zero = gutils->getShadowType(inTy)
213-
.cast<AutoDiffTypeInterface>()
214-
.createNullValue(builder, op.getLoc());
214+
.cast<AutoDiffTypeInterface>()
215+
.createNullValue(builder, op.getLoc());
215216

216-
auto red = builder.create<ReduceOp>(op.getLoc(),
217-
TypeRange(zero.getType()),
217+
auto red = builder.create<ReduceOp>(op.getLoc(), TypeRange(zero.getType()),
218218
inDiffe, zero, newDims);
219219
red.getBody().push_back(new Block());
220220
Block &body = red.getBody().front();
@@ -228,9 +228,9 @@ class AutoDiffBroadcastInDimRev
228228
bodyBuilder.create<ReturnOp>(op.getLoc(), ValueRange(add));
229229

230230
Value res = red->getResult(0);
231-
Type resTy = gutils->getShadowType(op.getOperand().getType());
231+
Type resTy = gutils->getShadowType(op.getOperand().getType());
232232
if (res.getType() != resTy)
233-
res = builder.create<ReshapeOp>(op.getLoc(), resTy, res);
233+
res = builder.create<ReshapeOp>(op.getLoc(), resTy, res);
234234

235235
gutils->addToDiffe(op.getOperand(), res, builder);
236236
return success();
@@ -250,8 +250,8 @@ class AutoDiffSliceRev
250250
SliceOp> {
251251
public:
252252
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
253-
MGradientUtilsReverse *gutils,
254-
SmallVector<Value> caches) const {
253+
MGradientUtilsReverse *gutils,
254+
SmallVector<Value> caches) const {
255255
auto op = cast<SliceOp>(orig);
256256
auto inTy = op.getOperand().getType();
257257
auto outTy = op.getType();
@@ -263,21 +263,25 @@ class AutoDiffSliceRev
263263
SmallVector<int64_t> starts;
264264
SmallVector<int64_t> edge_padding_high;
265265
SmallVector<int64_t> interior_padding;
266-
for (auto &&[start, limit, stride, dim] : llvm::zip(
267-
op.getStartIndices(), op.getLimitIndices(), op.getStrides(), inTy.getShape())) {
266+
for (auto &&[start, limit, stride, dim] :
267+
llvm::zip(op.getStartIndices(), op.getLimitIndices(), op.getStrides(),
268+
inTy.getShape())) {
268269
starts.push_back(start);
269270
edge_padding_high.push_back(dim - limit);
270271
interior_padding.push_back(stride - 1);
271272
}
272273

273-
274-
auto zeroPad = RankedTensorType::get({}, inTy.getElementType()).cast<AutoDiffTypeInterface>().createNullValue(builder,
275-
op.getLoc());
276-
auto red = builder.create<stablehlo::PadOp>(op.getLoc(), inDiffe, zeroPad, builder.getDenseI64ArrayAttr(starts), builder.getDenseI64ArrayAttr(edge_padding_high), builder.getDenseI64ArrayAttr(interior_padding));
274+
auto zeroPad = RankedTensorType::get({}, inTy.getElementType())
275+
.cast<AutoDiffTypeInterface>()
276+
.createNullValue(builder, op.getLoc());
277+
auto red = builder.create<stablehlo::PadOp>(
278+
op.getLoc(), inDiffe, zeroPad, builder.getDenseI64ArrayAttr(starts),
279+
builder.getDenseI64ArrayAttr(edge_padding_high),
280+
builder.getDenseI64ArrayAttr(interior_padding));
277281

278282
gutils->addToDiffe(op.getOperand(), red->getResult(0), builder);
279283
return success();
280-
#if 0
284+
#if 0
281285

282286
Value idxs;
283287
{
@@ -351,7 +355,7 @@ class AutoDiffSliceRev
351355
// gutils->setDiffe(op.getOperand(), red->getResult(0), builder);
352356

353357
return success();
354-
#endif
358+
#endif
355359
}
356360

357361
SmallVector<Value> cacheValues(Operation *orig,
@@ -368,26 +372,27 @@ class AutoDiffReduceRev
368372
ReduceOp> {
369373
public:
370374
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
371-
MGradientUtilsReverse *gutils,
372-
SmallVector<Value> caches) const {
375+
MGradientUtilsReverse *gutils,
376+
SmallVector<Value> caches) const {
373377
auto op = cast<ReduceOp>(orig);
374378
if (!isEligibleForCompactPrint(op)) {
375-
orig->emitError() << "Unsupported operation in reduction rev autodiff(1): "
376-
<< *orig << "\n";
379+
orig->emitError()
380+
<< "Unsupported operation in reduction rev autodiff(1): " << *orig
381+
<< "\n";
377382
return failure();
378383
}
379384

380385
Operation &innerOp = op.getBody().front().front();
381-
386+
382387
auto inTy = op->getOperand(0).getType().cast<RankedTensorType>();
383388
auto zero = inTy.cast<AutoDiffTypeInterface>().createNullValue(builder,
384389
op.getLoc());
385390
auto inDiffe = gutils->diffe(op->getResult(0), builder);
386391
gutils->zeroDiffe(op->getResult(0), builder);
387-
388-
SmallVector<int64_t> toBroadcast;
389-
{
390-
size_t idx=0;
392+
393+
SmallVector<int64_t> toBroadcast;
394+
{
395+
size_t idx = 0;
391396
for (auto en : llvm::enumerate(inTy.getShape())) {
392397
if (llvm::is_contained(op.getDimensions(), en.index())) {
393398
// reduced op
@@ -396,56 +401,62 @@ class AutoDiffReduceRev
396401
toBroadcast.push_back(idx);
397402
idx++;
398403
}
399-
}
404+
}
400405

401406
if (isa<AddOp>(innerOp)) {
402-
if (!gutils->isConstantValue(op.getInputs()[0])) {
407+
if (!gutils->isConstantValue(op.getInputs()[0])) {
403408
Value bcast;
404409

405-
406-
bcast = builder.create<BroadcastInDimOp>(op.getLoc(), gutils->getShadowType(inTy), inDiffe, builder.getDenseI64ArrayAttr(toBroadcast));
410+
bcast = builder.create<BroadcastInDimOp>(
411+
op.getLoc(), gutils->getShadowType(inTy), inDiffe,
412+
builder.getDenseI64ArrayAttr(toBroadcast));
407413

408414
gutils->addToDiffe(op.getInputs()[0], bcast, builder);
409-
}
410-
if (!gutils->isConstantValue(op.getInitValues()[0])) {
415+
}
416+
if (!gutils->isConstantValue(op.getInitValues()[0])) {
411417
gutils->addToDiffe(op.getInitValues()[0], inDiffe, builder);
412-
}
413-
return success();
418+
}
419+
return success();
414420
}
415421

416422
if (isa<MaxOp>(innerOp) || isa<MinOp>(innerOp)) {
417-
// TODO: technically we should invert the order here to pick the last value (or divide by count) if multiple are the same as the
418-
// result
423+
// TODO: technically we should invert the order here to pick the last
424+
// value (or divide by count) if multiple are the same as the result
419425
auto ores = gutils->getNewFromOriginal(op->getResult(0));
420426

421427
if (!gutils->isConstantValue(op.getInputs()[0])) {
422428
auto oprev = gutils->getNewFromOriginal(op.getInputs()[0]);
423429
auto attr = builder.getDenseI64ArrayAttr(toBroadcast);
424-
auto bc = builder.create<BroadcastInDimOp>(op.getLoc(), oprev.getType(), ores, attr);
430+
auto bc = builder.create<BroadcastInDimOp>(op.getLoc(), oprev.getType(),
431+
ores, attr);
425432

426-
auto cmp = builder.create<CompareOp>(op.getLoc(), bc, oprev, ComparisonDirection::EQ);
433+
auto cmp = builder.create<CompareOp>(op.getLoc(), bc, oprev,
434+
ComparisonDirection::EQ);
427435

428-
auto bc2 = builder.create<BroadcastInDimOp>(op.getLoc(), oprev.getType(), inDiffe, attr);
436+
auto bc2 = builder.create<BroadcastInDimOp>(
437+
op.getLoc(), oprev.getType(), inDiffe, attr);
429438

430439
auto res = builder.create<SelectOp>(op.getLoc(), cmp, bc2, zero);
431440
gutils->addToDiffe(op.getInputs()[0], res, builder);
432441
}
433442
if (!gutils->isConstantValue(op.getInitValues()[0])) {
434443
auto oprev = gutils->getNewFromOriginal(op.getInitValues()[0]);
435444

436-
auto zeroI = inDiffe.getType().cast<AutoDiffTypeInterface>().createNullValue(builder,
437-
op.getLoc());
445+
auto zeroI =
446+
inDiffe.getType().cast<AutoDiffTypeInterface>().createNullValue(
447+
builder, op.getLoc());
438448

439-
auto cmp = builder.create<CompareOp>(op.getLoc(), ores, oprev, ComparisonDirection::EQ);
449+
auto cmp = builder.create<CompareOp>(op.getLoc(), ores, oprev,
450+
ComparisonDirection::EQ);
440451

441452
auto res = builder.create<SelectOp>(op.getLoc(), cmp, inDiffe, zeroI);
442453
gutils->addToDiffe(op.getInitValues()[0], res, builder);
443454
}
444455
return success();
445456
}
446-
457+
447458
orig->emitError() << "Unsupported operation in reduction rev autodiff(1): "
448-
<< *orig << "\n";
459+
<< *orig << "\n";
449460
return failure();
450461
}
451462

@@ -463,40 +474,43 @@ class AutoDiffConcatenateRev
463474
ConcatenateOp> {
464475
public:
465476
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
466-
MGradientUtilsReverse *gutils,
467-
SmallVector<Value> caches) const {
477+
MGradientUtilsReverse *gutils,
478+
SmallVector<Value> caches) const {
468479
auto op = cast<ConcatenateOp>(orig);
469480

470481
auto inDiffe = gutils->diffe(op->getResult(0), builder);
471482
gutils->zeroDiffe(op->getResult(0), builder);
472483

473484
auto dim = op.getDimension();
474485
size_t startDim = 0;
475-
for (auto &ope : op->getOpOperands()) {
476-
auto op = ope.get();
477-
auto inTy = gutils->getShadowType(op.getType());
478-
SmallVector<int64_t> start;
479-
SmallVector<int64_t> limit;
480-
SmallVector<int64_t> strides;
481-
SmallVector<int64_t> tys;
482-
auto RT = inTy.cast<RankedTensorType>();
483-
for (auto i=0; i<RT.getShape().size(); i++) {
484-
tys.push_back(RT.getShape()[i]);
485-
if (i == dim) {
486-
start.push_back(startDim);
487-
limit.push_back(startDim + RT.getShape()[i]);
488-
startDim += RT.getShape()[i];
489-
strides.push_back(1);
490-
continue;
491-
}
492-
start.push_back(0);
493-
limit.push_back(RT.getShape()[i]);
494-
strides.push_back(1);
486+
for (auto &ope : op->getOpOperands()) {
487+
auto op = ope.get();
488+
auto inTy = gutils->getShadowType(op.getType());
489+
SmallVector<int64_t> start;
490+
SmallVector<int64_t> limit;
491+
SmallVector<int64_t> strides;
492+
SmallVector<int64_t> tys;
493+
auto RT = inTy.cast<RankedTensorType>();
494+
for (auto i = 0; i < RT.getShape().size(); i++) {
495+
tys.push_back(RT.getShape()[i]);
496+
if (i == dim) {
497+
start.push_back(startDim);
498+
limit.push_back(startDim + RT.getShape()[i]);
499+
startDim += RT.getShape()[i];
500+
strides.push_back(1);
501+
continue;
495502
}
496-
if (gutils->isConstantValue(op)) continue;
497-
auto res = builder.create<SliceOp>(op.getLoc(), RankedTensorType::get(tys, RT.getElementType()), inDiffe, start, limit, strides);
498-
auto res2 = builder.create<ReshapeOp>(op.getLoc(), inTy, res);
499-
gutils->addToDiffe(op, res2, builder);
503+
start.push_back(0);
504+
limit.push_back(RT.getShape()[i]);
505+
strides.push_back(1);
506+
}
507+
if (gutils->isConstantValue(op))
508+
continue;
509+
auto res = builder.create<SliceOp>(
510+
op.getLoc(), RankedTensorType::get(tys, RT.getElementType()), inDiffe,
511+
start, limit, strides);
512+
auto res2 = builder.create<ReshapeOp>(op.getLoc(), inTy, res);
513+
gutils->addToDiffe(op, res2, builder);
500514
}
501515
return success();
502516
}

0 commit comments

Comments
 (0)