Skip to content

Commit a01656a

Browse files
committed
hardware intrinsic in BitArray.*Shift
1 parent e29dedd commit a01656a

File tree

1 file changed

+129
-42
lines changed
  • src/libraries/System.Private.CoreLib/src/System/Collections

1 file changed

+129
-42
lines changed

src/libraries/System.Private.CoreLib/src/System/Collections/BitArray.cs

Lines changed: 129 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -513,51 +513,90 @@ public BitArray RightShift(int count)
513513
return this;
514514
}
515515

516-
Span<int> intSpan = MemoryMarshal.Cast<byte, int>((Span<byte>)_array);
517-
516+
Span<byte> thisSpan = new Span<byte>(_array, 0, GetByteArrayLengthFromBitLength(_bitLength));
518517
int toIndex = 0;
519-
int ints = GetInt32ArrayLengthFromBitLength(_bitLength);
518+
520519
if (count < _bitLength)
521520
{
522-
// We can not use Math.DivRem without taking a dependency on System.Runtime.Extensions
523-
(int fromIndex, int shiftCount) = Math.DivRem(count, 32);
524-
int extraBits = (int)((uint)_bitLength % 32);
521+
(int fromIndex, int shiftCount) = Math.DivRem(count, BitsPerByte);
525522
if (shiftCount == 0)
526523
{
527-
// Cannot use `(1u << extraBits) - 1u` as the mask
528-
// because for extraBits == 0, we need the mask to be 111...111, not 0.
529-
// In that case, we are shifting a uint by 32, which could be considered undefined.
530-
// The result of a shift operation is undefined ... if the right operand
531-
// is greater than or equal to the width in bits of the promoted left operand,
532-
// https://learn.microsoft.com/cpp/c-language/bitwise-shift-operators?view=vs-2017
533-
// However, the compiler protects us from undefined behaviour by constraining the
534-
// right operand to between 0 and width - 1 (inclusive), i.e. right_operand = (right_operand % width).
535-
uint mask = uint.MaxValue >> (BitsPerInt32 - extraBits);
536-
intSpan[ints - 1] &= ReverseIfBE((int)mask);
537-
538-
intSpan.Slice((int)fromIndex, ints - fromIndex).CopyTo(intSpan);
539-
toIndex = ints - fromIndex;
524+
thisSpan.Slice(fromIndex).CopyTo(thisSpan);
525+
toIndex = thisSpan.Length - fromIndex;
540526
}
541527
else
542528
{
543-
int lastIndex = ints - 1;
529+
if (Vector512.IsHardwareAccelerated)
530+
{
531+
toIndex = Apply<Vector512<byte>>(shiftCount, fromIndex, thisSpan);
532+
}
533+
else if (Vector256.IsHardwareAccelerated)
534+
{
535+
toIndex = Apply<Vector256<byte>>(shiftCount, fromIndex, thisSpan);
536+
}
537+
else if (Vector128.IsHardwareAccelerated)
538+
{
539+
toIndex = Apply<Vector128<byte>>(shiftCount, fromIndex, thisSpan);
540+
}
541+
fromIndex += toIndex;
542+
543+
ref byte p = ref MemoryMarshal.GetReference(thisSpan);
544544

545-
while (fromIndex < lastIndex)
545+
int carry32Count = BitsPerInt32 - shiftCount;
546+
while (fromIndex < thisSpan.Length - 4)
546547
{
547-
uint right = (uint)ReverseIfBE(intSpan[fromIndex]) >> shiftCount;
548-
int left = ReverseIfBE(intSpan[++fromIndex]) << (BitsPerInt32 - shiftCount);
549-
intSpan[toIndex++] = ReverseIfBE(left | (int)right);
548+
int lo = ReverseIfBE(Unsafe.ReadUnaligned<int>(ref Unsafe.AddByteOffset(ref p, (uint)fromIndex))) >>> shiftCount;
549+
int hi = Unsafe.AddByteOffset(ref p, (uint)(fromIndex + 4)) << carry32Count;
550+
int result = ReverseIfBE(hi | lo);
551+
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref p, toIndex), result);
552+
553+
fromIndex += 4;
554+
toIndex += 4;
550555
}
551556

552-
uint mask = uint.MaxValue >> (BitsPerInt32 - extraBits);
553-
mask &= (uint)ReverseIfBE(intSpan[fromIndex]);
554-
intSpan[toIndex++] = ReverseIfBE((int)(mask >> shiftCount));
557+
int carryCount = BitsPerByte - shiftCount;
558+
while (fromIndex < thisSpan.Length)
559+
{
560+
int lo = thisSpan[fromIndex] >>> shiftCount;
561+
int hi =
562+
fromIndex + 1 < thisSpan.Length
563+
? thisSpan[fromIndex + 1] << carryCount
564+
: 0;
565+
566+
thisSpan[toIndex] = (byte)(hi | lo);
567+
568+
fromIndex++;
569+
toIndex++;
570+
}
555571
}
556572
}
557573

558-
intSpan.Slice(toIndex, ints - toIndex).Clear();
574+
thisSpan.Slice(toIndex).Clear();
559575
_version++;
560576
return this;
577+
578+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
579+
static int Apply<TVector>(int shiftCount, int fromIndex, Span<byte> thisSpan)
580+
where TVector : ISimdVector<TVector, byte>
581+
{
582+
ref byte p = ref MemoryMarshal.GetReference(thisSpan);
583+
int carryCount = BitsPerByte - shiftCount;
584+
585+
int toIndex = 0;
586+
587+
while (fromIndex <= thisSpan.Length - (TVector.ElementCount + 1))
588+
{
589+
TVector lo = TVector.LoadUnsafe(ref p, (uint)fromIndex) >>> shiftCount;
590+
TVector hi = TVector.LoadUnsafe(ref p, (uint)(fromIndex + 1)) << carryCount;
591+
TVector result = lo | hi;
592+
result.StoreUnsafe(ref p, (uint)toIndex);
593+
594+
fromIndex += TVector.ElementCount;
595+
toIndex += TVector.ElementCount;
596+
}
597+
598+
return toIndex;
599+
}
561600
}
562601

563602
/// <summary>
@@ -576,41 +615,89 @@ public BitArray LeftShift(int count)
576615
return this;
577616
}
578617

579-
Span<int> intSpan = MemoryMarshal.Cast<byte, int>((Span<byte>)_array);
618+
Span<byte> thisSpan = new Span<byte>(_array, 0, GetByteArrayLengthFromBitLength(_bitLength));
580619

581620
int lengthToClear;
582621
if (count < _bitLength)
583622
{
584-
int lastIndex = (int)((uint)(_bitLength - 1) / BitsPerInt32);
585-
586-
(lengthToClear, int shiftCount) = Math.DivRem(count, BitsPerInt32);
623+
(lengthToClear, int shiftCount) = Math.DivRem(count, BitsPerByte);
587624

588625
if (shiftCount == 0)
589626
{
590-
intSpan.Slice(0, lastIndex + 1 - lengthToClear).CopyTo(intSpan.Slice(lengthToClear));
627+
thisSpan.Slice(0, thisSpan.Length - lengthToClear).CopyTo(thisSpan.Slice(lengthToClear));
591628
}
592629
else
593630
{
594-
int fromindex = lastIndex - lengthToClear;
631+
int toIndex = thisSpan.Length;
632+
int fromIndex = toIndex - lengthToClear;
633+
634+
if (Vector512.IsHardwareAccelerated)
635+
{
636+
toIndex = Apply<Vector512<byte>>(shiftCount, fromIndex, thisSpan);
637+
}
638+
else if (Vector256.IsHardwareAccelerated)
639+
{
640+
toIndex = Apply<Vector256<byte>>(shiftCount, fromIndex, thisSpan);
641+
}
642+
else if (Vector128.IsHardwareAccelerated)
643+
{
644+
toIndex = Apply<Vector128<byte>>(shiftCount, fromIndex, thisSpan);
645+
}
646+
fromIndex = toIndex - lengthToClear;
647+
648+
ref byte p = ref MemoryMarshal.GetReference(thisSpan);
649+
650+
int carryCount = BitsPerByte - shiftCount;
651+
while (fromIndex >= 5)
652+
{
653+
int hi = ReverseIfBE(Unsafe.ReadUnaligned<int>(ref Unsafe.AddByteOffset(ref p, (uint)(fromIndex -= 4)))) << shiftCount;
654+
int lo = Unsafe.AddByteOffset(ref p, (uint)(fromIndex - 1)) >>> carryCount;
655+
int result = ReverseIfBE(hi | lo);
656+
Unsafe.WriteUnaligned(ref Unsafe.AddByteOffset(ref p, toIndex -= 4), result);
657+
}
595658

596-
while (fromindex > 0)
659+
while (--fromIndex >= 0)
597660
{
598-
int left = ReverseIfBE(intSpan[fromindex]) << shiftCount;
599-
uint right = (uint)ReverseIfBE(intSpan[--fromindex]) >> (BitsPerInt32 - shiftCount);
600-
intSpan[lastIndex] = ReverseIfBE(left | (int)right);
601-
lastIndex--;
661+
int hi = thisSpan[fromIndex] << shiftCount;
662+
int lo =
663+
fromIndex > 0
664+
? thisSpan[fromIndex - 1] >>> carryCount
665+
: 0;
666+
667+
thisSpan[--toIndex] = (byte)(hi | lo);
602668
}
603-
intSpan[lastIndex] = ReverseIfBE(ReverseIfBE(intSpan[fromindex]) << shiftCount);
669+
670+
Debug.Assert(toIndex == lengthToClear);
604671
}
605672
}
606673
else
607674
{
608-
lengthToClear = GetInt32ArrayLengthFromBitLength(_bitLength); // Clear all
675+
lengthToClear = thisSpan.Length; // Clear all
609676
}
610677

611-
intSpan.Slice(0, lengthToClear).Clear();
678+
thisSpan.Slice(0, lengthToClear).Clear();
612679
_version++;
613680
return this;
681+
682+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
683+
static int Apply<TVector>(int shiftCount, int fromIndex, Span<byte> thisSpan)
684+
where TVector : ISimdVector<TVector, byte>
685+
{
686+
ref byte p = ref MemoryMarshal.GetReference(thisSpan);
687+
int carryCount = BitsPerByte - shiftCount;
688+
689+
int toIndex = thisSpan.Length;
690+
691+
while (fromIndex >= TVector.ElementCount + 1)
692+
{
693+
TVector hi = TVector.LoadUnsafe(ref p, (nuint)(fromIndex -= TVector.ElementCount)) << shiftCount;
694+
TVector lo = TVector.LoadUnsafe(ref p, (nuint)(fromIndex - 1)) >>> carryCount;
695+
TVector result = hi | lo;
696+
result.StoreUnsafe(ref p, (nuint)(toIndex -= TVector.ElementCount));
697+
}
698+
699+
return toIndex;
700+
}
614701
}
615702

616703
/// <summary>

0 commit comments

Comments
 (0)