Skip to content

Commit 4dcd91a

Browse files
[PAC] Implement authentication for C++ member function pointers (#99576)
Introduces type based signing of member function pointers. To support this discrimination schema we no longer emit member function pointer to virtual methods and indices into a vtable but migrate to using thunks. This does mean member function pointers are no longer necessarily directly comparable, however as such comparisons are UB this is acceptable. We derive the discriminator from the C++ mangling of the type of the pointer being authenticated. Co-Authored-By: Akira Hatanaka [email protected] Co-Authored-By: John McCall [email protected] Co-authored-by: Ahmed Bougacha <[email protected]>
1 parent d3fb41d commit 4dcd91a

File tree

10 files changed

+847
-117
lines changed

10 files changed

+847
-117
lines changed

clang/include/clang/AST/ASTContext.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
12871287
getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD);
12881288

12891289
/// Return the "other" type-specific discriminator for the given type.
1290-
uint16_t getPointerAuthTypeDiscriminator(QualType T) const;
1290+
uint16_t getPointerAuthTypeDiscriminator(QualType T);
12911291

12921292
/// Apply Objective-C protocol qualifiers to the given type.
12931293
/// \param allowOnPointerType specifies if we can apply protocol

clang/include/clang/Basic/PointerAuthOptions.h

+3
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ struct PointerAuthOptions {
180180

181181
/// The ABI for variadic C++ virtual function pointers.
182182
PointerAuthSchema CXXVirtualVariadicFunctionPointers;
183+
184+
/// The ABI for C++ member function pointers.
185+
PointerAuthSchema CXXMemberFunctionPointers;
183186
};
184187

185188
} // end namespace clang

clang/lib/AST/ASTContext.cpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -3407,7 +3407,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
34073407
}
34083408
}
34093409

3410-
uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
3410+
uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) {
34113411
assert(!T->isDependentType() &&
34123412
"cannot compute type discriminator of a dependent type");
34133413

@@ -3417,11 +3417,13 @@ uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
34173417
if (T->isFunctionPointerType() || T->isFunctionReferenceType())
34183418
T = T->getPointeeType();
34193419

3420-
if (T->isFunctionType())
3420+
if (T->isFunctionType()) {
34213421
encodeTypeForFunctionPointerAuth(*this, Out, T);
3422-
else
3423-
llvm_unreachable(
3424-
"type discrimination of non-function type not implemented yet");
3422+
} else {
3423+
T = T.getUnqualifiedType();
3424+
std::unique_ptr<MangleContext> MC(createMangleContext());
3425+
MC->mangleCanonicalTypeName(T, Out);
3426+
}
34253427

34263428
return llvm::getPointerAuthStableSipHash(Str);
34273429
}

clang/lib/CodeGen/CGCall.cpp

