Skip to content

Commit 53c35bd

Browse files
committed
feat(nestjs): add middleware support (#21)
1 parent 014df26 commit 53c35bd

File tree

6 files changed

+177
-95
lines changed

6 files changed

+177
-95
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,55 @@
1-
import type { MiddlewareConfigProxy, MiddlewareConsumer, NestMiddleware, RouteInfo, Type } from '@nestjs/common/interfaces'
2-
import type { IPluginContext } from '@unioc/core'
3-
import type { MiddlewareConsumerBuilder } from './middleware-customer'
4-
import { isClass } from '@unioc/shared'
5-
6-
export interface IResolvedMiddleware {
7-
handler(...args: any[]): any
8-
instance?: NestMiddleware
9-
routes: (string | Type<any> | RouteInfo)[]
10-
}
1+
import type { MiddlewareConfigProxy, MiddlewareConsumer, RouteInfo, Type } from '@nestjs/common/interfaces'
2+
import type { MiddlewareType } from './middleware-customer'
3+
import { MiddlewareCustomerBuilder } from './middleware-customer'
114

125
export class MiddlewareConfigProxyBuilder implements MiddlewareConfigProxy {
13-
private _excludedRoutes: (string | RouteInfo)[] = []
14-
private _routes: (string | Type<any> | RouteInfo)[] = []
15-
private _middleware: (Type<any> | ((...args: any[]) => any))[] = []
16-
17-
constructor(private readonly _middlewareConsumer: MiddlewareConsumerBuilder) {}
6+
constructor(
7+
private readonly _middlewareCustomerBuilder: MiddlewareCustomerBuilder,
8+
private readonly _middlewares: MiddlewareType[],
9+
) {}
1810

19-
protected getMiddlewareConsumer(): MiddlewareConsumer {
20-
this._middlewareConsumer.add(this)
21-
return this._middlewareConsumer
11+
getMiddlewareCustomerBuilder(): MiddlewareCustomerBuilder {
12+
return this._middlewareCustomerBuilder
2213
}
2314

2415
exclude(...routes: (string | RouteInfo)[]): MiddlewareConfigProxy {
25-
this._excludedRoutes = [...this._excludedRoutes, ...routes]
16+
for (const route of routes) {
17+
for (const middleware of this._middlewares) {
18+
const data = this.getMiddlewareCustomerBuilder().getMiddlewareMap().get(middleware)
19+
if (!data) {
20+
this.getMiddlewareCustomerBuilder().getMiddlewareMap().set(middleware, {
21+
includedRoutes: [],
22+
excludedRoutes: [route],
23+
})
24+
}
25+
else {
26+
data.excludedRoutes.push(route)
27+
}
28+
}
29+
}
30+
2631
return this
2732
}
2833

2934
forRoutes(...routes: (string | Type<any> | RouteInfo)[]): MiddlewareConsumer {
30-
this._routes = [...this._routes, ...routes]
31-
return this.getMiddlewareConsumer()
32-
}
33-
34-
public getExcludedRoutes(): (string | RouteInfo)[] {
35-
return this._excludedRoutes
36-
}
37-
38-
public getRoutes(): (string | Type<any> | RouteInfo)[] {
39-
return this._routes
40-
}
41-
42-
public setMiddleware(middleware: (Type<any> | ((...args: any[]) => any))[]): void {
43-
this._middleware = middleware
44-
}
45-
46-
public async resolveMiddlewares(ctx: IPluginContext): Promise<IResolvedMiddleware[]> {
47-
const handlers: IResolvedMiddleware[] = []
48-
49-
for (const middleware of this._middleware) {
50-
if (isClass(middleware)) {
51-
const instance: NestMiddleware = await ctx.createClass(middleware).resolve()
52-
if (instance && typeof instance === 'object' && 'use' in instance && typeof instance.use === 'function') {
53-
handlers.push({
54-
handler: instance.use,
55-
instance,
56-
routes: this._routes,
35+
for (const route of routes) {
36+
for (const middleware of this._middlewares) {
37+
const data = this.getMiddlewareCustomerBuilder().getMiddlewareMap().get(middleware)
38+
if (!data) {
39+
this.getMiddlewareCustomerBuilder().getMiddlewareMap().set(middleware, {
40+
includedRoutes: [],
41+
excludedRoutes: [route],
5742
})
5843
}
59-
}
60-
else {
61-
handlers.push({
62-
handler: middleware,
63-
instance: undefined,
64-
routes: this._routes,
65-
})
44+
else {
45+
data.excludedRoutes.push(route)
46+
}
6647
}
6748
}
6849

69-
return handlers
50+
return new MiddlewareCustomerBuilder(
51+
this.getMiddlewareCustomerBuilder().getPluginContext(),
52+
this.getMiddlewareCustomerBuilder().getMiddlewareMap(),
53+
)
7054
}
7155
}
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,52 @@
11
import type { MiddlewareConsumer, Type } from '@nestjs/common'
2-
import type { MiddlewareConfigProxy } from '@nestjs/common/interfaces'
2+
import type { MiddlewareConfigProxy, RouteInfo } from '@nestjs/common/interfaces'
3+
import type { IPluginContext } from '@unioc/core'
4+
import type { Awaitable } from '@unioc/shared'
35
import { MiddlewareConfigProxyBuilder } from './middleware-config-proxy'
46

5-
export class MiddlewareConsumerBuilder implements MiddlewareConsumer {
6-
private _middlewareConfigs: MiddlewareConfigProxyBuilder[] = []
7+
export type MiddlewareType = Type | ((...args: any[]) => Awaitable<unknown>)
78

8-
apply(...middleware: (Type<any> | ((...args: any[]) => any))[]): MiddlewareConfigProxy {
9-
const config = new MiddlewareConfigProxyBuilder(this)
10-
config.setMiddleware(middleware)
11-
return config
9+
export interface IMiddlewareApplyData {
10+
includedRoutes: (string | RouteInfo)[]
11+
excludedRoutes: (string | Type<any> | RouteInfo)[]
12+
}
13+
14+
export class MiddlewareCustomerBuilder implements MiddlewareConsumer {
15+
constructor(
16+
private readonly _pluginContext: IPluginContext,
17+
private readonly _middlewareMap: Map<MiddlewareType, IMiddlewareApplyData> = new Map<MiddlewareType, IMiddlewareApplyData>(),
18+
) {}
19+
20+
getPluginContext(): IPluginContext {
21+
return this._pluginContext
22+
}
23+
24+
getMiddlewareMap(): Map<MiddlewareType, IMiddlewareApplyData> {
25+
return this._middlewareMap
1226
}
1327

14-
add(...middlewareConfigs: MiddlewareConfigProxyBuilder[]): MiddlewareConsumer {
15-
this._middlewareConfigs = [...this._middlewareConfigs, ...middlewareConfigs]
16-
return this
28+
apply(...middlewares: MiddlewareType[]): MiddlewareConfigProxy {
29+
for (const middleware of middlewares) {
30+
if (!this.getMiddlewareMap().has(middleware)) {
31+
this.getMiddlewareMap().set(middleware, {
32+
includedRoutes: [],
33+
excludedRoutes: [],
34+
})
35+
}
36+
}
37+
return new MiddlewareConfigProxyBuilder(this, middlewares)
1738
}
1839

19-
public getMiddlewareConfigs(): MiddlewareConfigProxyBuilder[] {
20-
return this._middlewareConfigs
40+
merge(preMerge: Map<MiddlewareType, IMiddlewareApplyData>): Map<MiddlewareType, IMiddlewareApplyData> {
41+
for (const [preMergeMiddleware, preMergeData] of preMerge.entries()) {
42+
for (const [middleware, data] of this.getMiddlewareMap().entries()) {
43+
if (preMergeMiddleware === middleware) {
44+
data.includedRoutes = [...data.includedRoutes, ...preMergeData.includedRoutes]
45+
data.excludedRoutes = [...data.excludedRoutes, ...preMergeData.excludedRoutes]
46+
this.getMiddlewareMap().set(middleware, data)
47+
}
48+
}
49+
}
50+
return this.getMiddlewareMap()
2151
}
2252
}

packages/adapter/adapter-nestjs/src/nest-resolver.ts

+9
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ export class NestJSResolver {
99
return this._pluginContext
1010
}
1111

12+
isResolved: boolean = false
13+
1214
async resolveModule(module: ModuleMetadata): Promise<this> {
1315
await this.resolveImports(module.imports)
1416
await this.resolveProviders(module.providers)
1517
await this.resolveControllers(module.controllers)
18+
this.isResolved = true
1619
return this
1720
}
1821

@@ -21,14 +24,20 @@ export class NestJSResolver {
2124
private readonly _controllers = new Set<IClassWrapper>()
2225

2326
getResolvedModules(): IClassWrapper[] {
27+
if (!this.isResolved)
28+
throw new Error('NestJSResolver is not resolved yet, please call resolveModule first.')
2429
return Array.from(this._modules)
2530
}
2631

2732
getResolvedProviders(): IClassWrapper[] {
33+
if (!this.isResolved)
34+
throw new Error('NestJSResolver is not resolved yet, please call resolveModule first.')
2835
return Array.from(this._providers)
2936
}
3037

3138
getResolvedControllers(): IClassWrapper[] {
39+
if (!this.isResolved)
40+
throw new Error('NestJSResolver is not resolved yet, please call resolveModule first.')
3241
return Array.from(this._controllers)
3342
}
3443

packages/adapter/adapter-nestjs/src/plugin.ts

+2-21
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import type { InjectionToken, ModuleMetadata, NestModule } from '@nestjs/common'
1+
import type { InjectionToken, ModuleMetadata } from '@nestjs/common'
22
import type { IPlugin } from '@unioc/core'
33
import type { IClass } from '@unioc/shared'
44
import { PROPERTY_DEPS_METADATA, SELF_DECLARED_DEPS_METADATA } from '@nestjs/common/constants.js'
55
import { AbstractRestfulBootstrap } from '@unioc/web'
66
import { ErrorCollector } from './error-collector'
7-
import { MiddlewareConsumerBuilder } from './middleware/middleware-customer'
87
import { NestJSResolver } from './nest-resolver'
98
import { NestJSRestfulHandler } from './restful/restful-handler'
109
import { NestJSRestfulScanner } from './restful/restful-scanner'
@@ -44,28 +43,10 @@ export function NestJS(options: NestJS.Options = {}): IPlugin {
4443
return
4544
}
4645

47-
const scanner = new NestJSRestfulScanner(this)
46+
const scanner = new NestJSRestfulScanner(this, nestJSResolver.getResolvedModules())
4847
bootstrap.addRestfulScanner(scanner)
4948
bootstrap.createValue(scanner, NestJSRestfulScanner)
5049
await scanner.resolveAll()
51-
52-
// apply restful middlewares
53-
const modules = nestJSResolver.getResolvedModules()
54-
const middlewareConsumer = new MiddlewareConsumerBuilder()
55-
56-
for (const moduleWrapper of modules) {
57-
const resolvedInstance: NestModule = await moduleWrapper.resolve()
58-
if (resolvedInstance && typeof resolvedInstance === 'object' && 'configure' in resolvedInstance && typeof resolvedInstance.configure === 'function') {
59-
await resolvedInstance.configure(middlewareConsumer)
60-
}
61-
}
62-
63-
// apply middlewares
64-
const middlewareConfigs = middlewareConsumer.getMiddlewareConfigs()
65-
for (const middlewareConfig of middlewareConfigs) {
66-
const _middlewares = await middlewareConfig.resolveMiddlewares(this)
67-
// TODO
68-
}
6950
},
7051

7152
async resolve(ctx) {

packages/adapter/adapter-nestjs/src/restful/restful-handler.ts

+8-6
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export class NestJSRestfulHandler implements IRestfulConnect.Handler {
1313
return this._restfulScanner
1414
}
1515

16-
private _getParamValue(currentMetadata: INestJSMethodParamMetadata, ctx: IRestfulConnect.WebContext): unknown {
16+
protected getParamValue(currentMetadata: INestJSMethodParamMetadata, ctx: IRestfulConnect.WebContext): unknown {
1717
switch (currentMetadata.paramType) {
1818
case 'body':
1919
if (typeof ctx.body === 'object' && ctx.body !== null)
@@ -86,7 +86,7 @@ export class NestJSRestfulHandler implements IRestfulConnect.Handler {
8686
if (!currentMetadata)
8787
continue
8888
if (currentMetadata.paramType !== 'body' && currentMetadata.paramType !== 'query' && currentMetadata.paramType !== 'param' && currentMetadata.paramType !== 'custom') {
89-
params[i] = this._getParamValue(currentMetadata, ctx)
89+
params[i] = this.getParamValue(currentMetadata, ctx)
9090
continue
9191
}
9292

@@ -95,7 +95,7 @@ export class NestJSRestfulHandler implements IRestfulConnect.Handler {
9595
.getRestfulScanner()
9696
.executePipe(pipes, {
9797
type: currentMetadata.paramType,
98-
value: this._getParamValue(currentMetadata, ctx),
98+
value: this.getParamValue(currentMetadata, ctx),
9999
data: currentMetadata.data as string,
100100
metatype: currentParamType,
101101
...ctx,
@@ -166,11 +166,13 @@ export class NestJSRestfulHandler implements IRestfulConnect.Handler {
166166
} as const
167167

168168
try {
169-
// 1. Build params with pipes
169+
// 1. Execute middlewares
170+
await this.getRestfulScanner().executeMiddlewares(methodWrapper, ctx)
171+
// 2. Build params with pipes
170172
methodArguments = await this.buildParams(methodWrapper, ctx)
171-
// 2. Execute guards
173+
// 3. Execute guards
172174
await this.executeGuards(methodWrapper, methodArguments)
173-
// 3. Execute the controller method
175+
// 4. Execute the controller method
174176
const result = await methodWrapper.execute(methodArguments, extraOptions)
175177
// 4. TODO: Execute the interceptors
176178
// 5. Send the ending response

packages/adapter/adapter-nestjs/src/restful/restful-scanner.ts

+78-2
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
1-
import type { ArgumentMetadata, ArgumentsHost, CanActivate, ExceptionFilter, ExecutionContext, PipeTransform } from '@nestjs/common'
2-
import type { IArgument, IPluginContext } from '@unioc/core'
1+
import type { ArgumentMetadata, ArgumentsHost, CanActivate, ExceptionFilter, ExecutionContext, NestMiddleware, NestModule, PipeTransform } from '@nestjs/common'
2+
import type { IArgument, IClassWrapper, IPluginContext } from '@unioc/core'
33
import type { IClass } from '@unioc/shared'
44
import type { IHttpParam, IRestfulConnect, IRestfulScanner } from '@unioc/web'
5+
import type { IMiddlewareApplyData, MiddlewareType } from '../middleware/middleware-customer'
56
import type { NestJSMethodOperator } from './method-operator'
7+
import type { NestJSMethodWrapper } from './method-wrapper'
68
import { CONTROLLER_WATERMARK, FILTER_CATCH_EXCEPTIONS } from '@nestjs/common/constants.js'
9+
import { isClass } from '@unioc/shared'
710
import { RestfulScanner } from '@unioc/web'
11+
import { MiddlewareCustomerBuilder } from '../middleware/middleware-customer'
812
import { NestJSControllerWrapper } from './controller-wrapper'
913
import { NestJSRestfulHandler } from './restful-handler'
1014

1115
export interface INestJSPipeArgument extends IArgument, ArgumentMetadata, Partial<Record<IHttpParam, unknown>> {}
1216
export type INestJSFilterCatchType = 'done' | 'no-match'
1317

1418
export class NestJSRestfulScanner extends RestfulScanner implements IRestfulScanner {
19+
constructor(pluginContext: IPluginContext, private readonly _modules: IClassWrapper[]) {
20+
super(pluginContext)
21+
}
22+
1523
override getPluginContext(): IPluginContext {
1624
return super.getPluginContext()
1725
}
@@ -134,6 +142,74 @@ export class NestJSRestfulScanner extends RestfulScanner implements IRestfulScan
134142
return true
135143
}
136144

145+
private _middlewareCustomerMap: Map<MiddlewareType, IMiddlewareApplyData> = new Map<MiddlewareType, IMiddlewareApplyData>()
146+
147+
async resolveMiddlewares(): Promise<void> {
148+
for (const moduleWrapper of this._modules) {
149+
const resolveInstance: NestModule = await moduleWrapper.resolve()
150+
if (!('configure' in resolveInstance) || typeof resolveInstance.configure !== 'function')
151+
continue
152+
153+
const middlewareCustomer = new MiddlewareCustomerBuilder(this.getPluginContext())
154+
await resolveInstance.configure(middlewareCustomer)
155+
this._middlewareCustomerMap = middlewareCustomer.merge(this._middlewareCustomerMap)
156+
}
157+
}
158+
159+
async executeMiddlewares(methodWrapper: NestJSMethodWrapper, ctx: IRestfulConnect.WebContext): Promise<void> {
160+
const fullPaths = methodWrapper.findAll().map(o => o.getFullPath())
161+
const classTarget = methodWrapper.getControllerOperator().getControllerWrapper().getClassWrapper().getTarget()
162+
const matchedMiddlewares: MiddlewareType[] = []
163+
164+
for (const [middleware, data] of this._middlewareCustomerMap.entries()) {
165+
for (const route of data.includedRoutes) {
166+
if (isClass(route)) {
167+
if (classTarget === route) {
168+
matchedMiddlewares.push(middleware)
169+
}
170+
}
171+
else if (typeof route === 'string') {
172+
if (fullPaths.includes(route)) {
173+
matchedMiddlewares.push(middleware)
174+
}
175+
}
176+
else {
177+
if (fullPaths.includes(route.path)) {
178+
matchedMiddlewares.push(middleware)
179+
}
180+
}
181+
}
182+
183+
for (const route of data.excludedRoutes) {
184+
if (isClass(route)) {
185+
if (classTarget === route) {
186+
matchedMiddlewares.splice(matchedMiddlewares.indexOf(middleware), 1)
187+
}
188+
}
189+
else if (typeof route === 'string') {
190+
if (fullPaths.includes(route)) {
191+
matchedMiddlewares.splice(matchedMiddlewares.indexOf(middleware), 1)
192+
}
193+
}
194+
else {
195+
if (fullPaths.includes(route.path)) {
196+
matchedMiddlewares.splice(matchedMiddlewares.indexOf(middleware), 1)
197+
}
198+
}
199+
}
200+
}
201+
202+
for (const middleware of matchedMiddlewares) {
203+
if (isClass(middleware)) {
204+
const middlewareInstance = await this.getPluginContext().createClass<IClass<NestMiddleware>>(middleware).resolve()
205+
await middlewareInstance.use(ctx.request, ctx.response, ctx.next)
206+
}
207+
else {
208+
await middleware(ctx.request, ctx.response, ctx.next)
209+
}
210+
}
211+
}
212+
137213
resolveConnectHandler(): IRestfulConnect.Handler {
138214
return new NestJSRestfulHandler(this)
139215
}

0 commit comments

Comments
 (0)