2
2
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3
3
4
4
using DocumentFormat . OpenXml . Features ;
5
+ using DocumentFormat . OpenXml . Framework ;
5
6
using System ;
6
7
using System . Collections . Generic ;
7
8
using System . Diagnostics ;
8
9
using System . Diagnostics . CodeAnalysis ;
9
10
using System . IO ;
10
- using System . IO . Packaging ;
11
+ using System . Threading ;
11
12
12
13
namespace DocumentFormat . OpenXml . Packaging
13
14
{
@@ -236,9 +237,7 @@ public IEnumerable<OpenXmlPart> GetParentParts()
236
237
/// <returns>The content stream of the part. </returns>
237
238
public Stream GetStream ( FileMode mode )
238
239
{
239
- ThrowIfObjectDisposed ( ) ;
240
-
241
- return PackagePart . GetStream ( mode , Features . GetRequired < IPackageFeature > ( ) . Package . FileOpenAccess ) ;
240
+ return GetStream ( mode , Features . GetRequired < IPackageFeature > ( ) . Package . FileOpenAccess ) ;
242
241
}
243
242
244
243
/// <summary>
@@ -251,7 +250,20 @@ public Stream GetStream(FileMode mode, FileAccess access)
251
250
{
252
251
ThrowIfObjectDisposed ( ) ;
253
252
254
- return PackagePart . GetStream ( mode , access ) ;
253
+ var stream = PackagePart . GetStream ( mode , access ) ;
254
+
255
+ if ( mode is FileMode . Create || stream . Length == 0 )
256
+ {
257
+ UnloadRootElement ( ) ;
258
+ return new UnloadingRootElementStream ( this , stream ) ;
259
+ }
260
+
261
+ if ( stream . CanWrite )
262
+ {
263
+ return new UnloadingRootElementStream ( this , stream ) ;
264
+ }
265
+
266
+ return stream ;
255
267
}
256
268
257
269
/// <summary>
@@ -605,5 +617,76 @@ internal sealed override OpenXmlPart ThisOpenXmlPart
605
617
internal MarkupCompatibilityProcessSettings ? MCSettings { get ; set ; }
606
618
607
619
#endregion
620
+
621
+ /// <summary>
622
+ /// A <see cref="Stream"/> used by <see cref="GetStream(FileMode, FileAccess)" /> to unload the root if updated.
623
+ /// </summary>
624
+ private sealed class UnloadingRootElementStream : DelegatingStream
625
+ {
626
+ private readonly OpenXmlPart _part ;
627
+
628
+ private bool _hasWritten ;
629
+
630
+ public UnloadingRootElementStream ( OpenXmlPart part , Stream innerStream )
631
+ : base ( innerStream )
632
+ {
633
+ _part = part ;
634
+ }
635
+
636
+ protected override void Dispose ( bool disposing )
637
+ {
638
+ if ( disposing && _hasWritten )
639
+ {
640
+ _part . UnloadRootElement ( ) ;
641
+ }
642
+
643
+ base . Dispose ( disposing ) ;
644
+ }
645
+
646
+ public override void Write ( byte [ ] buffer , int offset , int count )
647
+ {
648
+ NotifyOfWrite ( ) ;
649
+ base . Write ( buffer , offset , count ) ;
650
+ }
651
+
652
+ #if NET46_OR_GREATER || NET || NETSTANDARD
653
+ public override System . Threading . Tasks . Task WriteAsync ( byte [ ] buffer , int offset , int count , CancellationToken cancellationToken )
654
+ {
655
+ NotifyOfWrite ( ) ;
656
+ return base . WriteAsync ( buffer , offset , count , cancellationToken ) ;
657
+ }
658
+ #endif
659
+
660
+ #if NET6_0_OR_GREATER
661
+ public override void Write ( ReadOnlySpan < byte > buffer )
662
+ {
663
+ NotifyOfWrite ( ) ;
664
+ base . Write ( buffer ) ;
665
+ }
666
+
667
+ public override System . Threading . Tasks . ValueTask WriteAsync ( ReadOnlyMemory < byte > buffer , CancellationToken cancellationToken = default )
668
+ {
669
+ NotifyOfWrite ( ) ;
670
+ return base . WriteAsync ( buffer , cancellationToken ) ;
671
+ }
672
+ #endif
673
+
674
+ public override void WriteByte ( byte value )
675
+ {
676
+ NotifyOfWrite ( ) ;
677
+ base . WriteByte ( value ) ;
678
+ }
679
+
680
+ public override IAsyncResult BeginWrite ( byte [ ] buffer , int offset , int count , AsyncCallback ? callback , object ? state )
681
+ {
682
+ NotifyOfWrite ( ) ;
683
+ return base . BeginWrite ( buffer , offset , count , callback , state ) ;
684
+ }
685
+
686
+ private void NotifyOfWrite ( )
687
+ {
688
+ _hasWritten = true ;
689
+ }
690
+ }
608
691
}
609
692
}
0 commit comments