Skip to content

Commit 522f9f3

Browse files
authored
Unload root element if Part.GetStream updates the underlying value (#1760)
Fixes #1755
1 parent f1fecd3 commit 522f9f3

File tree

2 files changed

+160
-5
lines changed

2 files changed

+160
-5
lines changed

src/DocumentFormat.OpenXml.Framework/Packaging/OpenXmlPart.cs

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using DocumentFormat.OpenXml.Features;
5+
using DocumentFormat.OpenXml.Framework;
56
using System;
67
using System.Collections.Generic;
78
using System.Diagnostics;
89
using System.Diagnostics.CodeAnalysis;
910
using System.IO;
10-
using System.IO.Packaging;
11+
using System.Threading;
1112

1213
namespace DocumentFormat.OpenXml.Packaging
1314
{
@@ -236,9 +237,7 @@ public IEnumerable<OpenXmlPart> GetParentParts()
236237
/// <returns>The content stream of the part. </returns>
237238
public Stream GetStream(FileMode mode)
238239
{
239-
ThrowIfObjectDisposed();
240-
241-
return PackagePart.GetStream(mode, Features.GetRequired<IPackageFeature>().Package.FileOpenAccess);
240+
return GetStream(mode, Features.GetRequired<IPackageFeature>().Package.FileOpenAccess);
242241
}
243242

244243
/// <summary>
@@ -251,7 +250,20 @@ public Stream GetStream(FileMode mode, FileAccess access)
251250
{
252251
ThrowIfObjectDisposed();
253252

254-
return PackagePart.GetStream(mode, access);
253+
var stream = PackagePart.GetStream(mode, access);
254+
255+
if (mode is FileMode.Create || stream.Length == 0)
256+
{
257+
UnloadRootElement();
258+
return new UnloadingRootElementStream(this, stream);
259+
}
260+
261+
if (stream.CanWrite)
262+
{
263+
return new UnloadingRootElementStream(this, stream);
264+
}
265+
266+
return stream;
255267
}
256268

257269
/// <summary>
@@ -605,5 +617,76 @@ internal sealed override OpenXmlPart ThisOpenXmlPart
605617
internal MarkupCompatibilityProcessSettings? MCSettings { get; set; }
606618

607619
#endregion
620+
621+
/// <summary>
622+
/// A <see cref="Stream"/> used by <see cref="GetStream(FileMode, FileAccess)" /> to unload the root if updated.
623+
/// </summary>
624+
private sealed class UnloadingRootElementStream : DelegatingStream
625+
{
626+
private readonly OpenXmlPart _part;
627+
628+
private bool _hasWritten;
629+
630+
public UnloadingRootElementStream(OpenXmlPart part, Stream innerStream)
631+
: base(innerStream)
632+
{
633+
_part = part;
634+
}
635+
636+
protected override void Dispose(bool disposing)
637+
{
638+
if (disposing && _hasWritten)
639+
{
640+
_part.UnloadRootElement();
641+
}
642+
643+
base.Dispose(disposing);
644+
}
645+
646+
public override void Write(byte[] buffer, int offset, int count)
647+
{
648+
NotifyOfWrite();
649+
base.Write(buffer, offset, count);
650+
}
651+
652+
#if NET46_OR_GREATER || NET || NETSTANDARD
653+
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
654+
{
655+
NotifyOfWrite();
656+
return base.WriteAsync(buffer, offset, count, cancellationToken);
657+
}
658+
#endif
659+
660+
#if NET6_0_OR_GREATER
661+
public override void Write(ReadOnlySpan<byte> buffer)
662+
{
663+
NotifyOfWrite();
664+
base.Write(buffer);
665+
}
666+
667+
public override System.Threading.Tasks.ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
668+
{
669+
NotifyOfWrite();
670+
return base.WriteAsync(buffer, cancellationToken);
671+
}
672+
#endif
673+
674+
public override void WriteByte(byte value)
675+
{
676+
NotifyOfWrite();
677+
base.WriteByte(value);
678+
}
679+
680+
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
681+
{
682+
NotifyOfWrite();
683+
return base.BeginWrite(buffer, offset, count, callback, state);
684+
}
685+
686+
private void NotifyOfWrite()
687+
{
688+
_hasWritten = true;
689+
}
690+
}
608691
}
609692
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using DocumentFormat.OpenXml.Spreadsheet;
5+
using System;
6+
using System.IO;
7+
using System.Text;
8+
using Xunit;
9+
10+
namespace DocumentFormat.OpenXml.Packaging.Tests;
11+
12+
public class OpenXmlPartTests
13+
{
14+
[InlineData(FileAccess.Write)]
15+
[InlineData(FileAccess.ReadWrite)]
16+
[Theory]
17+
public void GetStreamWrite(FileAccess access)
18+
{
19+
// Arrange
20+
const string expected = """<x:sst xmlns:x="http://schemas.openxmlformats.org/spreadsheetml/2006/main"><x:si><x:t>Test</x:t></x:si></x:sst>""";
21+
var stream = new MemoryStream();
22+
{
23+
using var package = SpreadsheetDocument.Create(stream, SpreadsheetDocumentType.Workbook);
24+
var wb = package.AddWorkbookPart();
25+
26+
var part = wb.AddNewPart<SharedStringTablePart>();
27+
28+
part.SharedStringTable = new SharedStringTable();
29+
30+
using var partStream = part.GetStream(FileMode.Create, access);
31+
32+
var bytes = Encoding.UTF8.GetBytes(expected);
33+
partStream.Write(bytes, 0, bytes.Length);
34+
}
35+
36+
// Reopen package
37+
stream.Position = 0;
38+
using var spreadsheet = SpreadsheetDocument.Open(stream, isEditable: false);
39+
40+
// Assert
41+
Assert.Equal(expected, spreadsheet.WorkbookPart.SharedStringTablePart.RootElement.OuterXml);
42+
}
43+
44+
[InlineData(FileAccess.Write)]
45+
[InlineData(FileAccess.ReadWrite)]
46+
[Theory]
47+
public void GetStreamWriteNoUpdates(FileAccess access)
48+
{
49+
// Arrange
50+
const string expected = """<x:sst xmlns:x="http://schemas.openxmlformats.org/spreadsheetml/2006/main" />""";
51+
var stream = new MemoryStream();
52+
{
53+
using var package = SpreadsheetDocument.Create(stream, SpreadsheetDocumentType.Workbook);
54+
var wb = package.AddWorkbookPart();
55+
56+
var part = wb.AddNewPart<SharedStringTablePart>();
57+
58+
part.SharedStringTable = new SharedStringTable();
59+
60+
package.Save();
61+
62+
using var partStream = part.GetStream(FileMode.Open, access);
63+
}
64+
65+
// Reopen package
66+
stream.Position = 0;
67+
using var spreadsheet = SpreadsheetDocument.Open(stream, isEditable: false);
68+
69+
// Assert
70+
Assert.Equal(expected, spreadsheet.WorkbookPart.SharedStringTablePart.RootElement.OuterXml);
71+
}
72+
}

0 commit comments

Comments
 (0)