+115-98
Original file line numberDiff line numberDiff line change
@@ -5034,7 +5034,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
50345034
ReturnValueSlot ReturnValue,
50355035
const CallArgList &CallArgs,
50365036
llvm::CallBase **callOrInvoke, bool IsMustTail,
5037-
SourceLocation Loc) {
5037+
SourceLocation Loc,
5038+
bool IsVirtualFunctionPointerThunk) {
50385039
// FIXME: We no longer need the types from CallArgs; lift up and simplify.
50395040

50405041
assert(Callee.isOrdinary() || Callee.isVirtual());
@@ -5098,7 +5099,11 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
50985099
RawAddress SRetAlloca = RawAddress::invalid();
50995100
llvm::Value *UnusedReturnSizePtr = nullptr;
51005101
if (RetAI.isIndirect() || RetAI.isInAlloca() || RetAI.isCoerceAndExpand()) {
5101-
if (!ReturnValue.isNull()) {
5102+
if (IsVirtualFunctionPointerThunk && RetAI.isIndirect()) {
5103+
SRetPtr = makeNaturalAddressForPointer(CurFn->arg_begin() +
5104+
IRFunctionArgs.getSRetArgNo(),
5105+
RetTy, CharUnits::fromQuantity(1));
5106+
} else if (!ReturnValue.isNull()) {
51025107
SRetPtr = ReturnValue.getAddress();
51035108
} else {
51045109
SRetPtr = CreateMemTemp(RetTy, "tmp", &SRetAlloca);
@@ -5877,119 +5882,131 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
58775882
CallArgs.freeArgumentMemory(*this);
58785883

58795884
// Extract the return value.
5880-
RValue Ret = [&] {
5881-
switch (RetAI.getKind()) {
5882-
case ABIArgInfo::CoerceAndExpand: {
5883-
auto coercionType = RetAI.getCoerceAndExpandType();
5884-
5885-
Address addr = SRetPtr.withElementType(coercionType);
5886-
5887-
assert(CI->getType() == RetAI.getUnpaddedCoerceAndExpandType());
5888-
bool requiresExtract = isa<llvm::StructType>(CI->getType());
5885+
RValue Ret;
58895886

5890-
unsigned unpaddedIndex = 0;
5891-
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
5892-
llvm::Type *eltType = coercionType->getElementType(i);
5893-
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
5894-
Address eltAddr = Builder.CreateStructGEP(addr, i);
5895-
llvm::Value *elt = CI;
5896-
if (requiresExtract)
5897-
elt = Builder.CreateExtractValue(elt, unpaddedIndex++);
5898-
else
5899-
assert(unpaddedIndex == 0);
5900-
Builder.CreateStore(elt, eltAddr);
5887+
// If the current function is a virtual function pointer thunk, avoid copying
5888+
// the return value of the musttail call to a temporary.
5889+
if (IsVirtualFunctionPointerThunk) {
5890+
Ret = RValue::get(CI);
5891+
} else {
5892+
Ret = [&] {
5893+
switch (RetAI.getKind()) {
5894+
case ABIArgInfo::CoerceAndExpand: {
5895+
auto coercionType = RetAI.getCoerceAndExpandType();
5896+
5897+
Address addr = SRetPtr.withElementType(coercionType);
5898+
5899+
assert(CI->getType() == RetAI.getUnpaddedCoerceAndExpandType());
5900+
bool requiresExtract = isa<llvm::StructType>(CI->getType());
5901+
5902+
unsigned unpaddedIndex = 0;
5903+
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
5904+
llvm::Type *eltType = coercionType->getElementType(i);
5905+
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
5906+
continue;
5907+
Address eltAddr = Builder.CreateStructGEP(addr, i);
5908+
llvm::Value *elt = CI;
5909+
if (requiresExtract)
5910+
elt = Builder.CreateExtractValue(elt, unpaddedIndex++);
5911+
else
5912+
assert(unpaddedIndex == 0);
5913+
Builder.CreateStore(elt, eltAddr);
5914+
}
5915+
[[fallthrough]];
59015916
}
5902-
[[fallthrough]];
5903-
}
5904-
5905-
case ABIArgInfo::InAlloca:
5906-
case ABIArgInfo::Indirect: {
5907-
RValue ret = convertTempToRValue(SRetPtr, RetTy, SourceLocation());
5908-
if (UnusedReturnSizePtr)
5909-
PopCleanupBlock();
5910-
return ret;
5911-
}
59125917

5913-
case ABIArgInfo::Ignore:
5914-
// If we are ignoring an argument that had a result, make sure to
5915-
// construct the appropriate return value for our caller.
5916-
return GetUndefRValue(RetTy);
5918+
case ABIArgInfo::InAlloca:
5919+
case ABIArgInfo::Indirect: {
5920+
RValue ret = convertTempToRValue(SRetPtr, RetTy, SourceLocation());
5921+
if (UnusedReturnSizePtr)
5922+
PopCleanupBlock();
5923+
return ret;
5924+
}
59175925

5918-
case ABIArgInfo::Extend:
5919-
case ABIArgInfo::Direct: {
5920-
llvm::Type *RetIRTy = ConvertType(RetTy);
5921-
if (RetAI.getCoerceToType() == RetIRTy && RetAI.getDirectOffset() == 0) {
5922-
switch (getEvaluationKind(RetTy)) {
5923-
case TEK_Complex: {
5924-
llvm::Value *Real = Builder.CreateExtractValue(CI, 0);
5925-
llvm::Value *Imag = Builder.CreateExtractValue(CI, 1);
5926-
return RValue::getComplex(std::make_pair(Real, Imag));
5927-
}
5928-
case TEK_Aggregate: {
5929-
Address DestPtr = ReturnValue.getAddress();
5930-
bool DestIsVolatile = ReturnValue.isVolatile();
5926+
case ABIArgInfo::Ignore:
5927+
// If we are ignoring an argument that had a result, make sure to
5928+
// construct the appropriate return value for our caller.
5929+
return GetUndefRValue(RetTy);
5930+
5931+
case ABIArgInfo::Extend:
5932+
case ABIArgInfo::Direct: {
5933+
llvm::Type *RetIRTy = ConvertType(RetTy);
5934+
if (RetAI.getCoerceToType() == RetIRTy &&
5935+
RetAI.getDirectOffset() == 0) {
5936+
switch (getEvaluationKind(RetTy)) {
5937+
case TEK_Complex: {
5938+
llvm::Value *Real = Builder.CreateExtractValue(CI, 0);
5939+
llvm::Value *Imag = Builder.CreateExtractValue(CI, 1);
5940+
return RValue::getComplex(std::make_pair(Real, Imag));
5941+
}
5942+
case TEK_Aggregate: {
5943+
Address DestPtr = ReturnValue.getAddress();
5944+
bool DestIsVolatile = ReturnValue.isVolatile();
59315945

5932-
if (!DestPtr.isValid()) {
5933-
DestPtr = CreateMemTemp(RetTy, "agg.tmp");
5934-
DestIsVolatile = false;
5946+
if (!DestPtr.isValid()) {
5947+
DestPtr = CreateMemTemp(RetTy, "agg.tmp");
5948+
DestIsVolatile = false;
5949+
}
5950+
EmitAggregateStore(CI, DestPtr, DestIsVolatile);
5951+
return RValue::getAggregate(DestPtr);
5952+
}
5953+
case TEK_Scalar: {
5954+
// If the argument doesn't match, perform a bitcast to coerce it.
5955+
// This can happen due to trivial type mismatches.
5956+
llvm::Value *V = CI;
5957+
if (V->getType() != RetIRTy)
5958+
V = Builder.CreateBitCast(V, RetIRTy);
5959+
return RValue::get(V);
59355960
}
5936-
EmitAggregateStore(CI, DestPtr, DestIsVolatile);
5937-
return RValue::getAggregate(DestPtr);
5961+
}
5962+
llvm_unreachable("bad evaluation kind");
59385963
}
5939-
case TEK_Scalar: {
5940-
// If the argument doesn't match, perform a bitcast to coerce it. This
5941-
// can happen due to trivial type mismatches.
5964+
5965+
// If coercing a fixed vector from a scalable vector for ABI
5966+
// compatibility, and the types match, use the llvm.vector.extract
5967+
// intrinsic to perform the conversion.
5968+
if (auto *FixedDstTy = dyn_cast<llvm::FixedVectorType>(RetIRTy)) {
59425969
llvm::Value *V = CI;
5943-
if (V->getType() != RetIRTy)
5944-
V = Builder.CreateBitCast(V, RetIRTy);
5945-
return RValue::get(V);
5946-
}
5970+
if (auto *ScalableSrcTy =
5971+
dyn_cast<llvm::ScalableVectorType>(V->getType())) {
5972+
if (FixedDstTy->getElementType() ==
5973+
ScalableSrcTy->getElementType()) {
5974+
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
5975+
V = Builder.CreateExtractVector(FixedDstTy, V, Zero,
5976+
"cast.fixed");
5977+
return RValue::get(V);
5978+
}
5979+
}
59475980
}
5948-
llvm_unreachable("bad evaluation kind");
5949-
}
59505981

5951-
// If coercing a fixed vector from a scalable vector for ABI
5952-
// compatibility, and the types match, use the llvm.vector.extract
5953-
// intrinsic to perform the conversion.
5954-
if (auto *FixedDstTy = dyn_cast<llvm::FixedVectorType>(RetIRTy)) {
5955-
llvm::Value *V = CI;
5956-
if (auto *ScalableSrcTy =
5957-
dyn_cast<llvm::ScalableVectorType>(V->getType())) {
5958-
if (FixedDstTy->getElementType() == ScalableSrcTy->getElementType()) {
5959-
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
5960-
V = Builder.CreateExtractVector(FixedDstTy, V, Zero, "cast.fixed");
5961-
return RValue::get(V);
5962-
}
5982+
Address DestPtr = ReturnValue.getValue();
5983+
bool DestIsVolatile = ReturnValue.isVolatile();
5984+
5985+
if (!DestPtr.isValid()) {
5986+
DestPtr = CreateMemTemp(RetTy, "coerce");
5987+
DestIsVolatile = false;
59635988
}
5964-
}
59655989

5966-
Address DestPtr = ReturnValue.getValue();
5967-
bool DestIsVolatile = ReturnValue.isVolatile();
5990+
// An empty record can overlap other data (if declared with
5991+
// no_unique_address); omit the store for such types - as there is no
5992+
// actual data to store.
5993+
if (!isEmptyRecord(getContext(), RetTy, true)) {
5994+
// If the value is offset in memory, apply the offset now.
5995+
Address StorePtr = emitAddressAtOffset(*this, DestPtr, RetAI);
5996+
CreateCoercedStore(CI, StorePtr, DestIsVolatile, *this);
5997+
}
59685998

5969-
if (!DestPtr.isValid()) {
5970-
DestPtr = CreateMemTemp(RetTy, "coerce");
5971-
DestIsVolatile = false;
5999+
return convertTempToRValue(DestPtr, RetTy, SourceLocation());
59726000
}
59736001

5974-
// An empty record can overlap other data (if declared with
5975-
// no_unique_address); omit the store for such types - as there is no
5976-
// actual data to store.
5977-
if (!isEmptyRecord(getContext(), RetTy, true)) {
5978-
// If the value is offset in memory, apply the offset now.
5979-
Address StorePtr = emitAddressAtOffset(*this, DestPtr, RetAI);
5980-
CreateCoercedStore(CI, StorePtr, DestIsVolatile, *this);
6002+
case ABIArgInfo::Expand:
6003+
case ABIArgInfo::IndirectAliased:
6004+
llvm_unreachable("Invalid ABI kind for return argument");
59816005
}
59826006

5983-
return convertTempToRValue(DestPtr, RetTy, SourceLocation());
5984-
}
5985-
5986-
case ABIArgInfo::Expand:
5987-
case ABIArgInfo::IndirectAliased:
5988-
llvm_unreachable("Invalid ABI kind for return argument");
5989-
}
5990-
5991-
llvm_unreachable("Unhandled ABIArgInfo::Kind");
5992-
} ();
6007+
llvm_unreachable("Unhandled ABIArgInfo::Kind");
6008+
}();
6009+
}
59936010

59946011
// Emit the assume_aligned check on the return value.
59956012
if (Ret.isScalar() && TargetDecl) {

clang/lib/CodeGen/CGPointerAuth.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,40 @@ llvm::Constant *CodeGenModule::getFunctionPointer(GlobalDecl GD,
365365
return getFunctionPointer(getRawFunctionPointer(GD, Ty), FuncType);
366366
}
367367

368+
CGPointerAuthInfo CodeGenModule::getMemberFunctionPointerAuthInfo(QualType FT) {
369+
assert(FT->getAs<MemberPointerType>() && "MemberPointerType expected");
370+
const auto &Schema = getCodeGenOpts().PointerAuth.CXXMemberFunctionPointers;
371+
if (!Schema)
372+
return CGPointerAuthInfo();
373+
374+
assert(!Schema.isAddressDiscriminated() &&
375+
"function pointers cannot use address-specific discrimination");
376+
377+
llvm::ConstantInt *Discriminator =
378+
getPointerAuthOtherDiscriminator(Schema, GlobalDecl(), FT);
379+
return CGPointerAuthInfo(Schema.getKey(), Schema.getAuthenticationMode(),
380+
/* IsIsaPointer */ false,
381+
/* AuthenticatesNullValues */ false, Discriminator);
382+
}
383+
384+
llvm::Constant *CodeGenModule::getMemberFunctionPointer(llvm::Constant *Pointer,
385+
QualType FT) {
386+
if (CGPointerAuthInfo PointerAuth = getMemberFunctionPointerAuthInfo(FT))
387+
return getConstantSignedPointer(
388+
Pointer, PointerAuth.getKey(), nullptr,
389+
cast_or_null<llvm::ConstantInt>(PointerAuth.getDiscriminator()));
390+
391+
return Pointer;
392+
}
393+
394+
llvm::Constant *CodeGenModule::getMemberFunctionPointer(const FunctionDecl *FD,
395+
llvm::Type *Ty) {
396+
QualType FT = FD->getType();
397+
FT = getContext().getMemberPointerType(
398+
FT, cast<CXXMethodDecl>(FD)->getParent()->getTypeForDecl());
399+
return getMemberFunctionPointer(getRawFunctionPointer(FD, Ty), FT);
400+
}
401+
368402
std::optional<PointerAuthQualifier>
369403
CodeGenModule::computeVTPointerAuthentication(const CXXRecordDecl *ThisClass) {
370404
auto DefaultAuthentication = getCodeGenOpts().PointerAuth.CXXVTablePointers;

clang/lib/CodeGen/CodeGenFunction.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -4374,7 +4374,8 @@ class CodeGenFunction : public CodeGenTypeCache {
43744374
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
43754375
ReturnValueSlot ReturnValue, const CallArgList &Args,
43764376
llvm::CallBase **callOrInvoke, bool IsMustTail,
4377-
SourceLocation Loc);
4377+
SourceLocation Loc,
4378+
bool IsVirtualFunctionPointerThunk = false);
43784379
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
43794380
ReturnValueSlot ReturnValue, const CallArgList &Args,
43804381
llvm::CallBase **callOrInvoke = nullptr,

clang/lib/CodeGen/CodeGenModule.h

+8
Original file line numberDiff line numberDiff line change
@@ -973,8 +973,16 @@ class CodeGenModule : public CodeGenTypeCache {
973973
llvm::Constant *getFunctionPointer(llvm::Constant *Pointer,
974974
QualType FunctionType);
975975

976+
llvm::Constant *getMemberFunctionPointer(const FunctionDecl *FD,
977+
llvm::Type *Ty = nullptr);
978+
979+
llvm::Constant *getMemberFunctionPointer(llvm::Constant *Pointer,
980+
QualType FT);
981+
976982
CGPointerAuthInfo getFunctionPointerAuthInfo(QualType T);
977983

984+
CGPointerAuthInfo getMemberFunctionPointerAuthInfo(QualType FT);
985+
978986
CGPointerAuthInfo getPointerAuthInfoForPointeeType(QualType type);
979987

980988
CGPointerAuthInfo getPointerAuthInfoForType(QualType type);

0 commit comments

Comments
 (0)