@@ -164,17 +164,22 @@ spv_result_t NumConsumedLocations(ValidationState_t& _, const Instruction* type,
164
164
}
165
165
break ;
166
166
case spv::Op::OpTypeMatrix:
167
- // Matrices consume locations equal to the underlying vector type for
168
- // each column.
169
- NumConsumedLocations (_, _.FindDef (type->GetOperandAs <uint32_t >(1 )),
170
- num_locations);
167
+ // Matrices consume locations equivalent to arrays of 4-component vectors.
168
+ if (_.ContainsSizedIntOrFloatType (type->id (), spv::Op::OpTypeInt, 64 ) ||
169
+ _.ContainsSizedIntOrFloatType (type->id (), spv::Op::OpTypeFloat, 64 )) {
170
+ *num_locations = 2 ;
171
+ } else {
172
+ *num_locations = 1 ;
173
+ }
171
174
*num_locations *= type->GetOperandAs <uint32_t >(2 );
172
175
break ;
173
176
case spv::Op::OpTypeArray: {
174
177
// Arrays consume locations equal to the underlying type times the number
175
178
// of elements in the vector.
176
- NumConsumedLocations (_, _.FindDef (type->GetOperandAs <uint32_t >(1 )),
177
- num_locations);
179
+ if (auto error = NumConsumedLocations (
180
+ _, _.FindDef (type->GetOperandAs <uint32_t >(1 )), num_locations)) {
181
+ return error;
182
+ }
178
183
bool is_int = false ;
179
184
bool is_const = false ;
180
185
uint32_t value = 0 ;
@@ -244,10 +249,31 @@ uint32_t NumConsumedComponents(ValidationState_t& _, const Instruction* type) {
244
249
NumConsumedComponents (_, _.FindDef (type->GetOperandAs <uint32_t >(1 )));
245
250
num_components *= type->GetOperandAs <uint32_t >(2 );
246
251
break ;
247
- case spv::Op::OpTypeArray:
248
- // Skip the array.
249
- return NumConsumedComponents (_,
250
- _.FindDef (type->GetOperandAs <uint32_t >(1 )));
252
+ case spv::Op::OpTypeMatrix:
253
+ // Matrices consume all components of the location.
254
+ // Round up to next multiple of 4.
255
+ num_components =
256
+ NumConsumedComponents (_, _.FindDef (type->GetOperandAs <uint32_t >(1 )));
257
+ num_components *= type->GetOperandAs <uint32_t >(2 );
258
+ num_components = ((num_components + 3 ) / 4 ) * 4 ;
259
+ break ;
260
+ case spv::Op::OpTypeArray: {
261
+ // Arrays consume all components of the location.
262
+ // Round up to next multiple of 4.
263
+ num_components =
264
+ NumConsumedComponents (_, _.FindDef (type->GetOperandAs <uint32_t >(1 )));
265
+
266
+ bool is_int = false ;
267
+ bool is_const = false ;
268
+ uint32_t value = 0 ;
269
+ // Attempt to evaluate the number of array elements.
270
+ std::tie (is_int, is_const, value) =
271
+ _.EvalInt32IfConst (type->GetOperandAs <uint32_t >(2 ));
272
+ if (is_int && is_const) num_components *= value;
273
+
274
+ num_components = ((num_components + 3 ) / 4 ) * 4 ;
275
+ return num_components;
276
+ }
251
277
case spv::Op::OpTypePointer:
252
278
if (_.addressing_model () ==
253
279
spv::AddressingModel::PhysicalStorageBuffer64 &&
@@ -330,9 +356,10 @@ spv_result_t GetLocationsForVariable(
330
356
}
331
357
}
332
358
333
- // Vulkan 14.1.3: Tessellation control and mesh per-vertex outputs and
334
- // tessellation control, evaluation and geometry per-vertex inputs have a
335
- // layer of arraying that is not included in interface matching.
359
+ // Vulkan 15.1.3 (Interface Matching): Tessellation control and mesh
360
+ // per-vertex outputs and tessellation control, evaluation and geometry
361
+ // per-vertex inputs have a layer of arraying that is not included in
362
+ // interface matching.
336
363
bool is_arrayed = false ;
337
364
switch (entry_point->GetOperandAs <spv::ExecutionModel>(0 )) {
338
365
case spv::ExecutionModel::TessellationControl:
@@ -386,51 +413,33 @@ spv_result_t GetLocationsForVariable(
386
413
387
414
const std::string storage_class = is_output ? " output" : " input" ;
388
415
if (has_location) {
389
- auto sub_type = type;
390
- bool is_int = false ;
391
- bool is_const = false ;
392
- uint32_t array_size = 1 ;
393
- // If the variable is still arrayed, mark the locations/components per
394
- // index.
395
- if (type->opcode () == spv::Op::OpTypeArray) {
396
- // Determine the array size if possible and get the element type.
397
- std::tie (is_int, is_const, array_size) =
398
- _.EvalInt32IfConst (type->GetOperandAs <uint32_t >(2 ));
399
- if (!is_int || !is_const) array_size = 1 ;
400
- auto sub_type_id = type->GetOperandAs <uint32_t >(1 );
401
- sub_type = _.FindDef (sub_type_id);
402
- }
403
-
404
416
uint32_t num_locations = 0 ;
405
- if (auto error = NumConsumedLocations (_, sub_type , &num_locations))
417
+ if (auto error = NumConsumedLocations (_, type , &num_locations))
406
418
return error;
407
- uint32_t num_components = NumConsumedComponents (_, sub_type );
419
+ uint32_t num_components = NumConsumedComponents (_, type );
408
420
409
- for (uint32_t array_idx = 0 ; array_idx < array_size; ++array_idx) {
410
- uint32_t array_location = location + (num_locations * array_idx);
411
- uint32_t start = array_location * 4 ;
412
- if (kMaxLocations <= start) {
413
- // Too many locations, give up.
414
- break ;
415
- }
421
+ uint32_t start = location * 4 ;
422
+ uint32_t end = (location + num_locations) * 4 ;
423
+ if (num_components % 4 != 0 ) {
424
+ start += component;
425
+ end = start + num_components;
426
+ }
416
427
417
- uint32_t end = (array_location + num_locations) * 4 ;
418
- if (num_components != 0 ) {
419
- start += component;
420
- end = array_location * 4 + component + num_components;
421
- }
428
+ if (kMaxLocations <= start) {
429
+ // Too many locations, give up.
430
+ return SPV_SUCCESS;
431
+ }
422
432
423
- auto locs = locations;
424
- if (has_index && index == 1 ) locs = output_index1_locations;
433
+ auto locs = locations;
434
+ if (has_index && index == 1 ) locs = output_index1_locations;
425
435
426
- for (uint32_t i = start; i < end; ++i) {
427
- if (!locs->insert (i).second ) {
428
- return _.diag (SPV_ERROR_INVALID_DATA, entry_point)
429
- << (is_output ? _.VkErrorID (8722 ) : _.VkErrorID (8721 ))
430
- << " Entry-point has conflicting " << storage_class
431
- << " location assignment at location " << i / 4
432
- << " , component " << i % 4 ;
433
- }
436
+ for (uint32_t i = start; i < end; ++i) {
437
+ if (!locs->insert (i).second ) {
438
+ return _.diag (SPV_ERROR_INVALID_DATA, entry_point)
439
+ << (is_output ? _.VkErrorID (8722 ) : _.VkErrorID (8721 ))
440
+ << " Entry-point has conflicting " << storage_class
441
+ << " location assignment at location " << i / 4 << " , component "
442
+ << i % 4 ;
434
443
}
435
444
}
436
445
} else {
@@ -489,38 +498,19 @@ spv_result_t GetLocationsForVariable(
489
498
continue ;
490
499
}
491
500
492
- if (member->opcode () == spv::Op::OpTypeArray && num_components >= 1 &&
493
- num_components < 4 ) {
494
- // When an array has an element that takes less than a location in
495
- // size, calculate the used locations in a strided manner.
496
- for (uint32_t l = location; l < num_locations + location; ++l) {
497
- for (uint32_t c = component; c < component + num_components; ++c) {
498
- uint32_t check = 4 * l + c;
499
- if (!locations->insert (check).second ) {
500
- return _.diag (SPV_ERROR_INVALID_DATA, entry_point)
501
- << (is_output ? _.VkErrorID (8722 ) : _.VkErrorID (8721 ))
502
- << " Entry-point has conflicting " << storage_class
503
- << " location assignment at location " << l
504
- << " , component " << c;
505
- }
506
- }
507
- }
508
- } else {
509
- // TODO: There is a hole here is the member is an array of 3- or
510
- // 4-element vectors of 64-bit types.
511
- uint32_t end = (location + num_locations) * 4 ;
512
- if (num_components != 0 ) {
513
- start += component;
514
- end = location * 4 + component + num_components;
515
- }
516
- for (uint32_t l = start; l < end; ++l) {
517
- if (!locations->insert (l).second ) {
518
- return _.diag (SPV_ERROR_INVALID_DATA, entry_point)
519
- << (is_output ? _.VkErrorID (8722 ) : _.VkErrorID (8721 ))
520
- << " Entry-point has conflicting " << storage_class
521
- << " location assignment at location " << l / 4
522
- << " , component " << l % 4 ;
523
- }
501
+ uint32_t end = (location + num_locations) * 4 ;
502
+ if (num_components % 4 != 0 ) {
503
+ start += component;
504
+ end = location * 4 + component + num_components;
505
+ }
506
+
507
+ for (uint32_t l = start; l < end; ++l) {
508
+ if (!locations->insert (l).second ) {
509
+ return _.diag (SPV_ERROR_INVALID_DATA, entry_point)
510
+ << (is_output ? _.VkErrorID (8722 ) : _.VkErrorID (8721 ))
511
+ << " Entry-point has conflicting " << storage_class
512
+ << " location assignment at location " << l / 4
513
+ << " , component " << l % 4 ;
524
514
}
525
515
}
526
516
}
0 commit comments