Skip to content

Commit bc00d4b

Browse files
authored
Add memory invariant assumptions for virtual calls to improve devirtualization (#4596)
Assumptions: - class methods do not change the object's vtable pointer - loads through vtable pointer are invariant (vtables are immutable throughout program life)
1 parent 321701e commit bc00d4b

File tree

7 files changed

+163
-21
lines changed

7 files changed

+163
-21
lines changed

gen/classes.cpp

+13-7
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,8 @@ DValue *DtoDynamicCastInterface(const Loc &loc, DValue *val, Type *_to) {
422422

423423
////////////////////////////////////////////////////////////////////////////////
424424

425-
LLValue *DtoVirtualFunctionPointer(DValue *inst, FuncDeclaration *fdecl) {
425+
std::pair<llvm::Value *, llvm::Value *>
426+
DtoVirtualFunctionPointer(DValue *inst, FuncDeclaration *fdecl) {
426427
// sanity checks
427428
assert(fdecl->isVirtual());
428429
assert(!fdecl->isFinalFunc());
@@ -440,17 +441,22 @@ LLValue *DtoVirtualFunctionPointer(DValue *inst, FuncDeclaration *fdecl) {
440441
const auto irtc = getIrType(tc->sym->type, true)->isClass();
441442
const auto vtblType = irtc->getVtblType();
442443

443-
LLValue *funcval = vthis;
444+
LLValue *vtable = vthis;
444445
// get the vtbl for objects
445-
funcval = DtoGEP(irtc->getMemoryLLType(), funcval, 0u, 0);
446+
vtable = DtoGEP(irtc->getMemoryLLType(), vthis, 0u, 0);
446447
// load vtbl ptr
447-
funcval = DtoLoad(vtblType->getPointerTo(), funcval);
448+
vtable = DtoLoad(vtblType->getPointerTo(), vtable);
448449
// index vtbl
449450
const std::string name = fdecl->toChars();
450451
const auto vtblname = name + "@vtbl";
451-
funcval = DtoGEP(vtblType, funcval, 0, fdecl->vtblIndex, vtblname.c_str());
452-
// load opaque pointer
452+
LLValue *funcval =
453+
DtoGEP(vtblType, vtable, 0, fdecl->vtblIndex, vtblname.c_str());
454+
// load opaque pointer.
453455
funcval = DtoAlignedLoad(vtblType->getElementType(), funcval);
456+
// Because vtables are immutable, LLVM's !invariant.load
457+
// can be applied (helps with devirtualization).
458+
llvm::cast<llvm::LoadInst>(funcval)->setMetadata(
459+
"invariant.load", llvm::MDNode::get(gIR->context(), {}));
454460

455461
IF_LOG Logger::cout() << "funcval: " << *funcval << '\n';
456462

@@ -462,5 +468,5 @@ LLValue *DtoVirtualFunctionPointer(DValue *inst, FuncDeclaration *fdecl) {
462468

463469
IF_LOG Logger::cout() << "funcval casted: " << *funcval << '\n';
464470

465-
return funcval;
471+
return std::make_pair(funcval, vtable);
466472
}

gen/classes.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <utility>
18+
1719
#include "gen/structs.h"
1820

1921
class ClassDeclaration;
@@ -37,4 +39,6 @@ DValue *DtoDynamicCastObject(const Loc &loc, DValue *val, Type *to);
3739

3840
DValue *DtoDynamicCastInterface(const Loc &loc, DValue *val, Type *to);
3941

40-
llvm::Value *DtoVirtualFunctionPointer(DValue *inst, FuncDeclaration *fdecl);
42+
/// Returns pair of function pointer and vtable pointer.
43+
std::pair<llvm::Value *, llvm::Value *>
44+
DtoVirtualFunctionPointer(DValue *inst, FuncDeclaration *fdecl);

gen/dvalue.cpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,13 @@ LLValue *DSliceValue::getPtr() {
105105

106106
////////////////////////////////////////////////////////////////////////////////
107107

108-
DFuncValue::DFuncValue(Type *t, FuncDeclaration *fd, LLValue *v, LLValue *vt)
109-
: DRValue(t, v), func(fd), vthis(vt) {}
108+
DFuncValue::DFuncValue(Type *t, FuncDeclaration *fd, LLValue *v, LLValue *vt,
109+
LLValue *vtable)
110+
: DRValue(t, v), func(fd), vthis(vt), vtable(vtable) {}
110111

111-
DFuncValue::DFuncValue(FuncDeclaration *fd, LLValue *v, LLValue *vt)
112-
: DFuncValue(fd->type, fd, v, vt) {}
112+
DFuncValue::DFuncValue(FuncDeclaration *fd, LLValue *v, LLValue *vt,
113+
LLValue *vtable)
114+
: DFuncValue(fd->type, fd, v, vt, vtable) {}
113115

114116
bool DFuncValue::definedInFuncEntryBB() {
115117
return isDefinedInFuncEntryBB(val) &&
@@ -235,7 +237,7 @@ DRValue *DDcomputeLValue::getRVal() {
235237
llvm_unreachable("getRVal() for memory-only type");
236238
return nullptr;
237239
}
238-
240+
239241
LLValue *rval = DtoLoad(lltype, val);
240242

241243
const auto ty = type->toBasetype()->ty;

gen/dvalue.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,18 @@ class DSliceValue : public DRValue {
137137
llvm::Value *getPtr();
138138
};
139139

140-
/// Represents a D function value with optional this/context pointer.
140+
/// Represents a D function value with optional this/context pointer, and
141+
/// optional vtable pointer.
141142
class DFuncValue : public DRValue {
142143
public:
143144
FuncDeclaration *func;
144145
llvm::Value *vthis;
146+
llvm::Value *vtable;
145147

146148
DFuncValue(Type *t, FuncDeclaration *fd, llvm::Value *v,
147-
llvm::Value *vt = nullptr);
148-
DFuncValue(FuncDeclaration *fd, llvm::Value *v, llvm::Value *vt = nullptr);
149+
llvm::Value *vt = nullptr, llvm::Value *vtable = nullptr);
150+
DFuncValue(FuncDeclaration *fd, llvm::Value *v, llvm::Value *vt = nullptr,
151+
llvm::Value *vtable = nullptr);
149152

150153
bool definedInFuncEntryBB() override;
151154

gen/ms-cxx-helper.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ void cloneBlocks(const std::vector<llvm::BasicBlock *> &srcblocks,
9797
for (auto &II : *bb) {
9898
llvm::Instruction *Inst = &II;
9999
llvm::Instruction *newInst = nullptr;
100-
if (funclet &&
101-
!llvm::isa<llvm::DbgInfoIntrinsic>(Inst)) { // IntrinsicInst?
100+
if (funclet && !llvm::isa<llvm::IntrinsicInst>(Inst)) {
102101
if (auto IInst = llvm::dyn_cast<llvm::InvokeInst>(Inst)) {
103102
auto invoke = llvm::InvokeInst::Create(
104103
IInst, llvm::OperandBundleDef("funclet", funclet));

gen/toir.cpp

+45-3
Original file line numberDiff line numberDiff line change
@@ -771,13 +771,54 @@ class ToElemVisitor : public Visitor {
771771

772772
// get func value if any
773773
DFuncValue *dfnval = fnval->isFunc();
774+
775+
// If this is a virtual function call, the object is passed by reference
776+
// through the `this` parameter, and therefore the optimizer has to assume
777+
// that the vtable field might be overwritten. This prevents optimization of
778+
// subsequent virtual calls on the same object. We help the optimizer by
779+
// allowing it to assume that the vtable field contents is the same after
780+
// the call. Equivalent D code:
781+
// ```
782+
// auto saved_vtable = a.__vptr; // emitted as part of `a.foo()`,
783+
// // except when e->directcall==true for
784+
// // final method calls.
785+
// a.foo();
786+
// assume(a.__vptr == saved_vtable); // <-- added assumption
787+
// ```
788+
// Only emit this extra code from -O2.
789+
// This optimization is only valid for D class method calls (not C++).
790+
bool canEmitVTableUnchangedAssumption =
791+
dfnval && dfnval->func && (dfnval->func->_linkage == LINK::d) &&
792+
(optLevel() >= 2);
793+
774794
if (dfnval && dfnval->func) {
775795
assert(!DtoIsMagicIntrinsic(dfnval->func));
796+
797+
// If loading the vtable was not needed for function call, we have to load
798+
// it here to do the "assume" optimization below.
799+
if (canEmitVTableUnchangedAssumption && !dfnval->vtable &&
800+
dfnval->vthis && dfnval->func->isVirtual()) {
801+
dfnval->vtable =
802+
DtoLoad(getVoidPtrType(),
803+
DtoBitCast(dfnval->vthis, getVoidPtrType()->getPointerTo()),
804+
"saved_vtable");
805+
}
776806
}
777807

778808
DValue *result =
779809
DtoCallFunction(e->loc, e->type, fnval, e->arguments, sretPointer);
780810

811+
if (canEmitVTableUnchangedAssumption && dfnval->vtable) {
812+
// Reload vtable ptr. It's the first element so instead of GEP+load we can
813+
// do a void* load+bitcast (at this point in the code we don't have easy
814+
// access to the type of the class to do a GEP).
815+
auto vtable = DtoLoad(
816+
dfnval->vtable->getType(),
817+
DtoBitCast(dfnval->vthis, dfnval->vtable->getType()->getPointerTo()));
818+
auto cmp = p->ir->CreateICmpEQ(vtable, dfnval->vtable);
819+
p->ir->CreateCall(GET_INTRINSIC_DECL(assume), {cmp});
820+
}
821+
781822
if (delayedDtorVar) {
782823
delayedDtorVar->edtor = delayedDtorExp;
783824
pushVarDtorCleanup(p, delayedDtorVar);
@@ -1044,16 +1085,17 @@ class ToElemVisitor : public Visitor {
10441085

10451086
// Get the actual function value to call.
10461087
LLValue *funcval = nullptr;
1088+
LLValue *vtable = nullptr;
10471089
if (nonFinal) {
10481090
DtoResolveFunction(fdecl);
1049-
funcval = DtoVirtualFunctionPointer(l, fdecl);
1091+
std::tie(funcval, vtable) = DtoVirtualFunctionPointer(l, fdecl);
10501092
} else {
10511093
funcval = DtoCallee(fdecl);
10521094
}
10531095
assert(funcval);
10541096

10551097
LLValue *vthis = (DtoIsInMemoryOnly(l->type) ? DtoLVal(l) : DtoRVal(l));
1056-
result = new DFuncValue(fdecl, funcval, vthis);
1098+
result = new DFuncValue(fdecl, funcval, vthis, vtable);
10571099
} else {
10581100
llvm_unreachable("Unknown target for VarDeclaration.");
10591101
}
@@ -1950,7 +1992,7 @@ class ToElemVisitor : public Visitor {
19501992

19511993
if (e->e1->op != EXP::super_ && e->e1->op != EXP::dotType &&
19521994
e->func->isVirtual() && !e->func->isFinalFunc()) {
1953-
castfptr = DtoVirtualFunctionPointer(u, e->func);
1995+
castfptr = DtoVirtualFunctionPointer(u, e->func).first;
19541996
} else if (e->func->isAbstract()) {
19551997
llvm_unreachable("Delegate to abstract method not implemented.");
19561998
} else if (e->func->toParent()->isInterfaceDeclaration()) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Tests that class member function calls do not prevent devirtualization (vtable cannot change in class member call).
2+
3+
// RUN: %ldc -output-ll -of=%t.ll %s -O3 && FileCheck %s < %t.ll
4+
5+
class A {
6+
void foo();
7+
final void oof();
8+
}
9+
class B : A {
10+
override void foo();
11+
}
12+
13+
// CHECK-LABEL: define{{.*}}ggg
14+
void ggg()
15+
{
16+
A a = new A();
17+
// CHECK: call {{.*}}_D29devirtualization_assumevtable1A3foo
18+
a.foo();
19+
// CHECK: call {{.*}}_D29devirtualization_assumevtable1A3foo
20+
a.foo();
21+
}
22+
23+
// CHECK-LABEL: define{{.*}}hhh
24+
void hhh()
25+
{
26+
A a = new A();
27+
// CHECK: call {{.*}}_D29devirtualization_assumevtable1A3oof
28+
a.oof();
29+
// CHECK: call {{.*}}_D29devirtualization_assumevtable1A3foo
30+
a.foo();
31+
}
32+
33+
// CHECK-LABEL: define{{.*}}directcall
34+
void directcall()
35+
{
36+
A a = new A();
37+
// CHECK: call {{.*}}_D29devirtualization_assumevtable1A3foo
38+
a.A.foo();
39+
// CHECK: call {{.*}}_D29devirtualization_assumevtable1A3foo
40+
a.foo();
41+
}
42+
// CHECK-LABEL: define{{.*}}exacttypeunknown
43+
void exacttypeunknown(A a, A b)
44+
{
45+
// CHECK: %[[FOO:[0-9a-z]+]] = load {{.*}}!invariant
46+
// CHECK: call{{.*}} void %[[FOO]](
47+
a.foo();
48+
// CHECK: call{{.*}} void %[[FOO]](
49+
a.foo();
50+
51+
a = b;
52+
// CHECK: %[[FOO2:[0-9a-z]+]] = load {{.*}}!invariant
53+
// CHECK: call{{.*}} void %[[FOO2]](
54+
a.foo();
55+
}
56+
57+
// No vtable loading and assume calls for struct method calls.
58+
struct S {
59+
void foo();
60+
}
61+
// CHECK-LABEL: define{{.*}}structS
62+
void structS(S s)
63+
{
64+
// CHECK-NOT: llvm.assume
65+
// CHECK-NOT: load
66+
s.foo();
67+
// CHECK: ret void
68+
}
69+
70+
// The devirtualization is not valid for C++ methods.
71+
extern(C++)
72+
class CPPClass {
73+
void foo();
74+
void oof();
75+
}
76+
77+
// CHECK-LABEL: define{{.*}}exactCPPtypeunknown
78+
void exactCPPtypeunknown(CPPClass a)
79+
{
80+
// CHECK: %[[FOO:[0-9a-z]+]] = load {{.*}}!invariant
81+
// CHECK: call{{.*}} void %[[FOO]](
82+
a.foo();
83+
// CHECK: %[[FOO2:[0-9a-z]+]] = load {{.*}}!invariant
84+
// CHECK: call{{.*}} void %[[FOO2]](
85+
a.foo();
86+
}

0 commit comments

Comments
 (0)