@@ -6,16 +6,16 @@ import type {
6
6
Expression ,
7
7
Identifier ,
8
8
ImportDeclaration ,
9
- ImportExpression ,
10
9
VariableDeclaration ,
11
10
} from 'estree'
12
11
import type { SourceMap } from 'magic-string'
12
+ import type { RollupAstNode } from 'rollup'
13
13
import type { Plugin , Rollup } from 'vite'
14
14
import type { Node , Positioned } from './esmWalker'
15
15
import { findNodeAround } from 'acorn-walk'
16
16
import MagicString from 'magic-string'
17
17
import { createFilter } from 'vite'
18
- import { esmWalker , getArbitraryModuleIdentifier } from './esmWalker'
18
+ import { esmWalker } from './esmWalker'
19
19
20
20
interface HoistMocksOptions {
21
21
/**
@@ -106,11 +106,14 @@ function isIdentifier(node: any): node is Positioned<Identifier> {
106
106
return node . type === 'Identifier'
107
107
}
108
108
109
- function getBetterEnd ( code : string , node : Node ) {
109
+ function getNodeTail ( code : string , node : Node ) {
110
110
let end = node . end
111
111
if ( code [ node . end ] === ';' ) {
112
112
end += 1
113
113
}
114
+ if ( code [ node . end ] === '\n' ) {
115
+ return end + 1
116
+ }
114
117
if ( code [ node . end + 1 ] === '\n' ) {
115
118
end += 1
116
119
}
@@ -163,45 +166,41 @@ export function hoistMocks(
163
166
hoistedModules = [ 'vitest' ] ,
164
167
} = options
165
168
166
- const hoistIndex = code . match ( hashbangRE ) ?. [ 0 ] . length ?? 0
169
+ // hoist at the start of the file, after the hashbang
170
+ let hoistIndex = hashbangRE . exec ( code ) ?. [ 0 ] . length ?? 0
167
171
168
172
let hoistedModuleImported = false
169
173
170
174
let uid = 0
171
175
const idToImportMap = new Map < string , string > ( )
172
176
177
+ const imports : {
178
+ node : RollupAstNode < ImportDeclaration >
179
+ id : string
180
+ } [ ] = [ ]
181
+
173
182
// this will transform import statements into dynamic ones, if there are imports
174
183
// it will keep the import as is, if we don't need to mock anything
175
184
// in browser environment it will wrap the module value with "vitest_wrap_module" function
176
185
// that returns a proxy to the module so that named exports can be mocked
177
- const transformImportDeclaration = ( node : ImportDeclaration ) => {
178
- const source = node . source . value as string
179
-
180
- const importId = `__vi_import_${ uid ++ } __`
181
- const hasSpecifiers = node . specifiers . length > 0
182
- const code = hasSpecifiers
183
- ? `const ${ importId } = await import('${ source } ')\n`
184
- : `await import('${ source } ')\n`
185
- return {
186
- code,
187
- id : importId ,
188
- }
189
- }
190
-
191
- function defineImport ( node : Positioned < ImportDeclaration > ) {
186
+ function defineImport (
187
+ index : number ,
188
+ importNode : ImportDeclaration & {
189
+ start : number
190
+ end : number
191
+ } ,
192
+ ) {
193
+ const source = importNode . source . value as string
192
194
// always hoist vitest import to top of the file, so
193
195
// "vi" helpers can access it
194
- if ( hoistedModules . includes ( node . source . value as string ) ) {
196
+ if ( hoistedModules . includes ( source ) ) {
195
197
hoistedModuleImported = true
196
198
return
197
199
}
200
+ const importId = `__vi_import_${ uid ++ } __`
201
+ imports . push ( { id : importId , node : importNode } )
198
202
199
- const declaration = transformImportDeclaration ( node )
200
- if ( ! declaration ) {
201
- return null
202
- }
203
- s . appendLeft ( hoistIndex , declaration . code )
204
- return declaration . id
203
+ return importId
205
204
}
206
205
207
206
// 1. check all import statements and record id -> importName map
@@ -210,17 +209,24 @@ export function hoistMocks(
210
209
// import { baz } from 'foo' --> baz -> __import_foo__.baz
211
210
// import * as ok from 'foo' --> ok -> __import_foo__
212
211
if ( node . type === 'ImportDeclaration' ) {
213
- const importId = defineImport ( node )
212
+ const importId = defineImport ( hoistIndex , node )
214
213
if ( ! importId ) {
215
214
continue
216
215
}
217
- s . remove ( node . start , getBetterEnd ( code , node ) )
218
216
for ( const spec of node . specifiers ) {
219
217
if ( spec . type === 'ImportSpecifier' ) {
220
- idToImportMap . set (
221
- spec . local . name ,
222
- `${ importId } .${ getArbitraryModuleIdentifier ( spec . imported ) } ` ,
223
- )
218
+ if ( spec . imported . type === 'Identifier' ) {
219
+ idToImportMap . set (
220
+ spec . local . name ,
221
+ `${ importId } .${ spec . imported . name } ` ,
222
+ )
223
+ }
224
+ else {
225
+ idToImportMap . set (
226
+ spec . local . name ,
227
+ `${ importId } [${ JSON . stringify ( spec . imported . value as string ) } ]` ,
228
+ )
229
+ }
224
230
}
225
231
else if ( spec . type === 'ImportDefaultSpecifier' ) {
226
232
idToImportMap . set ( spec . local . name , `${ importId } .default` )
@@ -235,7 +241,7 @@ export function hoistMocks(
235
241
236
242
const declaredConst = new Set < string > ( )
237
243
const hoistedNodes : Positioned <
238
- CallExpression | VariableDeclaration | AwaitExpression
244
+ CallExpression | VariableDeclaration | AwaitExpression
239
245
> [ ] = [ ]
240
246
241
247
function createSyntaxError ( node : Positioned < Node > , message : string ) {
@@ -347,6 +353,35 @@ export function hoistMocks(
347
353
`Cannot export the result of "${ method } ". Remove export declaration because "${ method } " doesn\'t return anything.` ,
348
354
)
349
355
}
356
+ // rewrite vi.mock(import('..')) into vi.mock('..')
357
+ if (
358
+ node . type === 'CallExpression'
359
+ && node . callee . type === 'MemberExpression'
360
+ && dynamicImportMockMethodNames . includes ( ( node . callee . property as Identifier ) . name )
361
+ ) {
362
+ const moduleInfo = node . arguments [ 0 ] as Positioned < Expression >
363
+ // vi.mock(import('./path')) -> vi.mock('./path')
364
+ if ( moduleInfo . type === 'ImportExpression' ) {
365
+ const source = moduleInfo . source as Positioned < Expression >
366
+ s . overwrite (
367
+ moduleInfo . start ,
368
+ moduleInfo . end ,
369
+ s . slice ( source . start , source . end ) ,
370
+ )
371
+ }
372
+ // vi.mock(await import('./path')) -> vi.mock('./path')
373
+ if (
374
+ moduleInfo . type === 'AwaitExpression'
375
+ && moduleInfo . argument . type === 'ImportExpression'
376
+ ) {
377
+ const source = moduleInfo . argument . source as Positioned < Expression >
378
+ s . overwrite (
379
+ moduleInfo . start ,
380
+ moduleInfo . end ,
381
+ s . slice ( source . start , source . end ) ,
382
+ )
383
+ }
384
+ }
350
385
hoistedNodes . push ( node )
351
386
}
352
387
// vi.doMock(import('./path')) -> vi.doMock('./path')
@@ -384,7 +419,6 @@ export function hoistMocks(
384
419
declarationNode ,
385
420
'Cannot export hoisted variable. You can control hoisting behavior by placing the import from this file first.' ,
386
421
)
387
- // hoist "const variable = vi.hoisted(() => {})"
388
422
hoistedNodes . push ( declarationNode )
389
423
}
390
424
else {
@@ -393,10 +427,8 @@ export function hoistMocks(
393
427
node . start ,
394
428
'AwaitExpression' ,
395
429
) ?. node as Positioned < AwaitExpression > | undefined
396
- // hoist "await vi.hoisted(async () => {})" or "vi.hoisted(() => {})"
397
- hoistedNodes . push (
398
- awaitedExpression ?. argument === node ? awaitedExpression : node ,
399
- )
430
+ const moveNode = awaitedExpression ?. argument === node ? awaitedExpression : node
431
+ hoistedNodes . push ( moveNode )
400
432
}
401
433
}
402
434
}
@@ -446,24 +478,6 @@ export function hoistMocks(
446
478
)
447
479
}
448
480
449
- function rewriteMockDynamicImport (
450
- nodeCode : string ,
451
- moduleInfo : Positioned < ImportExpression > ,
452
- expressionStart : number ,
453
- expressionEnd : number ,
454
- mockStart : number ,
455
- ) {
456
- const source = moduleInfo . source as Positioned < Expression >
457
- const importPath = s . slice ( source . start , source . end )
458
- const nodeCodeStart = expressionStart - mockStart
459
- const nodeCodeEnd = expressionEnd - mockStart
460
- return (
461
- nodeCode . slice ( 0 , nodeCodeStart )
462
- + importPath
463
- + nodeCode . slice ( nodeCodeEnd )
464
- )
465
- }
466
-
467
481
// validate hoistedNodes doesn't have nodes inside other nodes
468
482
for ( let i = 0 ; i < hoistedNodes . length ; i ++ ) {
469
483
const node = hoistedNodes [ i ]
@@ -479,61 +493,42 @@ export function hoistMocks(
479
493
}
480
494
}
481
495
482
- // Wait for imports to be hoisted and then hoist the mocks
483
- const hoistedCode = hoistedNodes
484
- . map ( ( node ) => {
485
- const end = getBetterEnd ( code , node )
486
- /**
487
- * In the following case, we need to change the `user` to user: __vi_import_x__.user
488
- * So we should get the latest code from `s`.
489
- *
490
- * import user from './user'
491
- * vi.mock('./mock.js', () => ({ getSession: vi.fn().mockImplementation(() => ({ user })) }))
492
- */
493
- let nodeCode = s . slice ( node . start , end )
494
-
495
- // rewrite vi.mock(import('..')) into vi.mock('..')
496
- if (
497
- node . type === 'CallExpression'
498
- && node . callee . type === 'MemberExpression'
499
- && dynamicImportMockMethodNames . includes ( ( node . callee . property as Identifier ) . name )
500
- ) {
501
- const moduleInfo = node . arguments [ 0 ] as Positioned < Expression >
502
- // vi.mock(import('./path')) -> vi.mock('./path')
503
- if ( moduleInfo . type === 'ImportExpression' ) {
504
- nodeCode = rewriteMockDynamicImport (
505
- nodeCode ,
506
- moduleInfo ,
507
- moduleInfo . start ,
508
- moduleInfo . end ,
509
- node . start ,
510
- )
511
- }
512
- // vi.mock(await import('./path')) -> vi.mock('./path')
513
- if (
514
- moduleInfo . type === 'AwaitExpression'
515
- && moduleInfo . argument . type === 'ImportExpression'
516
- ) {
517
- nodeCode = rewriteMockDynamicImport (
518
- nodeCode ,
519
- moduleInfo . argument as Positioned < ImportExpression > ,
520
- moduleInfo . start ,
521
- moduleInfo . end ,
522
- node . start ,
523
- )
524
- }
525
- }
496
+ // hoist vi.mock/vi.hoisted
497
+ for ( const node of hoistedNodes ) {
498
+ const end = getNodeTail ( code , node )
499
+ if ( hoistIndex === end ) {
500
+ hoistIndex = end
501
+ }
502
+ else if ( hoistIndex !== node . start ) {
503
+ s . move ( node . start , end , hoistIndex )
504
+ }
505
+ }
526
506
527
- s . remove ( node . start , end )
528
- return `${ nodeCode } ${ nodeCode . endsWith ( '\n' ) ? '' : '\n' } `
529
- } )
530
- . join ( '' )
507
+ // hoist actual dynamic imports last so they are inserted after all hoisted mocks
508
+ for ( const { node : importNode , id : importId } of imports ) {
509
+ const source = importNode . source . value as string
531
510
532
- if ( hoistedCode || hoistedModuleImported ) {
533
- s . prepend (
534
- ( ! hoistedModuleImported && hoistedCode ? API_NOT_FOUND_CHECK ( utilsObjectNames ) : '' )
535
- + hoistedCode ,
511
+ s . update (
512
+ importNode . start ,
513
+ importNode . end ,
514
+ `const ${ importId } = await import(${ JSON . stringify (
515
+ source ,
516
+ ) } );\n`,
536
517
)
518
+
519
+ if ( importNode . start === hoistIndex ) {
520
+ // no need to hoist, but update hoistIndex to keep the order
521
+ hoistIndex = importNode . end
522
+ }
523
+ else {
524
+ // There will be an error if the module is called before it is imported,
525
+ // so the module import statement is hoisted to the top
526
+ s . move ( importNode . start , importNode . end , hoistIndex )
527
+ }
528
+ }
529
+
530
+ if ( ! hoistedModuleImported && hoistedNodes . length ) {
531
+ s . prepend ( API_NOT_FOUND_CHECK ( utilsObjectNames ) )
537
532
}
538
533
539
534
return {
0 commit comments