@@ -103,16 +103,48 @@ void amx_stz(bool pair, unsigned int z_row, const void * ptr)
103
103
AMX_STZ (oprand );
104
104
}
105
105
106
- void amx_fma32 (bool vector , unsigned int x_offset , unsigned int y_offset , int z_row )
106
+ void amx_fma16_masked (bool vector , unsigned int x_offset , unsigned int y_offset , int z_row , uint8_t x_mode , uint8_t x_mask , uint8_t y_mode , uint8_t y_mask )
107
+ {
108
+ uint64_t oprand = 0 ;
109
+ if (vector )
110
+ oprand |= 1ULL << 63 ;
111
+
112
+ oprand |= (uint64_t )y_offset & 0x1FF ;
113
+ oprand |= ((uint64_t )x_offset & 0x1FF ) << 10 ;
114
+ oprand |= ((uint64_t )z_row & 0x3F ) << 20 ;
115
+ oprand |= ((uint64_t )y_mask & 0x1F ) << 32 ;
116
+ oprand |= ((uint64_t )y_mode & 0x3 ) << 37 ;
117
+ oprand |= ((uint64_t )x_mask & 0x1F ) << 41 ;
118
+ oprand |= ((uint64_t )x_mode & 0x3 ) << 46 ;
119
+
120
+ AMX_FMA16 (oprand );
121
+ }
122
+
123
+ void amx_fma16 (bool vector , unsigned int x_offset , unsigned int y_offset , int z_row )
124
+ {
125
+ amx_fma16_masked (vector , x_offset , y_offset , z_row , 0 , 0 , 0 , 0 );
126
+ }
127
+
128
+ void amx_fma32_masked (bool vector , unsigned int x_offset , unsigned int y_offset , int z_row , uint8_t x_mode , uint8_t x_mask , uint8_t y_mode , uint8_t y_mask )
107
129
{
108
130
uint64_t oprand = 0 ;
109
131
if (vector )
110
132
oprand |= 1ULL << 63 ;
111
133
112
134
oprand |= (uint64_t )y_offset & 0x1FF ;
113
135
oprand |= ((uint64_t )x_offset & 0x1FF ) << 10 ;
136
+ oprand |= ((uint64_t )z_row & 0x3F ) << 20 ;
137
+ oprand |= ((uint64_t )y_mask & 0x1F ) << 32 ;
138
+ oprand |= ((uint64_t )y_mode & 0x3 ) << 37 ;
139
+ oprand |= ((uint64_t )x_mask & 0x1F ) << 41 ;
140
+ oprand |= ((uint64_t )x_mode & 0x3 ) << 46 ;
114
141
115
142
AMX_FMA32 (oprand );
116
143
}
117
144
145
+ void amx_fma32 (bool vector , unsigned int x_offset , unsigned int y_offset , int z_row )
146
+ {
147
+ amx_fma32_masked (vector , x_offset , y_offset , z_row , 0 , 0 , 0 , 0 );
148
+ }
149
+
118
150
#endif // AMX_USABILITY_H
0 commit comments