Skip to content

failing enum/union cases #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ZuseZ4 opened this issue Mar 1, 2023 · 6 comments
Closed

failing enum/union cases #11

ZuseZ4 opened this issue Mar 1, 2023 · 6 comments

Comments

@ZuseZ4
Copy link
Member

ZuseZ4 commented Mar 1, 2023

https://doc.rust-lang.org/beta/src/alloc/raw_vec.rs.html#448

https://doc.rust-lang.org/src/core/ptr/non_null.rs.html#197-203

@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Mar 1, 2023

#![feature(bench_black_box)]
use autodiff::autodiff;

enum Foo {
    a(f32),
    b(i32), 
}

#[autodiff(d_bar, Reverse, Active)]
fn bar(#[dup] x: &f32) -> f32 {
    let val: Foo = 
    if *x > 0.0 {
        Foo::a(*x)
    } else {
        Foo::b(12)
    };

    std::hint::black_box(&val);
    match val  {
        Foo::a(f) => f * f,
        Foo::b(_) => 4.0,
    }
}

fn main() {
    let x = 1.0;
    let mut dx = 2.0;
    let out = bar(&x);
    let dout = d_bar(&x, &mut dx, 1.0);
    println!("x: {out}");
}

which results in

define internal float @preprocess__ZN11broken_enum3bar17h73c03a99634c7e77E(float* align 4 %0) unnamed_addr #63 !dbg !63984 {
  %2 = alloca %182, align 4
  call void @llvm.dbg.value(metadata float* %0, metadata !63986, metadata !DIExpression()), !dbg !63991
  call void @llvm.dbg.declare(metadata %182* %2, metadata !63987, metadata !DIExpression()) #64, !dbg !63992
  %3 = load float, float* %0, align 4, !dbg !63993
  %4 = fcmp ogt float %3, 0.000000e+00, !dbg !63993
  br i1 %4, label %9, label %5, !dbg !63993

5:                                                ; preds = %1
  %6 = bitcast %182* %2 to %183*, !dbg !63994
  %7 = getelementptr inbounds %183, %183* %6, i32 0, i32 1, !dbg !63994
  store i32 12, i32* %7, align 4, !dbg !63994
  %8 = bitcast %182* %2 to i32*, !dbg !63994
  store i32 1, i32* %8, align 4, !dbg !63994
  br label %14, !dbg !63995

9:                                                ; preds = %1
  %10 = load float, float* %0, align 4, !dbg !63996
  %11 = bitcast %182* %2 to %184*, !dbg !63997
  %12 = getelementptr inbounds %184, %184* %11, i32 0, i32 1, !dbg !63997
  store float %10, float* %12, align 4, !dbg !63997
  %13 = bitcast %182* %2 to i32*, !dbg !63997
  store i32 0, i32* %13, align 4, !dbg !63997
  br label %14, !dbg !63995

14:                                               ; preds = %9, %5
  %15 = call align 4 %182* @_ZN4core4hint9black_box17h09b07334ec86fe27E(%182* align 4 %2) #64, !dbg !63998
  %16 = bitcast %182* %2 to i32*, !dbg !63999
  %17 = load i32, i32* %16, align 4, !dbg !63999, !range !1705, !noundef !23
  %18 = zext i32 %17 to i64, !dbg !63999
  switch i64 %18, label %19 [
    i64 0, label %20
    i64 1, label %25
  ], !dbg !64000

19:                                               ; preds = %14
  unreachable, !dbg !63999

20:                                               ; preds = %14
  %21 = bitcast %182* %2 to %184*, !dbg !64001
  %22 = getelementptr inbounds %184, %184* %21, i32 0, i32 1, !dbg !64001
  %23 = load float, float* %22, align 4, !dbg !64001
  call void @llvm.dbg.value(metadata float %23, metadata !63989, metadata !DIExpression()), !dbg !64002
  %24 = fmul float %23, %23, !dbg !64003
  br label %25, !dbg !64004

25:                                               ; preds = %14, %20
  %26 = phi float [ %24, %20 ], [ 4.000000e+00, %14 ], !dbg !64005
  ret float %26, !dbg !64006
}

@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Mar 1, 2023

Wip Explorer link: https://fwd.gymni.ch/9cLTiD

@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Mar 20, 2023

out2.ll.txt

@wsmoses
Copy link
Member

wsmoses commented Mar 20, 2023

Oh I see the issue here.

The load side of the match gets compiled to be executed unconditionally. Therefore we think it's always safe to assume float.

@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Mar 21, 2023

Will check which pass it is. Tmp:
afterPasses.ll.txt

@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Mar 21, 2023

notes:
EarlyCSEPass around line 500 removes the second load from the input %x and instead stores the previously read

SimplifyCFGPass around line 1000 removes the switch and replaces it with a select, after loading x unconditionally.
@wsmoses found the place.

before:

