Skip to content

[SPIR-V] Support SPV_INTEL_int4 extension #141031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

vmaksimo
Copy link
Contributor

@vmaksimo vmaksimo marked this pull request as ready for review May 22, 2025 11:21
@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Viktoria Maximova (vmaksimo)

Changes

Adds support for native 4-bit type.

Spec:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_int4.asciidoc


Full diff: https://github.com/llvm/llvm-project/pull/141031.diff

8 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+17-4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+3)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll (+20)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll (+29)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll (+25)
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e6cb8cee66a60..fbaca4e0e4d81 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -99,7 +99,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_INTEL_ternary_bitwise_function",
          SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function},
         {"SPV_INTEL_2d_block_io",
-         SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}};
+         SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
+        {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index ac397fc486e19..d9fcb5623b220 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -154,7 +154,8 @@ unsigned SPIRVGlobalRegistry::adjustOpTypeIntWidth(unsigned Width) const {
     report_fatal_error("Unsupported integer width!");
   const SPIRVSubtarget &ST = cast<SPIRVSubtarget>(CurMF->getSubtarget());
   if (ST.canUseExtension(
-          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers))
+          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
+      ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4))
     return Width;
   if (Width <= 8)
     Width = 8;
@@ -174,9 +175,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
   const SPIRVSubtarget &ST =
       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-    if ((!isPowerOf2_32(Width) || Width < 8) &&
-        ST.canUseExtension(
-            SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+    if (Width == 4 && ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+      MIRBuilder.buildInstr(SPIRV::OpExtension)
+          .addImm(SPIRV::Extension::SPV_INTEL_int4);
+      MIRBuilder.buildInstr(SPIRV::OpCapability)
+          .addImm(SPIRV::Capability::Int4TypeINTEL);
+    } else if ((!isPowerOf2_32(Width) || Width < 8) &&
+               ST.canUseExtension(
+                   SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
       MIRBuilder.buildInstr(SPIRV::OpExtension)
           .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
       MIRBuilder.buildInstr(SPIRV::OpCapability)
@@ -1563,6 +1569,13 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
   const MachineInstr *NewMI =
       createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
         SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
+        const Type *ET = getTypeForSPIRVType(ElemType);
+        if (ET->isIntegerTy() && ET->getIntegerBitWidth() == 4 &&
+            cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget())
+                .canUseExtension(SPIRV::Extension::SPV_INTEL_int4)) {
+          MIRBuilder.buildInstr(SPIRV::OpCapability)
+              .addImm(SPIRV::Capability::Int4CooperativeMatrixINTEL);
+        }
         return MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
             .addDef(createTypeVReg(MIRBuilder))
             .addUse(getSPIRVTypeID(ElemType))
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 578e82881f6e8..29ec90d2ae8df 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -128,7 +128,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
   bool IsExtendedInts =
       ST.canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
-      ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
+      ST.canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
+      ST.canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
   auto extendedScalarsAndVectors =
       [IsExtendedInts](const LegalityQuery &Query) {
         const LLT Ty = Query.Types[0];
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index b6a2da6e2045d..2d6d0e37c8a1c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -492,7 +492,8 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
   bool IsExtendedInts =
       ST->canUseExtension(
           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers) ||
-      ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions);
+      ST->canUseExtension(SPIRV::Extension::SPV_KHR_bit_instructions) ||
+      ST->canUseExtension(SPIRV::Extension::SPV_INTEL_int4);
 
   for (MachineBasicBlock *MBB : post_order(&MF)) {
     if (MBB->empty())
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 851495bda4979..51a59e441b5b4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -316,6 +316,7 @@ defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
 defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
 defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
 defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
+defm SPV_INTEL_int4 : ExtensionOperand<123>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -521,6 +522,8 @@ defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_
 defm Subgroup2DBlockIOINTEL : CapabilityOperand<6228, 0, 0, [SPV_INTEL_2d_block_io], []>;
 defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
 defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
+defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
+defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
new file mode 100644
index 0000000000000..02f023276bf5d
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/cooperative_matrix.ll
@@ -0,0 +1,20 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - | FileCheck %s
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4,+SPV_KHR_cooperative_matrix %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: Capability Int4TypeINTEL
+; CHECK-DAG: Capability CooperativeMatrixKHR
+; CHECK-DAG: Extension "SPV_INTEL_int4"
+; CHECK-DAG: Capability Int4CooperativeMatrixINTEL
+; CHECK-DAG: Extension "SPV_KHR_cooperative_matrix"
+
+; CHECK: %[[#Int4Ty:]] = OpTypeInt 4 0
+; CHECK: %[[#CoopMatTy:]] = OpTypeCooperativeMatrixKHR %[[#Int4Ty]]
+; CHECK: CompositeConstruct %[[#CoopMatTy]]
+
+define spir_kernel void @foo() {
+entry:
+  %call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef 0)
+  ret void
+}
+
+declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i4, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i4 noundef)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
new file mode 100644
index 0000000000000..17202ab243e8f
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/negative.ll
@@ -0,0 +1,29 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_arbitrary_precision_integers %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-4
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-INT-8
+; No error would be reported in comparison to Khronos llvm-spirv, because type adjustments to integer size are made 
+; in case no appropriate extension is enabled. Here we expect that the type is adjusted to 8 bits.
+
+; CHECK-SPIRV: Capability ArbitraryPrecisionIntegersINTEL
+; CHECK-SPIRV: Extension "SPV_INTEL_arbitrary_precision_integers"
+; CHECK-INT-4: %[[#Int4:]] = OpTypeInt 4 0
+; CHECK-INT-8: %[[#Int4:]] = OpTypeInt 8 0
+; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
+; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
+; CHECK: %[[#Const:]] = OpConstant %[[#Int4]]  1
+
+; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
+; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
+; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
+; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
+
+define spir_kernel void @foo() {
+entry:
+  %0 = alloca i4
+  store i4 1, ptr %0
+  %1 = load i4, ptr %0
+  call spir_func void @boo(i4 %1)
+  ret void
+}
+
+declare spir_func void @boo(i4)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
new file mode 100644
index 0000000000000..f1bee0b963613
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_int4/trivial.ll
@@ -0,0 +1,25 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - | FileCheck %s
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_int4 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: Capability Int4TypeINTEL
+; CHECK: Extension "SPV_INTEL_int4"
+; CHECK: %[[#Int4:]] = OpTypeInt  4 0
+; CHECK: OpTypeFunction %[[#]] %[[#Int4]]
+; CHECK: %[[#Int4PtrTy:]] = OpTypePointer Function %[[#Int4]]
+; CHECK: %[[#Const:]] = OpConstant %[[#Int4]]  1
+
+; CHECK: %[[#Int4Ptr:]] = OpVariable %[[#Int4PtrTy]] Function
+; CHECK: OpStore %[[#Int4Ptr]] %[[#Const]]
+; CHECK: %[[#Load:]] = OpLoad %[[#Int4]] %[[#Int4Ptr]]
+; CHECK: OpFunctionCall %[[#]] %[[#]] %[[#Load]]
+
+define spir_kernel void @foo() {
+entry:
+  %0 = alloca i4
+  store i4 1, ptr %0
+  %1 = load i4, ptr %0
+  call spir_func void @boo(i4 %1)
+  ret void
+}
+
+declare spir_func void @boo(i4)

@vmaksimo
Copy link
Contributor Author

@MrSidims @michalpaszkowski please take a look, thanks!

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also extend llvm/docs/SPIRVUsage.rst with a description of the new extension

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants