diff --git a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/BaseSerializer.cs b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/BaseSerializer.cs index 844b21f34032..764920b2edfd 100644 --- a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/BaseSerializer.cs +++ b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/BaseSerializer.cs @@ -28,6 +28,24 @@ protected static string ReadString(Stream stream) } } + protected static string ReadStringValue(Stream stream, int size) + { + byte[] bytes = ArrayPool.Shared.Rent(size); + try + { +#if NET7_0_OR_GREATER + stream.ReadExactly(bytes, 0, size); +#else + _ = stream.Read(bytes, 0, size); +#endif + return Encoding.UTF8.GetString(bytes, 0, size); + } + finally + { + ArrayPool.Shared.Return(bytes); + } + } + protected static void WriteString(Stream stream, string str) { int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str); @@ -47,6 +65,21 @@ protected static void WriteString(Stream stream, string str) } } + protected static void WriteStringValue(Stream stream, string str) + { + int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str); + byte[] bytes = ArrayPool.Shared.Rent(stringutf8TotalBytes); + try + { + Encoding.UTF8.GetBytes(str, bytes); + stream.Write(bytes, 0, stringutf8TotalBytes); + } + finally + { + ArrayPool.Shared.Return(bytes); + } + } + protected static void WriteStringSize(Stream stream, string str) { int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str); @@ -57,9 +90,10 @@ protected static void WriteStringSize(Stream stream, string str) } } - protected static void WriteSize(Stream stream) + protected static void WriteSize(Stream stream) + where T : struct { - int sizeInBytes = GetSize(); + int sizeInBytes = GetSize(); Span len = stackalloc byte[sizeof(int)]; if (BitConverter.TryWriteBytes(len, sizeInBytes)) @@ -139,6 +173,14 @@ protected static string ReadString(Stream stream) return Encoding.UTF8.GetString(bytes); } + protected static string ReadStringValue(Stream stream, int size) + { + byte[] bytes = new byte[size]; + _ = stream.Read(bytes, 0, bytes.Length); + + return Encoding.UTF8.GetString(bytes); + } + protected static void WriteString(Stream stream, string str) { byte[] bytes = Encoding.UTF8.GetBytes(str); @@ -147,6 +189,12 @@ protected static void WriteString(Stream stream, string str) stream.Write(bytes, 0, bytes.Length); } + protected static void WriteStringValue(Stream stream, string str) + { + byte[] bytes = Encoding.UTF8.GetBytes(str); + stream.Write(bytes, 0, bytes.Length); + } + protected static void WriteStringSize(Stream stream, string str) { byte[] bytes = Encoding.UTF8.GetBytes(str); @@ -154,11 +202,12 @@ protected static void WriteStringSize(Stream stream, string str) stream.Write(len, 0, len.Length); } - protected static void WriteSize(Stream stream) + protected static void WriteSize(Stream stream) + where T : struct { - int sizeInBytes = GetSize(); + int sizeInBytes = GetSize(); byte[] len = BitConverter.GetBytes(sizeInBytes); - stream.Write(len, 0, sizeInBytes); + stream.Write(len, 0, len.Length); } protected static void WriteInt(Stream stream, int value) @@ -227,7 +276,7 @@ protected static void WriteField(Stream stream, ushort id, string? value) WriteShort(stream, id); WriteStringSize(stream, value); - WriteString(stream, value); + WriteStringValue(stream, value); } protected static void WriteField(Stream stream, string? value) @@ -258,7 +307,7 @@ protected static void WriteField(Stream stream, ushort id, bool? value) } WriteShort(stream, id); - WriteSize(stream); + WriteSize(stream); WriteBool(stream, value.Value); } @@ -270,7 +319,7 @@ protected static void WriteField(Stream stream, ushort id, byte? value) } WriteShort(stream, id); - WriteSize(stream); + WriteSize(stream); WriteByte(stream, value.Value); } @@ -290,6 +339,7 @@ protected static void WriteAtPosition(Stream stream, int value, long position) Type type when type == typeof(long) => sizeof(long), Type type when type == typeof(short) => sizeof(short), Type type when type == typeof(bool) => sizeof(bool), + Type type when type == typeof(byte) => sizeof(byte), _ => 0, }; diff --git a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/CommandLineOptionMessagesSerializer.cs b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/CommandLineOptionMessagesSerializer.cs index 6de12ac5b6f3..953bdb51f11e 100644 --- a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/CommandLineOptionMessagesSerializer.cs +++ b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/CommandLineOptionMessagesSerializer.cs @@ -59,7 +59,7 @@ public object Deserialize(Stream stream) switch (fieldId) { case CommandLineOptionMessagesFieldsId.ModulePath: - moduleName = ReadString(stream); + moduleName = ReadStringValue(stream, fieldSize); break; case CommandLineOptionMessagesFieldsId.CommandLineOptionMessageList: @@ -96,11 +96,11 @@ private static List ReadCommandLineOptionMessagesPaylo switch (fieldId) { case CommandLineOptionMessageFieldsId.Name: - name = ReadString(stream); + name = ReadStringValue(stream, fieldSize); break; case CommandLineOptionMessageFieldsId.Description: - description = ReadString(stream); + description = ReadStringValue(stream, fieldSize); break; case CommandLineOptionMessageFieldsId.IsHidden: diff --git a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/DiscoveredTestMessagesSerializer.cs b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/DiscoveredTestMessagesSerializer.cs index 5ee7be6dadc7..354b7ccaa6d9 100644 --- a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/DiscoveredTestMessagesSerializer.cs +++ b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/DiscoveredTestMessagesSerializer.cs @@ -51,7 +51,7 @@ public object Deserialize(Stream stream) switch (fieldId) { case DiscoveredTestMessagesFieldsId.ExecutionId: - executionId = ReadString(stream); + executionId = ReadStringValue(stream, fieldSize); break; case DiscoveredTestMessagesFieldsId.DiscoveredTestMessageList: @@ -87,11 +87,11 @@ private static List ReadDiscoveredTestMessagesPayload(Str switch (fieldId) { case DiscoveredTestMessageFieldsId.Uid: - uid = ReadString(stream); + uid = ReadStringValue(stream, fieldSize); break; case DiscoveredTestMessageFieldsId.DisplayName: - displayName = ReadString(stream); + displayName = ReadStringValue(stream, fieldSize); break; default: diff --git a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/FileArtifactMessagesSerializer.cs b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/FileArtifactMessagesSerializer.cs index 8329f2ec2177..4e12cb152722 100644 --- a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/FileArtifactMessagesSerializer.cs +++ b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/FileArtifactMessagesSerializer.cs @@ -67,7 +67,7 @@ public object Deserialize(Stream stream) switch (fieldId) { case FileArtifactMessagesFieldsId.ExecutionId: - executionId = ReadString(stream); + executionId = ReadStringValue(stream, fieldSize); break; case FileArtifactMessagesFieldsId.FileArtifactMessageList: @@ -103,27 +103,27 @@ private static List ReadFileArtifactMessagesPayload(Stream switch (fieldId) { case FileArtifactMessageFieldsId.FullPath: - fullPath = ReadString(stream); + fullPath = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.DisplayName: - displayName = ReadString(stream); + displayName = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.Description: - description = ReadString(stream); + description = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.TestUid: - testUid = ReadString(stream); + testUid = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.TestDisplayName: - testDisplayName = ReadString(stream); + testDisplayName = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.SessionUid: - sessionUid = ReadString(stream); + sessionUid = ReadStringValue(stream, fieldSize); break; default: diff --git a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/TestResultMessagesSerializer.cs b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/TestResultMessagesSerializer.cs index 689bae9c6734..f40520a8707d 100644 --- a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/TestResultMessagesSerializer.cs +++ b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/TestResultMessagesSerializer.cs @@ -99,7 +99,7 @@ public object Deserialize(Stream stream) switch (fieldId) { case TestResultMessagesFieldsId.ExecutionId: - executionId = ReadString(stream); + executionId = ReadStringValue(stream, fieldSize); break; case TestResultMessagesFieldsId.SuccessfulTestMessageList: @@ -143,11 +143,11 @@ private static List ReadSuccessfulTestMessagesPaylo switch (fieldId) { case SuccessfulTestResultMessageFieldsId.Uid: - uid = ReadString(stream); + uid = ReadStringValue(stream, fieldSize); break; case SuccessfulTestResultMessageFieldsId.DisplayName: - displayName = ReadString(stream); + displayName = ReadStringValue(stream, fieldSize); break; case SuccessfulTestResultMessageFieldsId.State: @@ -155,11 +155,11 @@ private static List ReadSuccessfulTestMessagesPaylo break; case SuccessfulTestResultMessageFieldsId.Reason: - reason = ReadString(stream); + reason = ReadStringValue(stream, fieldSize); break; case SuccessfulTestResultMessageFieldsId.SessionUid: - sessionUid = ReadString(stream); + sessionUid = ReadStringValue(stream, fieldSize); break; default: @@ -194,11 +194,11 @@ private static List ReadFailedTestMessagesPayload(Strea switch (fieldId) { case FailedTestResultMessageFieldsId.Uid: - uid = ReadString(stream); + uid = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.DisplayName: - displayName = ReadString(stream); + displayName = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.State: @@ -206,19 +206,19 @@ private static List ReadFailedTestMessagesPayload(Strea break; case FailedTestResultMessageFieldsId.Reason: - reason = ReadString(stream); + reason = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.ErrorMessage: - errorMessage = ReadString(stream); + errorMessage = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.ErrorStackTrace: - errorStackTrace = ReadString(stream); + errorStackTrace = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.SessionUid: - sessionUid = ReadString(stream); + sessionUid = ReadStringValue(stream, fieldSize); break; default: diff --git a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/TestSessionEventSerializer.cs b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/TestSessionEventSerializer.cs index d0ddd0247a17..53f560cf3cb7 100644 --- a/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/TestSessionEventSerializer.cs +++ b/src/Cli/dotnet/commands/dotnet-test/IPC/Serializers/TestSessionEventSerializer.cs @@ -37,7 +37,7 @@ public object Deserialize(Stream stream) for (int i = 0; i < fieldCount; i++) { - int fieldId = ReadShort(stream); + ushort fieldId = ReadShort(stream); int fieldSize = ReadInt(stream); switch (fieldId) @@ -47,11 +47,11 @@ public object Deserialize(Stream stream) break; case TestSessionEventFieldsId.SessionUid: - sessionUid = ReadString(stream); + sessionUid = ReadStringValue(stream, fieldSize); break; case TestSessionEventFieldsId.ExecutionId: - executionId = ReadString(stream); + executionId = ReadStringValue(stream, fieldSize); break; default: