Skip to content

WhileOp reverse derivative #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 10, 2024
Merged

WhileOp reverse derivative #160

merged 4 commits into from
Nov 10, 2024

Conversation

Pangoraw
Copy link
Collaborator

@Pangoraw Pangoraw commented Nov 8, 2024

  • conjugates

// The primal is augmented to store the number of iterations

auto newWhile = cast<WhileOp>(gutils->getNewFromOriginal(orig));
auto cond = &newWhile.getCond().front();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually one of the variables is a loop induction variable -- which means we don't need to store this. Not sure if this optimization would go here or in a canonicalization though.

Way more powerful versions of this would be here: https://github.com/llvm/Polygeist/blob/77c04bb2a7a2406ca9480bcc9e729b07d2c8d077/lib/polygeist/Passes/CanonicalizeFor.cpp#L662

The pro of doing this early though, is that we can potentially determine a static loop count and then make all of the augmented tensors instead of <?x32>

Copy link
Collaborator Author

@Pangoraw Pangoraw Nov 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to reuse the analysis from enzyme-hlo-unroll which we can make more robust in the future to support more complex patterns. the forward is not needed anymore if there is no cache push inside the body and the loop is a for loop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case I don't think we even care if the value is a constant, we just cache that value regardless of constant or not

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[and if it's constant cache fixups ought clean it up automatically for us]

SplatElementsAttr::get(unrankedTensorType,
ArrayRef<Attribute>(IntegerAttr::get(
bodyBuilder.getI64Type(), 1))));
Value bodyIterVar =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as weird, annoying as it is, we should probably use the createAdd from the autodiff type interface. Even though here we have to raise it back into a stablehlo.add, it does mean that new types will use the right add [if for example not correct in stablehlo.add]. I'm okay with this though if you feel strongly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah apologies, this add is for a new induction variable.

In this case same comment applies of we should re-use an existing induction variable [almost all cases presently], when possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively [or perhaps in addition to], we really should import most of that while op optimization stuff from polygeist [which also does redundant induction variable elimination somewhere iirc]

std::optional<int64_t> getConstantStart();
std::optional<int64_t> getConstantLimit();

/// Needs to be constant
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah here for the get number of induction variables, we can take a builder and compute the number of iterations

return %results1 : tensor<f64>
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add forward and reverse checks here?

// CHECK-NEXT: } do {
// CHECK-NEXT: %14 = stablehlo.add %iterArg, %c_0 : tensor<i64>
// CHECK-NEXT: "enzyme.set"(%4, %iterArg_2) : (!enzyme.Gradient<tensor<f64>>, tensor<f64>) -> ()
// CHECK-NEXT: %15 = "enzyme.get"(%4) : (!enzyme.Gradient<tensor<f64>>) -> tensor<f64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make sure that remove-enzyme-ops can deal with this %4 case without extra work.

%4 is defined right before this in a dominating way allowing %15 to be replaced with %iterArg_2.

Then now that only set's are called on it, we should be able to remove the whole thing.

No while special case handling required here

// CHECK-NEXT: }
// CHECK-NEXT: %11 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<f64>>) -> tensor<f64>
// CHECK-NEXT: %12 = arith.addf %11, %10#1 : tensor<f64>
// CHECK-NEXT: "enzyme.set"(%0, %12) : (!enzyme.Gradient<tensor<f64>>, tensor<f64>) -> ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also here with %0

// CHECK-NEXT: %16 = "enzyme.pop"(%3) : (!enzyme.Cache<tensor<f64>>) -> tensor<f64>
// CHECK-NEXT: %17 = "enzyme.pop"(%2) : (!enzyme.Cache<tensor<f64>>) -> tensor<f64>
// CHECK-NEXT: %18 = stablehlo.multiply %15, %17 : tensor<f64>
// CHECK-NEXT: %19 = "enzyme.get"(%1) : (!enzyme.Gradient<tensor<f64>>) -> tensor<f64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The %1 cache here confuses me. In principle your good design setup such that we don't have any gradient +='s that span the outside of the while regions. Yet here %4 is not able to be mem2reg'd within scope. So something is clearly going awry...

@wsmoses
Copy link
Member

wsmoses commented Nov 9, 2024

cc @mofeing I think the same argument of linearity from earlier applies here where we don't need to care about conjugates [which is nice]

@wsmoses wsmoses merged commit e7e0cc3 into EnzymeAD:main Nov 10, 2024
3 of 9 checks passed
@Pangoraw Pangoraw deleted the while-rev branch November 13, 2024 10:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants