diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 89ea7ef4dbe89..bdfa5f7741ad3 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -1062,6 +1062,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer { SDValue WidenVecRes_EXTRACT_SUBVECTOR(SDNode* N); SDValue WidenVecRes_INSERT_SUBVECTOR(SDNode *N); SDValue WidenVecRes_INSERT_VECTOR_ELT(SDNode* N); + SDValue WidenVecRes_ATOMIC_LOAD(AtomicSDNode *N); SDValue WidenVecRes_LOAD(SDNode* N); SDValue WidenVecRes_VP_LOAD(VPLoadSDNode *N); SDValue WidenVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *N); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 8eee7a4c61fe6..6b3467573a0a2 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -4625,6 +4625,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) { break; case ISD::EXTRACT_SUBVECTOR: Res = WidenVecRes_EXTRACT_SUBVECTOR(N); break; case ISD::INSERT_VECTOR_ELT: Res = WidenVecRes_INSERT_VECTOR_ELT(N); break; + case ISD::ATOMIC_LOAD: + Res = WidenVecRes_ATOMIC_LOAD(cast(N)); + break; case ISD::LOAD: Res = WidenVecRes_LOAD(N); break; case ISD::STEP_VECTOR: case ISD::SPLAT_VECTOR: @@ -6014,6 +6017,77 @@ SDValue DAGTypeLegalizer::WidenVecRes_INSERT_VECTOR_ELT(SDNode *N) { N->getOperand(1), N->getOperand(2)); } +/// Either return the same load or provide appropriate casts +/// from the load and return that. +static SDValue loadElement(SDValue LdOp, EVT FirstVT, EVT WidenVT, + TypeSize LdWidth, TypeSize FirstVTWidth, SDLoc dl, + SelectionDAG &DAG) { + assert(TypeSize::isKnownLE(LdWidth, FirstVTWidth)); + TypeSize WidenWidth = WidenVT.getSizeInBits(); + if (!FirstVT.isVector()) { + unsigned NumElts = + WidenWidth.getFixedValue() / FirstVTWidth.getFixedValue(); + EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), FirstVT, NumElts); + SDValue VecOp = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, NewVecVT, LdOp); + return DAG.getNode(ISD::BITCAST, dl, WidenVT, VecOp); + } else if (FirstVT == WidenVT) + return LdOp; + else { + // TODO: We don't currently have any tests that exercise this code path. + assert(!"Unimplemented"); + } +} + +static std::optional findMemType(SelectionDAG &DAG, + const TargetLowering &TLI, unsigned Width, + EVT WidenVT, unsigned Align, + unsigned WidenEx); + +SDValue DAGTypeLegalizer::WidenVecRes_ATOMIC_LOAD(AtomicSDNode *LD) { + EVT WidenVT = + TLI.getTypeToTransformTo(*DAG.getContext(), LD->getValueType(0)); + EVT LdVT = LD->getMemoryVT(); + SDLoc dl(LD); + assert(LdVT.isVector() && WidenVT.isVector() && "Expected vectors"); + assert(LdVT.isScalableVector() == WidenVT.isScalableVector() && + "Must be scalable"); + assert(LdVT.getVectorElementType() == WidenVT.getVectorElementType() && + "Expected equivalent element types"); + + // Load information + SDValue Chain = LD->getChain(); + SDValue BasePtr = LD->getBasePtr(); + MachineMemOperand::Flags MMOFlags = LD->getMemOperand()->getFlags(); + AAMDNodes AAInfo = LD->getAAInfo(); + + TypeSize LdWidth = LdVT.getSizeInBits(); + TypeSize WidenWidth = WidenVT.getSizeInBits(); + TypeSize WidthDiff = WidenWidth - LdWidth; + + // Find the vector type that can load from. + std::optional FirstVT = + findMemType(DAG, TLI, LdWidth.getKnownMinValue(), WidenVT, /*LdAlign=*/0, + WidthDiff.getKnownMinValue()); + + if (!FirstVT) + return SDValue(); + + SmallVector MemVTs; + TypeSize FirstVTWidth = FirstVT->getSizeInBits(); + + SDValue LdOp = DAG.getAtomicLoad(ISD::NON_EXTLOAD, dl, *FirstVT, *FirstVT, + Chain, BasePtr, LD->getMemOperand()); + + // Load the element with one instruction. + SDValue Result = + loadElement(LdOp, *FirstVT, WidenVT, LdWidth, FirstVTWidth, dl, DAG); + + // Modified the chain - switch anything that used the old chain to use + // the new one. + ReplaceValueWith(SDValue(LD, 1), LdOp.getValue(1)); + return Result; +} + SDValue DAGTypeLegalizer::WidenVecRes_LOAD(SDNode *N) { LoadSDNode *LD = cast(N); ISD::LoadExtType ExtType = LD->getExtensionType(); @@ -7897,27 +7971,7 @@ SDValue DAGTypeLegalizer::GenWidenVectorLoads(SmallVectorImpl &LdChain, // Check if we can load the element with one instruction. if (MemVTs.empty()) { - assert(TypeSize::isKnownLE(LdWidth, FirstVTWidth)); - if (!FirstVT->isVector()) { - unsigned NumElts = - WidenWidth.getFixedValue() / FirstVTWidth.getFixedValue(); - EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), *FirstVT, NumElts); - SDValue VecOp = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, NewVecVT, LdOp); - return DAG.getNode(ISD::BITCAST, dl, WidenVT, VecOp); - } - if (FirstVT == WidenVT) - return LdOp; - - // TODO: We don't currently have any tests that exercise this code path. - assert(WidenWidth.getFixedValue() % FirstVTWidth.getFixedValue() == 0); - unsigned NumConcat = - WidenWidth.getFixedValue() / FirstVTWidth.getFixedValue(); - SmallVector ConcatOps(NumConcat); - SDValue UndefVal = DAG.getUNDEF(*FirstVT); - ConcatOps[0] = LdOp; - for (unsigned i = 1; i != NumConcat; ++i) - ConcatOps[i] = UndefVal; - return DAG.getNode(ISD::CONCAT_VECTORS, dl, WidenVT, ConcatOps); + return loadElement(LdOp, *FirstVT, WidenVT, LdWidth, FirstVTWidth, dl, DAG); } // Load vector by using multiple loads from largest vector to scalar. diff --git a/llvm/test/CodeGen/X86/atomic-load-store.ll b/llvm/test/CodeGen/X86/atomic-load-store.ll index 39e9fdfa5e62b..9ee8b4fc5ac7f 100644 --- a/llvm/test/CodeGen/X86/atomic-load-store.ll +++ b/llvm/test/CodeGen/X86/atomic-load-store.ll @@ -146,6 +146,64 @@ define <1 x i64> @atomic_vec1_i64_align(ptr %x) nounwind { ret <1 x i64> %ret } +define <2 x i8> @atomic_vec2_i8(ptr %x) { +; CHECK3-LABEL: atomic_vec2_i8: +; CHECK3: ## %bb.0: +; CHECK3-NEXT: movzwl (%rdi), %eax +; CHECK3-NEXT: movd %eax, %xmm0 +; CHECK3-NEXT: retq +; +; CHECK0-LABEL: atomic_vec2_i8: +; CHECK0: ## %bb.0: +; CHECK0-NEXT: movw (%rdi), %cx +; CHECK0-NEXT: ## implicit-def: $eax +; CHECK0-NEXT: movw %cx, %ax +; CHECK0-NEXT: movd %eax, %xmm0 +; CHECK0-NEXT: retq + %ret = load atomic <2 x i8>, ptr %x acquire, align 4 + ret <2 x i8> %ret +} + +define <2 x i16> @atomic_vec2_i16(ptr %x) { +; CHECK-LABEL: atomic_vec2_i16: +; CHECK: ## %bb.0: +; CHECK-NEXT: movl (%rdi), %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: retq + %ret = load atomic <2 x i16>, ptr %x acquire, align 4 + ret <2 x i16> %ret +} + +define <2 x ptr addrspace(270)> @atomic_vec2_ptr270(ptr %x) { +; CHECK-LABEL: atomic_vec2_ptr270: +; CHECK: ## %bb.0: +; CHECK-NEXT: movq (%rdi), %rax +; CHECK-NEXT: movq %rax, %xmm0 +; CHECK-NEXT: retq + %ret = load atomic <2 x ptr addrspace(270)>, ptr %x acquire, align 8 + ret <2 x ptr addrspace(270)> %ret +} + +define <2 x i32> @atomic_vec2_i32_align(ptr %x) { +; CHECK-LABEL: atomic_vec2_i32_align: +; CHECK: ## %bb.0: +; CHECK-NEXT: movq (%rdi), %rax +; CHECK-NEXT: movq %rax, %xmm0 +; CHECK-NEXT: retq + %ret = load atomic <2 x i32>, ptr %x acquire, align 8 + ret <2 x i32> %ret +} + +define <2 x float> @atomic_vec2_float_align(ptr %x) { +; CHECK-LABEL: atomic_vec2_float_align: +; CHECK: ## %bb.0: +; CHECK-NEXT: movq (%rdi), %rax +; CHECK-NEXT: movq %rax, %xmm0 +; CHECK-NEXT: retq + %ret = load atomic <2 x float>, ptr %x acquire, align 8 + ret <2 x float> %ret +} + define <1 x ptr> @atomic_vec1_ptr(ptr %x) nounwind { ; CHECK3-LABEL: atomic_vec1_ptr: ; CHECK3: ## %bb.0: @@ -295,6 +353,26 @@ define <2 x i32> @atomic_vec2_i32(ptr %x) nounwind { ret <2 x i32> %ret } +define <4 x i8> @atomic_vec4_i8(ptr %x) nounwind { +; CHECK-LABEL: atomic_vec4_i8: +; CHECK: ## %bb.0: +; CHECK-NEXT: movl (%rdi), %eax +; CHECK-NEXT: movd %eax, %xmm0 +; CHECK-NEXT: retq + %ret = load atomic <4 x i8>, ptr %x acquire, align 4 + ret <4 x i8> %ret +} + +define <4 x i16> @atomic_vec4_i16(ptr %x) nounwind { +; CHECK-LABEL: atomic_vec4_i16: +; CHECK: ## %bb.0: +; CHECK-NEXT: movq (%rdi), %rax +; CHECK-NEXT: movq %rax, %xmm0 +; CHECK-NEXT: retq + %ret = load atomic <4 x i16>, ptr %x acquire, align 8 + ret <4 x i16> %ret +} + define <4 x float> @atomic_vec4_float_align(ptr %x) nounwind { ; CHECK-LABEL: atomic_vec4_float_align: ; CHECK: ## %bb.0: