Skip to content

Fixed issues with the request middleware source generator #8274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,14 @@ private void WriteCtorServiceResolution(List<RequestMiddlewareParameterInfo> par

case RequestMiddlewareParameterKind.SchemaService when !parameter.IsNullable:
_writer.WriteIndentedLine(
"var cp{0} = core.SchemaServices.GetRequiredService<global::{1}>();",
"var cp{0} = core.SchemaServices.GetRequiredService<{1}>();",
i,
parameter.TypeName);
break;

case RequestMiddlewareParameterKind.SchemaService when parameter.IsNullable:
_writer.WriteIndentedLine(
"var cp{0} = core.SchemaServices.GetService<global::{1}>();",
"var cp{0} = core.SchemaServices.GetService<{1}>();",
i,
parameter.TypeName);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,11 @@ ctor is not
kind = RequestMiddlewareParameterKind.Service;
}

ctorParameters.Add(new RequestMiddlewareParameterInfo(kind, parameterTypeName));
ctorParameters.Add(
new RequestMiddlewareParameterInfo(
kind,
parameterTypeName,
isNullable: !parameter.IsNonNullable()));
}

foreach (var parameter in invokeMethod.Parameters)
Expand All @@ -113,7 +117,11 @@ ctor is not
kind = RequestMiddlewareParameterKind.Service;
}

invokeParameters.Add(new RequestMiddlewareParameterInfo(kind, parameterTypeName));
invokeParameters.Add(
new RequestMiddlewareParameterInfo(
kind,
parameterTypeName,
isNullable: !parameter.IsNonNullable()));
}

syntaxInfo = new RequestMiddlewareInfo(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
namespace HotChocolate.Types;

public class RequestMiddlewareTests
{
[Fact]
public async Task GenerateSource_RequestMiddleware_MatchesSnapshot()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
#nullable enable
using System.Threading.Tasks;
using HotChocolate;
using HotChocolate.Execution;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;

public class Program
{
public static void Main(string[] args)
{
var builder = WebApplication.CreateBuilder(args);
builder.Services
.AddGraphQLServer()
.UseRequest<SomeRequestMiddleware>();
}
}

public class SomeRequestMiddleware(
RequestDelegate next,
#pragma warning disable CS9113
[SchemaService] Service1 service1,
[SchemaService] Service2? service2)
#pragma warning restore CS9113
{
public async ValueTask InvokeAsync(
IRequestContext context,
#pragma warning disable CS9113
Service1 service1,
Service2? service2)
#pragma warning restore CS9113
{
await next(context);
}
}

public class Service1;
public class Service2;
""").MatchMarkdownAsync();
}
}
22 changes: 21 additions & 1 deletion src/HotChocolate/Core/test/Types.Analyzers.Tests/TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
using GreenDonut;
using GreenDonut.Data;
using HotChocolate.Data.Filters;
using HotChocolate.Execution;
using HotChocolate.Execution.Configuration;
using HotChocolate.Types.Analyzers;
using HotChocolate.Types.Pagination;
using Microsoft.AspNetCore.Builder;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.Extensions.DependencyInjection;

namespace HotChocolate.Types;

Expand All @@ -33,6 +37,12 @@ public static Snapshot GetGeneratedSourceSnapshot(string[] sourceTexts, string?
#elif NET9_0
.. Net90.References.All,
#endif
// HotChocolate.Execution
MetadataReference.CreateFromFile(typeof(RequestDelegate).Assembly.Location),

// HotChocolate.Execution.Abstractions
MetadataReference.CreateFromFile(typeof(IRequestExecutorBuilder).Assembly.Location),

// HotChocolate.Types
MetadataReference.CreateFromFile(typeof(ObjectTypeAttribute).Assembly.Location),
MetadataReference.CreateFromFile(typeof(Connection).Assembly.Location),
Expand All @@ -41,6 +51,10 @@ public static Snapshot GetGeneratedSourceSnapshot(string[] sourceTexts, string?
// HotChocolate.Abstractions
MetadataReference.CreateFromFile(typeof(ParentAttribute).Assembly.Location),

// HotChocolate.AspNetCore
MetadataReference.CreateFromFile(
typeof(HotChocolateAspNetCoreServiceCollectionExtensions).Assembly.Location),

// GreenDonut
MetadataReference.CreateFromFile(typeof(DataLoaderBase<,>).Assembly.Location),
MetadataReference.CreateFromFile(typeof(IDataLoader).Assembly.Location),
Expand All @@ -50,7 +64,13 @@ public static Snapshot GetGeneratedSourceSnapshot(string[] sourceTexts, string?
MetadataReference.CreateFromFile(typeof(IPredicateBuilder).Assembly.Location),

// HotChocolate.Data
MetadataReference.CreateFromFile(typeof(IFilterContext).Assembly.Location)
MetadataReference.CreateFromFile(typeof(IFilterContext).Assembly.Location),

// Microsoft.AspNetCore
MetadataReference.CreateFromFile(typeof(WebApplication).Assembly.Location),

// Microsoft.Extensions.DependencyInjection.Abstractions
MetadataReference.CreateFromFile(typeof(IServiceCollection).Assembly.Location)
];

// Create a Roslyn compilation for the syntax tree.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# GenerateSource_RequestMiddleware_MatchesSnapshot

```csharp
// <auto-generated/>

#nullable enable
#pragma warning disable

using System;
using System.Runtime.CompilerServices;
using HotChocolate;
using HotChocolate.Types;
using HotChocolate.Execution.Configuration;
using Microsoft.Extensions.DependencyInjection;

namespace HotChocolate.Execution.Generated
{
public static class TestsTypesMiddlewareFactoriesHASH
{
// global::SomeRequestMiddleware
private static global::HotChocolate.Execution.RequestCoreMiddleware CreateMiddleware0()
=> (core, next) =>
{
var cp1 = core.SchemaServices.GetRequiredService<global::Service1>();
var cp2 = core.SchemaServices.GetService<global::Service2>();
var middleware = new global::SomeRequestMiddleware(next, cp1, cp2);
return async context =>
{
var ip1 = context.Services.GetRequiredService<global::Service1>();
var ip2 = context.Services.GetService<global::Service2>();
await middleware.InvokeAsync(context, ip1, ip2).ConfigureAwait(false);
};
};

[InterceptsLocation("", 15, 14)]
public static global::HotChocolate.Execution.Configuration.IRequestExecutorBuilder UseRequestGen0<TMiddleware>(
this HotChocolate.Execution.Configuration.IRequestExecutorBuilder builder) where TMiddleware : class
=> builder.UseRequest(CreateMiddleware0());
}
}

#pragma warning disable CS9113 // Parameter is unread.
namespace System.Runtime.CompilerServices
{
[AttributeUsage(AttributeTargets.Method, AllowMultiple = true)]
file sealed class InterceptsLocationAttribute(string filePath, int line, int column) : Attribute;
}
#pragma warning restore CS9113 // Parameter is unread.
```
Loading