Skip to content

Commit c45c8b0

Browse files
authored
Merge branch 'main' into mst/execution-preparations
2 parents 348ec20 + 0d38466 commit c45c8b0

10 files changed

+336
-20
lines changed

src/HotChocolate/Core/src/Types.Analyzers/FileBuilders/RequestMiddlewareFileBuilder.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ private void WriteCtorServiceResolution(List<RequestMiddlewareParameterInfo> par
121121

122122
case RequestMiddlewareParameterKind.SchemaService when !parameter.IsNullable:
123123
_writer.WriteIndentedLine(
124-
"var cp{0} = core.SchemaServices.GetRequiredService<global::{1}>();",
124+
"var cp{0} = core.SchemaServices.GetRequiredService<{1}>();",
125125
i,
126126
parameter.TypeName);
127127
break;
128128

129129
case RequestMiddlewareParameterKind.SchemaService when parameter.IsNullable:
130130
_writer.WriteIndentedLine(
131-
"var cp{0} = core.SchemaServices.GetService<global::{1}>();",
131+
"var cp{0} = core.SchemaServices.GetService<{1}>();",
132132
i,
133133
parameter.TypeName);
134134
break;

src/HotChocolate/Core/src/Types.Analyzers/Inspectors/RequestMiddlewareInspector.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,11 @@ ctor is not
8787
kind = RequestMiddlewareParameterKind.Service;
8888
}
8989

90-
ctorParameters.Add(new RequestMiddlewareParameterInfo(kind, parameterTypeName));
90+
ctorParameters.Add(
91+
new RequestMiddlewareParameterInfo(
92+
kind,
93+
parameterTypeName,
94+
isNullable: !parameter.IsNonNullable()));
9195
}
9296

9397
foreach (var parameter in invokeMethod.Parameters)
@@ -113,7 +117,11 @@ ctor is not
113117
kind = RequestMiddlewareParameterKind.Service;
114118
}
115119

116-
invokeParameters.Add(new RequestMiddlewareParameterInfo(kind, parameterTypeName));
120+
invokeParameters.Add(
121+
new RequestMiddlewareParameterInfo(
122+
kind,
123+
parameterTypeName,
124+
isNullable: !parameter.IsNonNullable()));
117125
}
118126

119127
syntaxInfo = new RequestMiddlewareInfo(

src/HotChocolate/Core/src/Types/SemanticNonNullTypeInterceptor.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ public override void OnBeforeCompleteType(ITypeCompletionContext completionConte
104104
continue;
105105
}
106106

107+
if (field.Name == "id")
108+
{
109+
continue;
110+
}
111+
107112
var levels = GetSemanticNonNullLevels(field.Type);
108113

109114
if (levels.Count < 1)

src/HotChocolate/Core/src/Types/Types/Relay/Serialization/CompositeNodeIdValueSerializer.cs

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace HotChocolate.Types.Relay;
1616
public abstract class CompositeNodeIdValueSerializer<T> : INodeIdValueSerializer
1717
{
1818
private const byte _partSeparator = (byte)':';
19+
private const byte _escape = (byte)'\\';
1920
private static readonly Encoding _utf8 = Encoding.UTF8;
2021

2122
public virtual bool IsSupported(Type type) => type == typeof(T) || type == typeof(T?);
@@ -87,19 +88,21 @@ public NodeIdFormatterResult Format(Span<byte> buffer, object value, out int wri
8788
/// </returns>
8889
protected static bool TryFormatIdPart(Span<byte> buffer, string value, out int written)
8990
{
90-
var requiredCapacity = _utf8.GetByteCount(value) + 1;
91+
var requiredCapacity = _utf8.GetByteCount(value) * 2 + 1; // * 2 to allow for escaping.
9192
if (buffer.Length < requiredCapacity)
9293
{
9394
written = 0;
9495
return false;
9596
}
9697

97-
var stringBytes = buffer;
98-
Utf8GraphQLParser.ConvertToBytes(value, ref stringBytes);
98+
Span<byte> utf8Bytes = stackalloc byte[_utf8.GetByteCount(value)];
99+
_utf8.GetBytes(value, utf8Bytes);
99100

100-
buffer = buffer.Slice(stringBytes.Length);
101+
var bytesWritten = WriteEscapedBytes(utf8Bytes, buffer);
102+
103+
buffer = buffer[bytesWritten..];
101104
buffer[0] = _partSeparator;
102-
written = stringBytes.Length + 1;
105+
written = bytesWritten + 1;
103106
return true;
104107
}
105108

@@ -125,7 +128,8 @@ protected static bool TryFormatIdPart(Span<byte> buffer, Guid value, out int wri
125128
{
126129
if (compress)
127130
{
128-
if (buffer.Length < 17)
131+
const int requiredCapacity = 16 * 2 + 1; // * 2 to allow for escaping.
132+
if (buffer.Length < requiredCapacity)
129133
{
130134
written = 0;
131135
return false;
@@ -135,16 +139,17 @@ protected static bool TryFormatIdPart(Span<byte> buffer, Guid value, out int wri
135139
#pragma warning disable CS9191
136140
MemoryMarshal.TryWrite(span, ref value);
137141
#pragma warning restore CS9191
138-
span.CopyTo(buffer);
139-
buffer = buffer.Slice(16);
142+
var bytesWritten = WriteEscapedBytes(span, buffer);
143+
144+
buffer = buffer[bytesWritten..];
140145
buffer[0] = _partSeparator;
141-
written = 17;
146+
written = bytesWritten + 1;
142147
return true;
143148
}
144149

145150
if (Utf8Formatter.TryFormat(value, buffer, out written, format: 'N'))
146151
{
147-
buffer = buffer.Slice(written);
152+
buffer = buffer[written..];
148153
if (buffer.Length < 1)
149154
{
150155
return false;
@@ -344,8 +349,9 @@ protected static unsafe bool TryParseIdPart(
344349
[NotNullWhen(true)] out string? value,
345350
out int consumed)
346351
{
347-
var index = buffer.IndexOf(_partSeparator);
348-
var valueSpan = index == -1 ? buffer : buffer.Slice(0, index);
352+
var index = IndexOfPartSeparator(buffer);
353+
var valueSpan = index == -1 ? buffer : buffer[..index];
354+
valueSpan = Unescape(valueSpan);
349355
fixed (byte* b = valueSpan)
350356
{
351357
value = _utf8.GetString(b, valueSpan.Length);
@@ -379,11 +385,13 @@ protected static bool TryParseIdPart(
379385
out int consumed,
380386
bool compress = true)
381387
{
382-
var index = buffer.IndexOf(_partSeparator);
383-
var valueSpan = index == -1 ? buffer : buffer.Slice(0, index);
388+
var index = IndexOfPartSeparator(buffer);
389+
var valueSpan = index == -1 ? buffer : buffer[..index];
384390

385391
if (compress)
386392
{
393+
valueSpan = Unescape(valueSpan);
394+
387395
if (valueSpan.Length != 16)
388396
{
389397
value = default;
@@ -396,7 +404,7 @@ protected static bool TryParseIdPart(
396404
return true;
397405
}
398406

399-
if (Utf8Parser.TryParse(valueSpan, out Guid parsedValue, out _))
407+
if (Utf8Parser.TryParse(valueSpan, out Guid parsedValue, out _, standardFormat: 'N'))
400408
{
401409
value = parsedValue;
402410
consumed = index + 1;
@@ -547,4 +555,82 @@ protected static bool TryParseIdPart(
547555
consumed = 0;
548556
return false;
549557
}
558+
559+
/// <summary>
560+
/// Writes the given unescaped bytes with the part separator (<c>:</c>) escaped, into the given
561+
/// span.
562+
/// </summary>
563+
/// <param name="unescapedBytes">The unescaped bytes to write as escaped.</param>
564+
/// <param name="escapedBytes">The span into which the escaped bytes should be written.</param>
565+
/// <returns>The number of bytes written.</returns>
566+
private static int WriteEscapedBytes(ReadOnlySpan<byte> unescapedBytes, Span<byte> escapedBytes)
567+
{
568+
var index = 0;
569+
570+
foreach (var b in unescapedBytes)
571+
{
572+
if (b == _partSeparator)
573+
{
574+
escapedBytes[index++] = _escape;
575+
}
576+
577+
escapedBytes[index++] = b;
578+
}
579+
580+
return index;
581+
}
582+
583+
/// <summary>
584+
/// Unescapes part separators (<c>:</c>) in the given span of bytes.
585+
/// </summary>
586+
/// <param name="escapedBytes">A span with the bytes to be unescaped.</param>
587+
/// <returns>A span with the unescaped bytes.</returns>
588+
private static ReadOnlySpan<byte> Unescape(ReadOnlySpan<byte> escapedBytes)
589+
{
590+
Span<byte> unescapedBytes = new byte[escapedBytes.Length];
591+
592+
var index = 0;
593+
var skipNext = false;
594+
595+
for (var i = 0; i < escapedBytes.Length; i++)
596+
{
597+
if (skipNext)
598+
{
599+
skipNext = false;
600+
continue;
601+
}
602+
603+
if (escapedBytes[i] == _escape
604+
&& i + 1 < escapedBytes.Length
605+
&& escapedBytes[i + 1] == _partSeparator)
606+
{
607+
unescapedBytes[index++] = _partSeparator;
608+
skipNext = true;
609+
}
610+
else
611+
{
612+
unescapedBytes[index++] = escapedBytes[i];
613+
}
614+
}
615+
616+
return unescapedBytes[..index];
617+
}
618+
619+
/// <summary>
620+
/// Finds the index of the first non-escaped part separator (<c>:</c>) in the given buffer.
621+
/// </summary>
622+
/// <param name="buffer">The buffer to search.</param>
623+
/// <returns>The index of the non-escaped part separator.</returns>
624+
private static int IndexOfPartSeparator(ReadOnlySpan<byte> buffer)
625+
{
626+
for (var i = 0; i < buffer.Length; i++)
627+
{
628+
if (buffer[i] == _partSeparator && (i == 0 || buffer[i - 1] != _escape))
629+
{
630+
return i;
631+
}
632+
}
633+
634+
return -1;
635+
}
550636
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
namespace HotChocolate.Types;
2+
3+
public class RequestMiddlewareTests
4+
{
5+
[Fact]
6+
public async Task GenerateSource_RequestMiddleware_MatchesSnapshot()
7+
{
8+
await TestHelper.GetGeneratedSourceSnapshot(
9+
"""
10+
#nullable enable
11+
using System.Threading.Tasks;
12+
using HotChocolate;
13+
using HotChocolate.Execution;
14+
using Microsoft.AspNetCore.Builder;
15+
using Microsoft.Extensions.DependencyInjection;
16+
17+
public class Program
18+
{
19+
public static void Main(string[] args)
20+
{
21+
var builder = WebApplication.CreateBuilder(args);
22+
builder.Services
23+
.AddGraphQLServer()
24+
.UseRequest<SomeRequestMiddleware>();
25+
}
26+
}
27+
28+
public class SomeRequestMiddleware(
29+
RequestDelegate next,
30+
#pragma warning disable CS9113
31+
[SchemaService] Service1 service1,
32+
[SchemaService] Service2? service2)
33+
#pragma warning restore CS9113
34+
{
35+
public async ValueTask InvokeAsync(
36+
IRequestContext context,
37+
#pragma warning disable CS9113
38+
Service1 service1,
39+
Service2? service2)
40+
#pragma warning restore CS9113
41+
{
42+
await next(context);
43+
}
44+
}
45+
46+
public class Service1;
47+
public class Service2;
48+
""").MatchMarkdownAsync();
49+
}
50+
}

src/HotChocolate/Core/test/Types.Analyzers.Tests/TestHelper.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88
using GreenDonut;
99
using GreenDonut.Data;
1010
using HotChocolate.Data.Filters;
11+
using HotChocolate.Execution;
12+
using HotChocolate.Execution.Configuration;
1113
using HotChocolate.Types.Analyzers;
1214
using HotChocolate.Types.Pagination;
15+
using Microsoft.AspNetCore.Builder;
1316
using Microsoft.CodeAnalysis;
1417
using Microsoft.CodeAnalysis.CSharp;
18+
using Microsoft.Extensions.DependencyInjection;
1519

1620
namespace HotChocolate.Types;
1721

@@ -33,6 +37,12 @@ public static Snapshot GetGeneratedSourceSnapshot(string[] sourceTexts, string?
3337
#elif NET9_0
3438
.. Net90.References.All,
3539
#endif
40+
// HotChocolate.Execution
41+
MetadataReference.CreateFromFile(typeof(RequestDelegate).Assembly.Location),
42+
43+
// HotChocolate.Execution.Abstractions
44+
MetadataReference.CreateFromFile(typeof(IRequestExecutorBuilder).Assembly.Location),
45+
3646
// HotChocolate.Types
3747
MetadataReference.CreateFromFile(typeof(ObjectTypeAttribute).Assembly.Location),
3848
MetadataReference.CreateFromFile(typeof(Connection).Assembly.Location),
@@ -41,6 +51,10 @@ public static Snapshot GetGeneratedSourceSnapshot(string[] sourceTexts, string?
4151
// HotChocolate.Abstractions
4252
MetadataReference.CreateFromFile(typeof(ParentAttribute).Assembly.Location),
4353

54+
// HotChocolate.AspNetCore
55+
MetadataReference.CreateFromFile(
56+
typeof(HotChocolateAspNetCoreServiceCollectionExtensions).Assembly.Location),
57+
4458
// GreenDonut
4559
MetadataReference.CreateFromFile(typeof(DataLoaderBase<,>).Assembly.Location),
4660
MetadataReference.CreateFromFile(typeof(IDataLoader).Assembly.Location),
@@ -50,7 +64,13 @@ public static Snapshot GetGeneratedSourceSnapshot(string[] sourceTexts, string?
5064
MetadataReference.CreateFromFile(typeof(IPredicateBuilder).Assembly.Location),
5165

5266
// HotChocolate.Data
53-
MetadataReference.CreateFromFile(typeof(IFilterContext).Assembly.Location)
67+
MetadataReference.CreateFromFile(typeof(IFilterContext).Assembly.Location),
68+
69+
// Microsoft.AspNetCore
70+
MetadataReference.CreateFromFile(typeof(WebApplication).Assembly.Location),
71+
72+
// Microsoft.Extensions.DependencyInjection.Abstractions
73+
MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location)
5474
];
5575

5676
// Create a Roslyn compilation for the syntax tree.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# GenerateSource_RequestMiddleware_MatchesSnapshot
2+
3+
```csharp
4+
// <auto-generated/>
5+
6+
#nullable enable
7+
#pragma warning disable
8+
9+
using System;
10+
using System.Runtime.CompilerServices;
11+
using HotChocolate;
12+
using HotChocolate.Types;
13+
using HotChocolate.Execution.Configuration;
14+
using Microsoft.Extensions.DependencyInjection;
15+
16+
namespace HotChocolate.Execution.Generated
17+
{
18+
public static class TestsTypesMiddlewareFactoriesHASH
19+
{
20+
// global::SomeRequestMiddleware
21+
private static global::HotChocolate.Execution.RequestCoreMiddleware CreateMiddleware0()
22+
=> (core, next) =>
23+
{
24+
var cp1 = core.SchemaServices.GetRequiredService<global::Service1>();
25+
var cp2 = core.SchemaServices.GetService<global::Service2>();
26+
var middleware = new global::SomeRequestMiddleware(next, cp1, cp2);
27+
return async context =>
28+
{
29+
var ip1 = context.Services.GetRequiredService<global::Service1>();
30+
var ip2 = context.Services.GetService<global::Service2>();
31+
await middleware.InvokeAsync(context, ip1, ip2).ConfigureAwait(false);
32+
};
33+
};
34+
35+
[InterceptsLocation("", 15, 14)]
36+
public static global::HotChocolate.Execution.Configuration.IRequestExecutorBuilder UseRequestGen0<TMiddleware>(
37+
this HotChocolate.Execution.Configuration.IRequestExecutorBuilder builder) where TMiddleware : class
38+
=> builder.UseRequest(CreateMiddleware0());
39+
}
40+
}
41+
42+
#pragma warning disable CS9113 // Parameter is unread.
43+
namespace System.Runtime.CompilerServices
44+
{
45+
[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
46+
file sealed class InterceptsLocationAttribute(string filePath, int line, int column) : Attribute;
47+
}
48+
#pragma warning restore CS9113 // Parameter is unread.
49+
```

0 commit comments

Comments
 (0)