Skip to content

Commit 260ea40

Browse files
committed
feat(nestjs): add middleware support (#21)
1 parent 3ceb6cf commit 260ea40

File tree

10 files changed

+191
-103
lines changed

10 files changed

+191
-103
lines changed

fixtures/nestjs/src/app.module.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ import { AppService } from './app.service'
99
})
1010
export class AppModule implements NestModule {
1111
configure(consumer: MiddlewareConsumer): void {
12-
consumer.apply(AppMiddleware).forRoutes('*')
12+
consumer.apply(AppMiddleware).forRoutes('*path')
1313
}
1414
}

packages/adapter/adapter-nestjs/src/middleware/middleware-customer.ts

+4-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ import type { Awaitable } from '@unioc/shared'
55
import { MiddlewareConfigProxyBuilder } from './middleware-config-proxy'
66

77
export type MiddlewareType = Type | ((...args: any[]) => Awaitable<unknown>)
8-
8+
export type ExcludeMiddlewareRoute = string | RouteInfo
9+
export type IncludeMiddlewareRoute = ExcludeMiddlewareRoute | Type<any>
910
export interface IMiddlewareApplyData {
10-
includedRoutes: (string | Type<any> | RouteInfo)[]
11-
excludedRoutes: (string | RouteInfo)[]
11+
includedRoutes: IncludeMiddlewareRoute[]
12+
excludedRoutes: ExcludeMiddlewareRoute[]
1213
}
1314

1415
export class MiddlewareCustomerBuilder implements MiddlewareConsumer {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import type { NestModule } from '@nestjs/common'
2+
import type { IClassWrapper } from '@unioc/core'
3+
import type { NestJSMethodOperator, NestJSRestfulScanner } from '../restful'
4+
import type { IMiddlewareApplyData, MiddlewareType } from './middleware-customer'
5+
import { MiddlewareCustomerBuilder } from './middleware-customer'
6+
7+
export class MiddlewareResolver {
8+
constructor(
9+
private readonly _nestJSScanner: NestJSRestfulScanner,
10+
private readonly _modules: IClassWrapper[],
11+
) {}
12+
13+
getNestJSRestfulScanner(): NestJSRestfulScanner {
14+
return this._nestJSScanner
15+
}
16+
17+
private _middlewareCustomerMap: Map<MiddlewareType, IMiddlewareApplyData> = new Map<MiddlewareType, IMiddlewareApplyData>()
18+
isResolved: boolean = false
19+
20+
async resolveAll(): Promise<Map<MiddlewareType, IMiddlewareApplyData>> {
21+
if (this.isResolved)
22+
return this._middlewareCustomerMap
23+
24+
for (const moduleWrapper of this._modules) {
25+
const resolveInstance: NestModule = await moduleWrapper.resolve()
26+
if (!('configure' in resolveInstance) || typeof resolveInstance.configure !== 'function')
27+
continue
28+
29+
const middlewareCustomer = new MiddlewareCustomerBuilder(this.getNestJSRestfulScanner().getPluginContext())
30+
await resolveInstance.configure(middlewareCustomer)
31+
this._middlewareCustomerMap = middlewareCustomer.merge(this._middlewareCustomerMap)
32+
}
33+
this.isResolved = true
34+
return this._middlewareCustomerMap
35+
}
36+
37+
isResolvedOperatorToMiddleware: boolean = false
38+
private _operatorToMiddlewareMap: Map<NestJSMethodOperator, Set<MiddlewareType>> = new Map<NestJSMethodOperator, Set<MiddlewareType>>()
39+
40+
async resolveOperatorToMiddleware(): Promise<Map<NestJSMethodOperator, Set<MiddlewareType>>> {
41+
if (this.isResolvedOperatorToMiddleware)
42+
return this._operatorToMiddlewareMap
43+
44+
await this.resolveAll()
45+
for (const methodOperator of this.getNestJSRestfulScanner().findAllMethodOperator())
46+
this._operatorToMiddlewareMap.set(methodOperator, await methodOperator.getMatchedMiddlewares())
47+
48+
this.isResolvedOperatorToMiddleware = true
49+
return this._operatorToMiddlewareMap
50+
}
51+
}

packages/adapter/adapter-nestjs/src/restful/controller-wrapper.ts

+3-27
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import type { IClass } from '@unioc/shared'
2-
import type { IHttpMethod, IHttpMethodMetadata, IRestControllerOperator, IRestControllerWrapper, RestController } from '@unioc/web'
2+
import type { IHttpMethodMetadata, IRestControllerOperator, IRestControllerWrapper, RestController } from '@unioc/web'
33
import type { NestJSRestfulScanner } from './restful-scanner'
4-
import { RequestMethod } from '@nestjs/common'
54
import { METHOD_METADATA, PATH_METADATA } from '@nestjs/common/constants.js'
65
import { isReflectable } from '@unioc/meta'
7-
import { HttpMethod, RestControllerWrapper } from '@unioc/web'
6+
import { RestControllerWrapper } from '@unioc/web'
87
import { NestJSRestControllerOperator } from './controller-operator'
98

109
export class NestJSControllerWrapper extends RestControllerWrapper implements IRestControllerWrapper {
@@ -24,29 +23,6 @@ export class NestJSControllerWrapper extends RestControllerWrapper implements IR
2423
]
2524
}
2625

27-
private _toHttpMethod(method: RequestMethod): IHttpMethod {
28-
switch (method) {
29-
case RequestMethod.GET:
30-
return HttpMethod.GET
31-
case RequestMethod.POST:
32-
return HttpMethod.POST
33-
case RequestMethod.PUT:
34-
return HttpMethod.PUT
35-
case RequestMethod.DELETE:
36-
return HttpMethod.DELETE
37-
case RequestMethod.PATCH:
38-
return HttpMethod.PATCH
39-
case RequestMethod.OPTIONS:
40-
return HttpMethod.OPTIONS
41-
case RequestMethod.HEAD:
42-
return HttpMethod.HEAD
43-
case RequestMethod.ALL:
44-
return HttpMethod.ALL
45-
default:
46-
return HttpMethod.GET
47-
}
48-
}
49-
5026
override getFullMethodOptions(): IHttpMethodMetadata[] {
5127
const methodPropertyKeys: PropertyKey[] = Reflect.ownKeys(this.getClassWrapper().getTarget().prototype)
5228
.filter(propertyKey => propertyKey !== 'constructor')
@@ -62,7 +38,7 @@ export class NestJSControllerWrapper extends RestControllerWrapper implements IR
6238
continue
6339

6440
methodOptions.push({
65-
httpMethod: this._toHttpMethod(methodMetadata),
41+
httpMethod: this.getRestfulScanner().toHttpMethod(methodMetadata),
6642
path: pathMetadata,
6743
propertyKey,
6844
})
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,36 @@
11
import type { IRestMethodOperator } from '@unioc/web'
2+
import type { MiddlewareType } from '../middleware/middleware-customer'
23
import type { NestJSMethodWrapper } from './method-wrapper'
34
import { RestMethodOperator } from '@unioc/web'
45

56
export class NestJSMethodOperator extends RestMethodOperator implements IRestMethodOperator {
67
getMethodWrapper(): NestJSMethodWrapper {
78
return super.getMethodWrapper() as NestJSMethodWrapper
89
}
10+
11+
async getMatchedMiddlewares(): Promise<Set<MiddlewareType>> {
12+
const matchedMiddlewares: Set<MiddlewareType> = new Set<MiddlewareType>()
13+
const restfulScanner = this.getMethodWrapper()
14+
.getControllerOperator()
15+
.getControllerWrapper()
16+
.getRestfulScanner()
17+
18+
const resolvedMiddlewares = await restfulScanner.getMiddlewareResolver().resolveAll()
19+
20+
for (const [middleware, data] of resolvedMiddlewares.entries()) {
21+
for (const route of data.includedRoutes) {
22+
if (restfulScanner.isMatchRoute(this, this.getFullPath(), route)) {
23+
matchedMiddlewares.add(middleware)
24+
}
25+
}
26+
27+
for (const route of data.excludedRoutes) {
28+
if (restfulScanner.isMatchRoute(this, this.getFullPath(), route)) {
29+
matchedMiddlewares.delete(middleware)
30+
}
31+
}
32+
}
33+
34+
return matchedMiddlewares
35+
}
936
}

packages/adapter/adapter-nestjs/src/restful/restful-handler.ts

+28-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import type { NestMiddleware } from '@nestjs/common'
2+
import type { IClass } from '@unioc/shared'
13
import type { IMethodExecuteOptions, IRestfulConnect, IRestMethodOperator } from '@unioc/web'
24
import type { INestJSMethodParamMetadata, NestJSMethodWrapper } from './method-wrapper'
35
import type { NestJSRestfulScanner } from './restful-scanner'
46
import { UnauthorizedException } from '@nestjs/common'
7+
import { isClass } from '@unioc/shared'
58
import { ExecutionContextBuilder } from '../execution-context-builder'
69
import { EndingHandler } from './ending-handler'
710
import { NestJSMethodOperator } from './method-operator'
@@ -153,6 +156,28 @@ export class NestJSRestfulHandler implements IRestfulConnect.Handler {
153156
throw new UnauthorizedException()
154157
}
155158

159+
async initialize(ctx: IRestfulConnect.InitializeContext): Promise<void> {
160+
const map = await this.getRestfulScanner()
161+
.getMiddlewareResolver()
162+
.resolveOperatorToMiddleware()
163+
164+
for (const [methodOperator, matchedMiddlewares] of map) {
165+
for (const middleware of matchedMiddlewares) {
166+
if (isClass(middleware)) {
167+
const middlewareInstance = await this.getRestfulScanner()
168+
.getPluginContext()
169+
.createClass<IClass<NestMiddleware>>(middleware)
170+
.resolve()
171+
172+
ctx.addMiddleware(methodOperator.getFullPath(), middlewareInstance.use.bind(middlewareInstance))
173+
}
174+
else {
175+
ctx.addMiddleware(methodOperator.getFullPath(), middleware)
176+
}
177+
}
178+
}
179+
}
180+
156181
async handleConnectRequest(methodOperator: IRestMethodOperator, ctx: IRestfulConnect.WebContext): Promise<void> {
157182
if (!(methodOperator instanceof NestJSMethodOperator))
158183
throw new Error('Method operator is not a NestJSMethodOperator')
@@ -166,13 +191,11 @@ export class NestJSRestfulHandler implements IRestfulConnect.Handler {
166191
} as const
167192

168193
try {
169-
// 1. Execute middlewares
170-
await this.getRestfulScanner().executeMiddlewares(methodWrapper, ctx)
171-
// 2. Build params with pipes
194+
// 1. Build params with pipes
172195
methodArguments = await this.buildParams(methodWrapper, ctx)
173-
// 3. Execute guards
196+
// 2. Execute guards
174197
await this.executeGuards(methodWrapper, methodArguments)
175-
// 4. Execute the controller method
198+
// 3. Execute the controller method
176199
const result = await methodWrapper.execute(methodArguments, extraOptions)
177200
// 4. TODO: Execute the interceptors
178201
// 5. Send the ending response

packages/adapter/adapter-nestjs/src/restful/restful-scanner.ts

+63-66
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
import type { ArgumentMetadata, ArgumentsHost, CanActivate, ExceptionFilter, ExecutionContext, NestMiddleware, NestModule, PipeTransform } from '@nestjs/common'
1+
import type { ArgumentMetadata, ArgumentsHost, CanActivate, ExceptionFilter, ExecutionContext, PipeTransform } from '@nestjs/common'
22
import type { IArgument, IClassWrapper, IPluginContext } from '@unioc/core'
33
import type { IClass } from '@unioc/shared'
4-
import type { IHttpParam, IRestfulConnect, IRestfulScanner } from '@unioc/web'
5-
import type { IMiddlewareApplyData, MiddlewareType } from '../middleware/middleware-customer'
4+
import type { IHttpMethod, IHttpParam, IRestfulConnect, IRestfulScanner } from '@unioc/web'
5+
import type { ExcludeMiddlewareRoute, IncludeMiddlewareRoute } from '../middleware/middleware-customer'
66
import type { NestJSMethodOperator } from './method-operator'
7-
import type { NestJSMethodWrapper } from './method-wrapper'
7+
import { RequestMethod } from '@nestjs/common'
88
import { CONTROLLER_WATERMARK, FILTER_CATCH_EXCEPTIONS } from '@nestjs/common/constants.js'
9-
import { isClass } from '@unioc/shared'
10-
import { RestfulScanner } from '@unioc/web'
11-
import { MiddlewareCustomerBuilder } from '../middleware/middleware-customer'
9+
import { HttpMethod, RestfulScanner } from '@unioc/web'
10+
import { match } from 'path-to-regexp'
11+
import { MiddlewareResolver } from '../middleware/middleware-resolver'
1212
import { NestJSControllerWrapper } from './controller-wrapper'
1313
import { NestJSRestfulHandler } from './restful-handler'
1414

@@ -142,71 +142,68 @@ export class NestJSRestfulScanner extends RestfulScanner implements IRestfulScan
142142
return true
143143
}
144144

145-
private _middlewareCustomerMap: Map<MiddlewareType, IMiddlewareApplyData> = new Map<MiddlewareType, IMiddlewareApplyData>()
145+
private _middlewareResolver: MiddlewareResolver = new MiddlewareResolver(this, this._modules)
146146

147-
async resolveMiddlewares(): Promise<void> {
148-
for (const moduleWrapper of this._modules) {
149-
const resolveInstance: NestModule = await moduleWrapper.resolve()
150-
if (!('configure' in resolveInstance) || typeof resolveInstance.configure !== 'function')
151-
continue
152-
153-
const middlewareCustomer = new MiddlewareCustomerBuilder(this.getPluginContext())
154-
await resolveInstance.configure(middlewareCustomer)
155-
this._middlewareCustomerMap = middlewareCustomer.merge(this._middlewareCustomerMap)
156-
}
147+
getMiddlewareResolver(): MiddlewareResolver {
148+
return this._middlewareResolver
157149
}
158150

159-
async executeMiddlewares(methodWrapper: NestJSMethodWrapper, ctx: IRestfulConnect.WebContext): Promise<void> {
160-
const fullPaths = methodWrapper.findAll().map(o => o.getFullPath())
161-
const classTarget = methodWrapper.getControllerOperator().getControllerWrapper().getClassWrapper().getTarget()
162-
const matchedMiddlewares: MiddlewareType[] = []
163-
164-
for (const [middleware, data] of this._middlewareCustomerMap.entries()) {
165-
for (const route of data.includedRoutes) {
166-
if (isClass(route)) {
167-
if (classTarget === route) {
168-
matchedMiddlewares.push(middleware)
169-
}
170-
}
171-
else if (typeof route === 'string') {
172-
if (fullPaths.includes(route)) {
173-
matchedMiddlewares.push(middleware)
174-
}
175-
}
176-
else {
177-
if (fullPaths.includes(route.path)) {
178-
matchedMiddlewares.push(middleware)
179-
}
180-
}
181-
}
151+
override async resolveAll(): Promise<this> {
152+
await super.resolveAll()
153+
await this.getMiddlewareResolver().resolveAll()
154+
return this
155+
}
182156

183-
for (const route of data.excludedRoutes) {
184-
if (isClass(route)) {
185-
if (classTarget === route) {
186-
matchedMiddlewares.splice(matchedMiddlewares.indexOf(middleware), 1)
187-
}
188-
}
189-
else if (typeof route === 'string') {
190-
if (fullPaths.includes(route)) {
191-
matchedMiddlewares.splice(matchedMiddlewares.indexOf(middleware), 1)
192-
}
193-
}
194-
else {
195-
if (fullPaths.includes(route.path)) {
196-
matchedMiddlewares.splice(matchedMiddlewares.indexOf(middleware), 1)
197-
}
198-
}
199-
}
157+
/**
158+
* ### 🔄 Convert the `RequestMethod` to unioc `IHttpMethod`.
159+
*
160+
* @param method - The `RequestMethod` to convert.
161+
* @returns The converted `IHttpMethod`.
162+
*/
163+
toHttpMethod(method: RequestMethod): IHttpMethod {
164+
switch (method) {
165+
case RequestMethod.GET:
166+
return HttpMethod.GET
167+
case RequestMethod.POST:
168+
return HttpMethod.POST
169+
case RequestMethod.PUT:
170+
return HttpMethod.PUT
171+
case RequestMethod.DELETE:
172+
return HttpMethod.DELETE
173+
case RequestMethod.PATCH:
174+
return HttpMethod.PATCH
175+
case RequestMethod.OPTIONS:
176+
return HttpMethod.OPTIONS
177+
case RequestMethod.HEAD:
178+
return HttpMethod.HEAD
179+
case RequestMethod.ALL:
180+
return HttpMethod.ALL
181+
default:
182+
return HttpMethod.GET
200183
}
184+
}
201185

202-
for (const middleware of matchedMiddlewares) {
203-
if (isClass(middleware)) {
204-
const middlewareInstance = await this.getPluginContext().createClass<IClass<NestMiddleware>>(middleware).resolve()
205-
await middlewareInstance.use(ctx.request, ctx.response, ctx.next)
206-
}
207-
else {
208-
await middleware(ctx.request, ctx.response, ctx.next)
209-
}
186+
isMatch(url: string, fullPath: string): boolean {
187+
return match(fullPath)(url) !== false
188+
}
189+
190+
isMatchRoute(methodOperator: NestJSMethodOperator, url: string, route: IncludeMiddlewareRoute | ExcludeMiddlewareRoute): boolean {
191+
const methodWrapper = methodOperator.getMethodWrapper()
192+
const classTarget = methodWrapper
193+
.getControllerOperator()
194+
.getControllerWrapper()
195+
.getClassWrapper()
196+
.getTarget()
197+
198+
switch (typeof route) {
199+
case 'string':
200+
return this.isMatch(url, route)
201+
case 'object':
202+
return this.isMatch(url, route.path) && methodOperator.getHttpMethod() === this.toHttpMethod(route.method)
203+
case 'function':
204+
return route === classTarget
205+
default:
206+
return false
210207
}
211208
}
212209

packages/server/web-express/src/adapter.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { AsyncCallback } from '@unioc/shared'
1+
import type { AsyncCallback, Awaitable } from '@unioc/shared'
22
import type { IRestfulConnect, IRestMethodOperator } from '@unioc/web'
33
import type express from 'express'
44

@@ -11,6 +11,12 @@ export class ExpressAdapter implements IRestfulConnect.Adapter {
1111

1212
readonly adapterType = 'connect'
1313

14+
initialize(): Awaitable<IRestfulConnect.InitializeContext> {
15+
return {
16+
addMiddleware: this.getExpressApp().use.bind(this.getExpressApp()) as unknown as import('connect').Server['use'],
17+
}
18+
}
19+
1420
async handleConnectRequest(methodOperator: IRestMethodOperator, callback: AsyncCallback<unknown, [IRestfulConnect.WebContext]>): Promise<void> {
1521
const app = this.getExpressApp()
1622
const path = methodOperator.getFullPath()

packages/server/web/src/plugins/restful.ts

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export function RestfulPlugin(): IPlugin {
2121
await adapter.handleNodeRequest(operator, async ctx => await nodeHandler?.handleNodeRequest(operator, ctx))
2222
break
2323
case 'connect':
24+
await connectHandler?.initialize?.(await adapter.initialize())
2425
await adapter.handleConnectRequest(operator, async ctx => await connectHandler?.handleConnectRequest(operator, ctx))
2526
break
2627
case 'context':

0 commit comments

Comments
 (0)