*** IR Dump Before SimplifyCFGPass on _ZN11broken_enum3bar17hfdde5c03da18b49aE ***
; Function Attrs: noinline nonlazybind uwtable
define hidden float @_ZN11broken_enum3bar17hfdde5c03da18b49aE(float* noalias noundef readonly align 4 dereferenceable(4) %x) unnamed_addr #0 !dbg !428 {
start:
  %val = alloca %Foo, align 4
  call void @llvm.dbg.value(metadata float* %x, metadata !432, metadata !DIExpression()), !dbg !437
  call void @llvm.dbg.declare(metadata %Foo* %val, metadata !433, metadata !DIExpression()), !dbg !438
  %0 = bitcast %Foo* %val to i8*, !dbg !439
  call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0), !dbg !439
  %_4 = load float, float* %x, align 4, !dbg !440
  %_3 = fcmp ogt float %_4, 0.000000e+00, !dbg !440
  br i1 %_3, label %bb1, label %bb2, !dbg !440

bb2:                                              ; preds = %start
  %1 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 1, i64 0, !dbg !441
  store i32 12, i32* %1, align 4, !dbg !441
  %2 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 0, !dbg !441
  store i32 1, i32* %2, align 4, !dbg !441
  br label %bb3, !dbg !442

bb1:                                              ; preds = %start
  %3 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 1, !dbg !443
  %4 = bitcast [1 x i32]* %3 to float*, !dbg !443
  store float %_4, float* %4, align 4, !dbg !443
  %5 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 0, !dbg !443
  store i32 0, i32* %5, align 4, !dbg !443
  br label %bb3, !dbg !442

bb3:                                              ; preds = %bb2, %bb1
  call fastcc void @_ZN4core4hint9black_box17h2a85d879c8ed7aa7E(%Foo* noalias noundef nonnull readonly align 4 dereferenceable(8) %val), !dbg !444
  %6 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 0, !dbg !445
  %7 = load i32, i32* %6, align 4, !dbg !445, !range !446, !noundef !23
  %trunc = icmp ne i32 %7, 0, !dbg !447
  switch i1 %trunc, label %bb6 [
    i1 false, label %bb7
    i1 true, label %bb5
  ], !dbg !447

bb6:                                              ; preds = %bb3
  unreachable, !dbg !445

bb7:                                              ; preds = %bb3
  %8 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 1, !dbg !448
  %9 = bitcast [1 x i32]* %8 to float*, !dbg !448
  %f = load float, float* %9, align 4, !dbg !448
  call void @llvm.dbg.value(metadata float %f, metadata !435, metadata !DIExpression()), !dbg !449
  %10 = fmul float %f, %f, !dbg !450
  br label %bb8, !dbg !451

bb5:                                              ; preds = %bb3
  br label %bb8, !dbg !452

bb8:                                              ; preds = %bb7, %bb5
  %.0 = phi float [ 4.000000e+00, %bb5 ], [ %10, %bb7 ], !dbg !453
  call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0), !dbg !454
  ret float %.0, !dbg !455
}

after:


*** IR Dump Before RequireAnalysisPass<llvm::GlobalsAA, llvm::Module> on [module] ***
; Function Attrs: noinline nonlazybind uwtable
define hidden float @_ZN11broken_enum3bar17hfdde5c03da18b49aE(float* noalias noundef readonly align 4 dereferenceable(4) %x) unnamed_addr #0 !dbg !428 {
start:
  %val = alloca %Foo, align 4
  call void @llvm.dbg.value(metadata float* %x, metadata !432, metadata !DIExpression()), !dbg !437
  call void @llvm.dbg.declare(metadata %Foo* %val, metadata !433, metadata !DIExpression()), !dbg !438
  %0 = bitcast %Foo* %val to i8*, !dbg !439
  call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %0), !dbg !439
  %_4 = load float, float* %x, align 4, !dbg !440
  %_3 = fcmp ogt float %_4, 0.000000e+00, !dbg !440
  br i1 %_3, label %bb1, label %bb2, !dbg !440

bb2:                                              ; preds = %start
  %1 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 1, i64 0, !dbg !441
  store i32 12, i32* %1, align 4, !dbg !441
  %2 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 0, !dbg !441
  store i32 1, i32* %2, align 4, !dbg !441
  br label %bb3, !dbg !442

bb1:                                              ; preds = %start
  %3 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 1, !dbg !443
  %4 = bitcast [1 x i32]* %3 to float*, !dbg !443
  store float %_4, float* %4, align 4, !dbg !443
  %5 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 0, !dbg !443
  store i32 0, i32* %5, align 4, !dbg !443
  br label %bb3, !dbg !442

bb3:                                              ; preds = %bb2, %bb1
  call fastcc void @_ZN4core4hint9black_box17h2a85d879c8ed7aa7E(%Foo* noalias noundef nonnull readonly align 4 dereferenceable(8) %val), !dbg !444
  %6 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 0, !dbg !445
  %7 = load i32, i32* %6, align 4, !dbg !445, !range !446, !noundef !23
  %trunc = icmp ne i32 %7, 0, !dbg !447
  %switch = icmp ult i1 %trunc, true, !dbg !447
  %8 = getelementptr inbounds %Foo, %Foo* %val, i64 0, i32 1, !dbg !447
  %9 = bitcast [1 x i32]* %8 to float*, !dbg !447
  %f = load float, float* %9, align 4, !dbg !447
  %10 = fmul float %f, %f, !dbg !447
  %.0 = select i1 %switch, float %10, float 4.000000e+00, !dbg !447
  call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %0), !dbg !448
  ret float %.0, !dbg !449
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants