@@ -23,14 +23,9 @@ namespace System.IO
23
23
* of the UnmanagedMemoryStream.
24
24
* 3) You clean up the memory when appropriate. The UnmanagedMemoryStream
25
25
* currently will do NOTHING to free this memory.
26
- * 4) All calls to Write and WriteByte may not be threadsafe currently.
27
- *
28
- * It may become necessary to add in some sort of
29
- * DeallocationMode enum, specifying whether we unmap a section of memory,
30
- * call free, run a user-provided delegate to free the memory, etc.
31
- * We'll suggest user write a subclass of UnmanagedMemoryStream that uses
32
- * a SafeHandle subclass to hold onto the memory.
33
- *
26
+ * 4) This type is not thread safe. However, the implementation should prevent buffer
27
+ * overruns or returning uninitialized memory when Reads and Writes are called
28
+ * concurrently in thread unsafe manner.
34
29
*/
35
30
36
31
/// <summary>
@@ -40,10 +35,10 @@ public class UnmanagedMemoryStream : Stream
40
35
{
41
36
private SafeBuffer ? _buffer ;
42
37
private unsafe byte * _mem ;
43
- private long _length ;
44
- private long _capacity ;
45
- private long _position ;
46
- private long _offset ;
38
+ private nuint _capacity ;
39
+ private nuint _offset ;
40
+ private nuint _length ; // nuint to guarantee atomic access on 32-bit platforms
41
+ private long _position ; // long to allow seeking to any location beyond the length of the stream.
47
42
private FileAccess _access ;
48
43
private bool _isOpen ;
49
44
private CachedCompletedInt32Task _lastReadTask ; // The last successful task returned from ReadAsync
@@ -123,10 +118,10 @@ protected void Initialize(SafeBuffer buffer, long offset, long length, FileAcces
123
118
}
124
119
}
125
120
126
- _offset = offset ;
121
+ _offset = ( nuint ) offset ;
127
122
_buffer = buffer ;
128
- _length = length ;
129
- _capacity = length ;
123
+ _length = ( nuint ) length ;
124
+ _capacity = ( nuint ) length ;
130
125
_access = access ;
131
126
_isOpen = true ;
132
127
}
@@ -171,8 +166,8 @@ protected unsafe void Initialize(byte* pointer, long length, long capacity, File
171
166
172
167
_mem = pointer ;
173
168
_offset = 0 ;
174
- _length = length ;
175
- _capacity = capacity ;
169
+ _length = ( nuint ) length ;
170
+ _capacity = ( nuint ) capacity ;
176
171
_access = access ;
177
172
_isOpen = true ;
178
173
}
@@ -259,7 +254,7 @@ public override long Length
259
254
get
260
255
{
261
256
EnsureNotClosed ( ) ;
262
- return Interlocked . Read ( ref _length ) ;
257
+ return ( long ) _length ;
263
258
}
264
259
}
265
260
@@ -271,7 +266,7 @@ public long Capacity
271
266
get
272
267
{
273
268
EnsureNotClosed ( ) ;
274
- return _capacity ;
269
+ return ( long ) _capacity ;
275
270
}
276
271
}
277
272
@@ -283,14 +278,14 @@ public override long Position
283
278
get
284
279
{
285
280
if ( ! CanSeek ) ThrowHelper . ThrowObjectDisposedException_StreamClosed ( null ) ;
286
- return Interlocked . Read ( ref _position ) ;
281
+ return _position ;
287
282
}
288
283
set
289
284
{
290
285
ArgumentOutOfRangeException . ThrowIfNegative ( value ) ;
291
286
if ( ! CanSeek ) ThrowHelper . ThrowObjectDisposedException_StreamClosed ( null ) ;
292
287
293
- Interlocked . Exchange ( ref _position , value ) ;
288
+ _position = value ;
294
289
}
295
290
}
296
291
@@ -308,11 +303,10 @@ public unsafe byte* PositionPointer
308
303
EnsureNotClosed ( ) ;
309
304
310
305
// Use a temp to avoid a race
311
- long pos = Interlocked . Read ( ref _position ) ;
312
- if ( pos > _capacity )
306
+ long pos = _position ;
307
+ if ( pos > ( long ) _capacity )
313
308
throw new IndexOutOfRangeException ( SR . IndexOutOfRange_UMSPosition ) ;
314
- byte * ptr = _mem + pos ;
315
- return ptr ;
309
+ return _mem + pos ;
316
310
}
317
311
set
318
312
{
@@ -327,7 +321,7 @@ public unsafe byte* PositionPointer
327
321
if ( newPosition < 0 )
328
322
throw new ArgumentOutOfRangeException ( nameof ( value ) , SR . ArgumentOutOfRange_UnmanagedMemStreamLength ) ;
329
323
330
- Interlocked . Exchange ( ref _position , newPosition ) ;
324
+ _position = newPosition ;
331
325
}
332
326
}
333
327
@@ -367,8 +361,13 @@ internal int ReadCore(Span<byte> buffer)
367
361
368
362
// Use a local variable to avoid a race where another thread
369
363
// changes our position after we decide we can read some bytes.
370
- long pos = Interlocked . Read ( ref _position ) ;
371
- long len = Interlocked . Read ( ref _length ) ;
364
+ long pos = _position ;
365
+
366
+ // Use a volatile read to prevent reading of the uninitialized memory. This volatile read
367
+ // and matching volatile write that set _length avoids reordering of NativeMemory.Clear
368
+ // operations with reading of the buffer below.
369
+ long len = ( long ) Volatile . Read ( ref _length ) ;
370
+
372
371
long n = Math . Min ( len - pos , buffer . Length ) ;
373
372
if ( n <= 0 )
374
373
{
@@ -407,7 +406,7 @@ internal int ReadCore(Span<byte> buffer)
407
406
}
408
407
}
409
408
410
- Interlocked . Exchange ( ref _position , pos + n ) ;
409
+ _position = pos + n ;
411
410
return nInt;
412
411
}
413
412
@@ -484,11 +483,16 @@ public override int ReadByte()
484
483
EnsureNotClosed ( ) ;
485
484
EnsureReadable ( ) ;
486
485
487
- long pos = Interlocked . Read ( ref _position ) ; // Use a local to avoid a race condition
488
- long len = Interlocked . Read ( ref _length ) ;
486
+ long pos = _position ; // Use a local to avoid a race condition
487
+
488
+ // Use a volatile read to prevent reading of the uninitialized memory. This volatile read
489
+ // and matching volatile write that set _length avoids reordering of NativeMemory.Clear
490
+ // operations with reading of the buffer below.
491
+ long len = ( long ) Volatile . Read ( ref _length ) ;
492
+
489
493
if ( pos >= len )
490
494
return - 1 ;
491
- Interlocked . Exchange ( ref _position , pos + 1 ) ;
495
+ _position = pos + 1 ;
492
496
int result ;
493
497
if ( _buffer != null )
494
498
{
@@ -529,35 +533,33 @@ public override long Seek(long offset, SeekOrigin loc)
529
533
{
530
534
EnsureNotClosed ( ) ;
531
535
536
+ long newPosition ;
532
537
switch ( loc )
533
538
{
534
539
case SeekOrigin . Begin :
535
- if ( offset < 0 )
540
+ newPosition = offset ;
541
+ if ( newPosition < 0 )
536
542
throw new IOException ( SR . IO_SeekBeforeBegin ) ;
537
- Interlocked . Exchange ( ref _position , offset ) ;
538
543
break ;
539
544
540
545
case SeekOrigin . Current :
541
- long pos = Interlocked . Read ( ref _position ) ;
542
- if ( offset + pos < 0 )
546
+ newPosition = _position + offset ;
547
+ if ( newPosition < 0 )
543
548
throw new IOException ( SR . IO_SeekBeforeBegin ) ;
544
- Interlocked . Exchange ( ref _position , offset + pos ) ;
545
549
break ;
546
550
547
551
case SeekOrigin . End :
548
- long len = Interlocked . Read ( ref _length ) ;
549
- if ( len + offset < 0 )
552
+ newPosition = ( long ) _length + offset ;
553
+ if ( newPosition < 0 )
550
554
throw new IOException ( SR . IO_SeekBeforeBegin ) ;
551
- Interlocked . Exchange ( ref _position , len + offset ) ;
552
555
break ;
553
556
554
557
default :
555
558
throw new ArgumentException ( SR . Argument_InvalidSeekOrigin ) ;
556
559
}
557
560
558
- long finalPos = Interlocked . Read ( ref _position ) ;
559
- Debug . Assert ( finalPos >= 0 , "_position >= 0" ) ;
560
- return finalPos ;
561
+ _position = newPosition ;
562
+ return newPosition ;
561
563
}
562
564
563
565
/// <summary>
@@ -573,22 +575,22 @@ public override void SetLength(long value)
573
575
EnsureNotClosed ( ) ;
574
576
EnsureWriteable ( ) ;
575
577
576
- if ( value > _capacity )
578
+ if ( value > ( long ) _capacity )
577
579
throw new IOException ( SR . IO_FixedCapacity ) ;
578
580
579
- long pos = Interlocked . Read ( ref _position ) ;
580
- long len = Interlocked . Read ( ref _length ) ;
581
+ long len = ( long ) _length ;
581
582
if ( value > len )
582
583
{
583
584
unsafe
584
585
{
585
586
NativeMemory . Clear ( _mem + len , ( nuint ) ( value - len ) ) ;
586
587
}
587
588
}
588
- Interlocked . Exchange ( ref _length , value ) ;
589
- if ( pos > value )
589
+ Volatile . Write ( ref _length , ( nuint ) value ) ; // volatile to prevent reading of uninitialized memory
590
+
591
+ if ( _position > value )
590
592
{
591
- Interlocked . Exchange ( ref _position , value ) ;
593
+ _position = value ;
592
594
}
593
595
}
594
596
@@ -625,16 +627,16 @@ internal unsafe void WriteCore(ReadOnlySpan<byte> buffer)
625
627
EnsureNotClosed ( ) ;
626
628
EnsureWriteable ( ) ;
627
629
628
- long pos = Interlocked . Read ( ref _position ) ; // Use a local to avoid a race condition
629
- long len = Interlocked . Read ( ref _length ) ;
630
+ long pos = _position ; // Use a local to avoid a race condition
631
+ long len = ( long ) _length ;
630
632
long n = pos + buffer . Length ;
631
633
// Check for overflow
632
634
if ( n < 0 )
633
635
{
634
636
throw new IOException ( SR . IO_StreamTooLong ) ;
635
637
}
636
638
637
- if ( n > _capacity )
639
+ if ( n > ( long ) _capacity )
638
640
{
639
641
throw new NotSupportedException ( SR . IO_FixedCapacity ) ;
640
642
}
@@ -648,16 +650,16 @@ internal unsafe void WriteCore(ReadOnlySpan<byte> buffer)
648
650
NativeMemory . Clear ( _mem + len , ( nuint ) ( pos - len ) ) ;
649
651
}
650
652
651
- // set length after zeroing memory to avoid race condition of accessing unzeroed memory
653
+ // set length after zeroing memory to avoid race condition of accessing uninitialized memory
652
654
if ( n > len )
653
655
{
654
- Interlocked . Exchange ( ref _length , n ) ;
656
+ Volatile . Write ( ref _length , ( nuint ) n ) ; // volatile to prevent reading of uninitialized memory
655
657
}
656
658
}
657
659
658
660
if ( _buffer != null )
659
661
{
660
- long bytesLeft = _capacity - pos ;
662
+ long bytesLeft = ( long ) _capacity - pos ;
661
663
if ( bytesLeft < buffer . Length )
662
664
{
663
665
throw new ArgumentException ( SR . Arg_BufferTooSmall ) ;
@@ -682,8 +684,7 @@ internal unsafe void WriteCore(ReadOnlySpan<byte> buffer)
682
684
Buffer . Memmove ( ref * ( _mem + pos ) , ref MemoryMarshal . GetReference ( buffer ) , ( nuint ) buffer . Length ) ;
683
685
}
684
686
685
- Interlocked . Exchange ( ref _position , n ) ;
686
- return ;
687
+ _position = n ;
687
688
}
688
689
689
690
/// <summary>
@@ -754,16 +755,16 @@ public override void WriteByte(byte value)
754
755
EnsureNotClosed ( ) ;
755
756
EnsureWriteable ( ) ;
756
757
757
- long pos = Interlocked . Read ( ref _position ) ; // Use a local to avoid a race condition
758
- long len = Interlocked . Read ( ref _length ) ;
758
+ long pos = _position ; // Use a local to avoid a race condition
759
+ long len = ( long ) _length ;
759
760
long n = pos + 1 ;
760
761
if ( pos >= len )
761
762
{
762
763
// Check for overflow
763
764
if ( n < 0 )
764
765
throw new IOException ( SR . IO_StreamTooLong ) ;
765
766
766
- if ( n > _capacity )
767
+ if ( n > ( long ) _capacity )
767
768
throw new NotSupportedException ( SR . IO_FixedCapacity ) ;
768
769
769
770
// Check to see whether we are now expanding the stream and must
@@ -779,8 +780,7 @@ public override void WriteByte(byte value)
779
780
}
780
781
}
781
782
782
- // set length after zeroing memory to avoid race condition of accessing unzeroed memory
783
- Interlocked . Exchange ( ref _length , n ) ;
783
+ Volatile . Write ( ref _length , ( nuint ) n ) ; // volatile to prevent reading of uninitialized memory
784
784
}
785
785
}
786
786
@@ -810,7 +810,7 @@ public override void WriteByte(byte value)
810
810
_mem [ pos ] = value ;
811
811
}
812
812
}
813
- Interlocked . Exchange ( ref _position , n ) ;
813
+ _position = n ;
814
814
}
815
815
}
816
816
}
0 commit comments