@@ -712,25 +712,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
712
712
713
713
static inline void memsetRemainPattern (hipStream_t Stream, uint32_t PatternSize,
714
714
size_t Size, const void *pPattern,
715
- hipDeviceptr_t Ptr) {
715
+ hipDeviceptr_t Ptr,
716
+ uint32_t StartOffset) {
717
+ // Calculate the number of times the pattern needs to be applied
718
+ auto Height = Size / PatternSize;
716
719
717
- // Calculate the number of patterns, stride and the number of times the
718
- // pattern needs to be applied.
719
- auto NumberOfSteps = PatternSize / sizeof (uint8_t );
720
- auto Pitch = NumberOfSteps * sizeof (uint8_t );
721
- auto Height = Size / NumberOfSteps;
722
-
723
- for (auto step = 4u ; step < NumberOfSteps; ++step) {
720
+ for (auto step = StartOffset; step < PatternSize; ++step) {
724
721
// take 1 byte of the pattern
725
722
auto Value = *(static_cast <const uint8_t *>(pPattern) + step);
726
723
727
724
// offset the pointer to the part of the buffer we want to write to
728
- auto OffsetPtr = reinterpret_cast < void *>( reinterpret_cast < uint8_t *>(Ptr) +
729
- (step * sizeof ( uint8_t )) );
725
+ auto OffsetPtr =
726
+ reinterpret_cast < void *>( reinterpret_cast < uint8_t *>(Ptr) + step );
730
727
731
728
// set all of the pattern chunks
732
- UR_CHECK_ERROR (hipMemset2DAsync (OffsetPtr, Pitch, Value, sizeof ( uint8_t ),
733
- Height, Stream));
729
+ UR_CHECK_ERROR (
730
+ hipMemset2DAsync (OffsetPtr, PatternSize, Value, 1u , Height, Stream));
734
731
}
735
732
}
736
733
@@ -743,11 +740,55 @@ static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize,
743
740
ur_result_t commonMemSetLargePattern (hipStream_t Stream, uint32_t PatternSize,
744
741
size_t Size, const void *pPattern,
745
742
hipDeviceptr_t Ptr) {
743
+ // Find the largest supported word size into which the pattern can be divided
744
+ auto BackendWordSize = PatternSize % 4u == 0u ? 4u
745
+ : PatternSize % 2u == 0u ? 2u
746
+ : 1u ;
747
+
748
+ // Calculate the number of patterns
749
+ auto NumberOfSteps = PatternSize / BackendWordSize;
750
+
751
+ // If the pattern is 1 word or the first word is repeated throughout, a fast
752
+ // continuous fill can be used without the need for slower strided fills
753
+ bool UseOnlyFirstValue{true };
754
+ auto checkIfFirstWordRepeats = [&UseOnlyFirstValue,
755
+ NumberOfSteps](const auto *pPatternWords) {
756
+ for (auto Step{1u }; (Step < NumberOfSteps) && UseOnlyFirstValue; ++Step) {
757
+ if (*(pPatternWords + Step) != *pPatternWords) {
758
+ UseOnlyFirstValue = false ;
759
+ }
760
+ }
761
+ };
746
762
747
- // Get 4-byte chunk of the pattern and call hipMemsetD32Async
748
- auto Count32 = Size / sizeof (uint32_t );
749
- auto Value = *(static_cast <const uint32_t *>(pPattern));
750
- UR_CHECK_ERROR (hipMemsetD32Async (Ptr, Value, Count32, Stream));
763
+ // Use a continuous fill for the first word in the pattern because it's faster
764
+ // than a strided fill. Then, overwrite the other values in subsequent steps.
765
+ switch (BackendWordSize) {
766
+ case 4u : {
767
+ auto *pPatternWords = static_cast <const uint32_t *>(pPattern);
768
+ checkIfFirstWordRepeats (pPatternWords);
769
+ UR_CHECK_ERROR (
770
+ hipMemsetD32Async (Ptr, *pPatternWords, Size / BackendWordSize, Stream));
771
+ break ;
772
+ }
773
+ case 2u : {
774
+ auto *pPatternWords = static_cast <const uint16_t *>(pPattern);
775
+ checkIfFirstWordRepeats (pPatternWords);
776
+ UR_CHECK_ERROR (
777
+ hipMemsetD16Async (Ptr, *pPatternWords, Size / BackendWordSize, Stream));
778
+ break ;
779
+ }
780
+ default : {
781
+ auto *pPatternWords = static_cast <const uint8_t *>(pPattern);
782
+ checkIfFirstWordRepeats (pPatternWords);
783
+ UR_CHECK_ERROR (
784
+ hipMemsetD8Async (Ptr, *pPatternWords, Size / BackendWordSize, Stream));
785
+ break ;
786
+ }
787
+ }
788
+
789
+ if (UseOnlyFirstValue) {
790
+ return UR_RESULT_SUCCESS;
791
+ }
751
792
752
793
// There is a bug in ROCm prior to 6.0.0 version which causes hipMemset2D
753
794
// to behave incorrectly when acting on host pinned memory.
@@ -761,7 +802,7 @@ ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
761
802
// we need to check that isManaged attribute is false.
762
803
if (ptrAttribs.hostPointer && !ptrAttribs.isManaged ) {
763
804
const auto NumOfCopySteps = Size / PatternSize;
764
- const auto Offset = sizeof ( uint32_t ) ;
805
+ const auto Offset = BackendWordSize ;
765
806
const auto LeftPatternSize = PatternSize - Offset;
766
807
const auto OffsetPatternPtr = reinterpret_cast <const void *>(
767
808
reinterpret_cast <const uint8_t *>(pPattern) + Offset);
@@ -776,10 +817,12 @@ ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
776
817
Stream));
777
818
}
778
819
} else {
779
- memsetRemainPattern (Stream, PatternSize, Size, pPattern, Ptr);
820
+ memsetRemainPattern (Stream, PatternSize, Size, pPattern, Ptr,
821
+ BackendWordSize);
780
822
}
781
823
#else
782
- memsetRemainPattern (Stream, PatternSize, Size, pPattern, Ptr);
824
+ memsetRemainPattern (Stream, PatternSize, Size, pPattern, Ptr,
825
+ BackendWordSize);
783
826
#endif
784
827
return UR_RESULT_SUCCESS;
785
828
}
0 commit comments