-
Notifications
You must be signed in to change notification settings - Fork 18
WhileOp reverse derivative #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
// The primal is augmented to store the number of iterations | ||
|
||
auto newWhile = cast<WhileOp>(gutils->getNewFromOriginal(orig)); | ||
auto cond = &newWhile.getCond().front(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually one of the variables is a loop induction variable -- which means we don't need to store this. Not sure if this optimization would go here or in a canonicalization though.
Way more powerful versions of this would be here: https://github.com/llvm/Polygeist/blob/77c04bb2a7a2406ca9480bcc9e729b07d2c8d077/lib/polygeist/Passes/CanonicalizeFor.cpp#L662
The pro of doing this early though, is that we can potentially determine a static loop count and then make all of the augmented tensors instead of <?x32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was able to reuse the analysis from enzyme-hlo-unroll which we can make more robust in the future to support more complex patterns. the forward is not needed anymore if there is no cache push inside the body and the loop is a for loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this case I don't think we even care if the value is a constant, we just cache that value regardless of constant or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[and if it's constant cache fixups ought clean it up automatically for us]
SplatElementsAttr::get(unrankedTensorType, | ||
ArrayRef<Attribute>(IntegerAttr::get( | ||
bodyBuilder.getI64Type(), 1)))); | ||
Value bodyIterVar = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as weird, annoying as it is, we should probably use the createAdd from the autodiff type interface. Even though here we have to raise it back into a stablehlo.add, it does mean that new types will use the right add [if for example not correct in stablehlo.add]. I'm okay with this though if you feel strongly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah apologies, this add is for a new induction variable.
In this case same comment applies of we should re-use an existing induction variable [almost all cases presently], when possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively [or perhaps in addition to], we really should import most of that while op optimization stuff from polygeist [which also does redundant induction variable elimination somewhere iirc]
std::optional<int64_t> getConstantStart(); | ||
std::optional<int64_t> getConstantLimit(); | ||
|
||
/// Needs to be constant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah here for the get number of induction variables, we can take a builder and compute the number of iterations
return %results1 : tensor<f64> | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add forward and reverse checks here?
// CHECK-NEXT: } do { | ||
// CHECK-NEXT: %14 = stablehlo.add %iterArg, %c_0 : tensor<i64> | ||
// CHECK-NEXT: "enzyme.set"(%4, %iterArg_2) : (!enzyme.Gradient<tensor<f64>>, tensor<f64>) -> () | ||
// CHECK-NEXT: %15 = "enzyme.get"(%4) : (!enzyme.Gradient<tensor<f64>>) -> tensor<f64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should make sure that remove-enzyme-ops can deal with this %4 case without extra work.
%4 is defined right before this in a dominating way allowing %15 to be replaced with %iterArg_2.
Then now that only set's are called on it, we should be able to remove the whole thing.
No while special case handling required here
// CHECK-NEXT: } | ||
// CHECK-NEXT: %11 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<f64>>) -> tensor<f64> | ||
// CHECK-NEXT: %12 = arith.addf %11, %10#1 : tensor<f64> | ||
// CHECK-NEXT: "enzyme.set"(%0, %12) : (!enzyme.Gradient<tensor<f64>>, tensor<f64>) -> () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and also here with %0
// CHECK-NEXT: %16 = "enzyme.pop"(%3) : (!enzyme.Cache<tensor<f64>>) -> tensor<f64> | ||
// CHECK-NEXT: %17 = "enzyme.pop"(%2) : (!enzyme.Cache<tensor<f64>>) -> tensor<f64> | ||
// CHECK-NEXT: %18 = stablehlo.multiply %15, %17 : tensor<f64> | ||
// CHECK-NEXT: %19 = "enzyme.get"(%1) : (!enzyme.Gradient<tensor<f64>>) -> tensor<f64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The %1 cache here confuses me. In principle your good design setup such that we don't have any gradient +='s that span the outside of the while regions. Yet here %4 is not able to be mem2reg'd within scope. So something is clearly going awry...
cc @mofeing I think the same argument of linearity from earlier applies here where we don't need to care about conjugates [which is nice] |
Uh oh!
There was an error while loading. Please reload this page.