@@ -1319,26 +1319,35 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1319
1319
if (EQClasses.size () < 2 )
1320
1320
return ;
1321
1321
1322
+ auto CopyMetaDataFromTo = [&](Instruction *Src, Instruction *Dst) {
1323
+ SmallVector<std::pair<unsigned , MDNode *>, 4 > MD;
1324
+ Src->getAllMetadata (MD);
1325
+ for (const auto [ID, Node] : MD) {
1326
+ Dst->setMetadata (ID, Node);
1327
+ }
1328
+ };
1329
+
1322
1330
// For each class, determine if all instructions are of type int, FP or ptr.
1323
1331
// This information will help us determine the type instructions should be
1324
1332
// casted into.
1325
1333
MapVector<EqClassKey, Bitset<3 >> ClassAllTy;
1326
- for (auto C : EQClasses) {
1327
- if (all_of (EQClasses[C.first ],
1328
- [](Instruction *I) {
1329
- return I->getType ()->isIntOrIntVectorTy ();
1330
- }))
1331
- ClassAllTy[C.first ].set (0 );
1332
- else if (all_of (EQClasses[C.first ],
1333
- [](Instruction *I) {
1334
- return I->getType ()->isFPOrFPVectorTy ();
1335
- }))
1336
- ClassAllTy[C.first ].set (1 );
1337
- else if (all_of (EQClasses[C.first ],
1338
- [](Instruction *I) {
1339
- return I->getType ()->isPtrOrPtrVectorTy ();
1340
- }))
1341
- ClassAllTy[C.first ].set (2 );
1334
+ for (const auto &C : EQClasses) {
1335
+ auto CommonTypeKind = [](Instruction *I) {
1336
+ if (I->getType ()->isIntOrIntVectorTy ())
1337
+ return 0 ;
1338
+ if (I->getType ()->isFPOrFPVectorTy ())
1339
+ return 1 ;
1340
+ if (I->getType ()->isPtrOrPtrVectorTy ())
1341
+ return 2 ;
1342
+ return -1 ; // Invalid type kind
1343
+ };
1344
+
1345
+ int FirstTypeKind = CommonTypeKind (EQClasses[C.first ][0 ]);
1346
+ if (FirstTypeKind != -1 && all_of (EQClasses[C.first ], [&](Instruction *I) {
1347
+ return CommonTypeKind (I) == FirstTypeKind;
1348
+ })) {
1349
+ ClassAllTy[C.first ].set (FirstTypeKind);
1350
+ }
1342
1351
}
1343
1352
1344
1353
// Loop over all equivalence classes and try to merge them. Keep track of
@@ -1362,6 +1371,11 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1362
1371
if (Ptr1 != Ptr2 || AS1 != AS2 || IsLoad1 != IsLoad2 || TySize1 < TySize2)
1363
1372
continue ;
1364
1373
1374
+ // An All-FP class should only be merged into another All-FP class.
1375
+ if ((ClassAllTy[EC1.first ].test (1 ) && !ClassAllTy[EC2.first ].test (1 )) ||
1376
+ (!ClassAllTy[EC1.first ].test (2 ) && ClassAllTy[EC2.first ].test (2 )))
1377
+ continue ;
1378
+
1365
1379
// Ensure all instructions in EC2 can be bitcasted into NewTy.
1366
1380
// / TODO: NewTyBits is needed as stuctured binded variables cannot be
1367
1381
// / captured by a lambda until C++20.
@@ -1381,13 +1395,14 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1381
1395
NewTy = Type::getFloatTy (Ctx);
1382
1396
else if (NewTyBits == 64 )
1383
1397
NewTy = Type::getDoubleTy (Ctx);
1384
- } else if (ClassAllTy[EC1.first ].test (2 ) && ClassAllTy[EC2.first ].test (2 )) {
1398
+ } else if (ClassAllTy[EC1.first ].test (2 ) &&
1399
+ ClassAllTy[EC2.first ].test (2 )) {
1385
1400
NewTy = PointerType::get (Ctx, AS2);
1386
1401
}
1387
1402
1388
1403
for (auto *Inst : EC2.second ) {
1389
- auto *Ptr = getLoadStorePointerOperand (Inst);
1390
- auto *OrigTy = Inst->getType ();
1404
+ Value *Ptr = getLoadStorePointerOperand (Inst);
1405
+ Type *OrigTy = Inst->getType ();
1391
1406
if (OrigTy == NewTy)
1392
1407
continue ;
1393
1408
if (auto *LI = dyn_cast<LoadInst>(Inst)) {
@@ -1406,6 +1421,7 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1406
1421
SI->getValueOperand ()->getName () + " .cast" );
1407
1422
auto *NewStore = Builder.CreateStore (
1408
1423
Cast, getLoadStorePointerOperand (SI), SI->isVolatile ());
1424
+ CopyMetaDataFromTo (SI, NewStore);
1409
1425
SI->eraseFromParent ();
1410
1426
EQClasses[EC1.first ].emplace_back (NewStore);
1411
1427
}
@@ -1415,7 +1431,7 @@ void Vectorizer::insertCastsToMergeClasses(EquivalenceClassMap &EQClasses) {
1415
1431
// basic block. This is important to ensure that the instructions are
1416
1432
// vectorized in the correct order.
1417
1433
std::sort (EQClasses[EC1.first ].begin (), EQClasses[EC1.first ].end (),
1418
- [](Instruction *A, Instruction *B) {
1434
+ [](const Instruction *A, const Instruction *B) {
1419
1435
return A && B && A->comesBefore (B);
1420
1436
});
1421
1437
ClassesToErase.insert (EC2.first );
0 commit comments