7
7
*/
8
8
9
9
import type { NodePath , Visitor } from '@babel/traverse' ;
10
- import type { Identifier } from '@babel/types' ;
10
+ import {
11
+ Identifier ,
12
+ Node ,
13
+ Expression ,
14
+ isImportDeclaration ,
15
+ } from '@babel/types' ;
16
+ import sortBy = require( 'lodash.sortby' ) ;
11
17
12
18
// We allow `jest`, `expect`, `require`, all default Node.js globals and all
13
19
// ES2015 built-ins to be used inside of a `jest.mock` factory.
@@ -70,7 +76,10 @@ const WHITELISTED_IDENTIFIERS = new Set<string>(
70
76
] . sort ( ) ,
71
77
) ;
72
78
73
- const JEST_GLOBAL = { name : 'jest' } ;
79
+ const JEST_GLOBAL_NAME = 'jest' ;
80
+ const JEST_GLOBALS_MODULE_NAME = '@jest/globals' ;
81
+ const JEST_GLOBALS_MODULE_JEST_EXPORT_NAME = 'jest' ;
82
+
74
83
// TODO: Should be Visitor<{ids: Set<NodePath<Identifier>>}>, but `ReferencedIdentifier` doesn't exist
75
84
const IDVisitor = {
76
85
ReferencedIdentifier ( path : NodePath < Identifier > ) {
@@ -82,7 +91,7 @@ const IDVisitor = {
82
91
83
92
const FUNCTIONS : Record <
84
93
string ,
85
- ( args : Array < NodePath > ) => boolean
94
+ < T extends Node > ( args : Array < NodePath < T > > ) => boolean
86
95
> = Object . create ( null ) ;
87
96
88
97
FUNCTIONS . mock = args => {
@@ -152,72 +161,95 @@ FUNCTIONS.deepUnmock = args => args.length === 1 && args[0].isStringLiteral();
152
161
FUNCTIONS . disableAutomock = FUNCTIONS . enableAutomock = args =>
153
162
args . length === 0 ;
154
163
155
- export default ( ) : { visitor : Visitor } => {
156
- const shouldHoistExpression = ( expr : NodePath ) : boolean => {
157
- if ( ! expr . isCallExpression ( ) ) {
158
- return false ;
159
- }
164
+ const isIdentifierJestObject = ( identifier : NodePath < Identifier > ) : boolean => {
165
+ // global
166
+ if (
167
+ identifier . node . name === JEST_GLOBAL_NAME &&
168
+ ! identifier . scope . hasBinding ( JEST_GLOBAL_NAME )
169
+ ) {
170
+ return true ;
171
+ }
172
+ // import from '@jest/globals'
173
+ if (
174
+ identifier . referencesImport (
175
+ JEST_GLOBALS_MODULE_NAME ,
176
+ JEST_GLOBALS_MODULE_JEST_EXPORT_NAME ,
177
+ )
178
+ ) {
179
+ return true ;
180
+ }
160
181
161
- // TODO: avoid type casts - the types can be arrays (is it possible to ignore that without casting?)
162
- const callee = expr . get ( 'callee' ) as NodePath ;
163
- const expressionArguments = expr . get ( 'arguments' ) ;
164
- const object = callee . get ( 'object' ) as NodePath ;
165
- const property = callee . get ( 'property' ) as NodePath ;
166
- return (
167
- property . isIdentifier ( ) &&
168
- FUNCTIONS [ property . node . name ] &&
169
- ( object . isIdentifier ( JEST_GLOBAL ) ||
170
- ( callee . isMemberExpression ( ) && shouldHoistExpression ( object ) ) ) &&
171
- FUNCTIONS [ property . node . name ] (
172
- Array . isArray ( expressionArguments )
173
- ? expressionArguments
174
- : [ expressionArguments ] ,
175
- )
176
- ) ;
177
- } ;
182
+ return false ;
183
+ } ;
178
184
179
- const visitor : Visitor = {
180
- ExpressionStatement ( path ) {
181
- if ( shouldHoistExpression ( path . get ( 'expression' ) as NodePath ) ) {
182
- // @ts -ignore: private, magical property
183
- path . node . _blockHoist = Infinity ;
184
- }
185
- } ,
186
- ImportDeclaration ( path ) {
187
- if ( path . node . source . value === '@jest/globals' ) {
188
- // @ts -ignore: private, magical property
189
- path . node . _blockHoist = Infinity ;
190
- }
191
- } ,
192
- VariableDeclaration ( path ) {
193
- const declarations = path . get ( 'declarations' ) ;
185
+ const shouldHoistExpression = < T extends Node > ( expr : NodePath < T > ) : boolean => {
186
+ if ( ! expr . isCallExpression ( ) ) {
187
+ return false ;
188
+ }
189
+
190
+ const callee = expr . get < 'callee' > ( 'callee' ) ;
191
+ const args = expr . get < 'arguments' > ( 'arguments' ) ;
194
192
195
- if ( declarations . length === 1 ) {
196
- const declarationInit = declarations [ 0 ] . get ( 'init' ) ;
193
+ if ( ! callee . isMemberExpression ( ) ) {
194
+ return false ;
195
+ }
196
+
197
+ const object = callee . get < 'object' > ( 'object' ) ;
198
+ const property = callee . get < 'property' > ( 'property' ) as
199
+ | NodePath < Expression >
200
+ | NodePath < Identifier > ;
197
201
198
- if ( declarationInit . isCallExpression ( ) ) {
199
- const callee = declarationInit . get ( 'callee' ) as NodePath ;
200
- const callArguments = declarationInit . get ( 'arguments' ) as Array <
201
- NodePath
202
- > ;
202
+ if ( ! property . isIdentifier ( ) ) {
203
+ return false ;
204
+ }
205
+ const propertyName = property . node . name ;
206
+
207
+ const objectIsJest =
208
+ ( object . isIdentifier ( ) && isIdentifierJestObject ( object ) ) ||
209
+ // The Jest object could be returned from another call since the functions are all chainable.
210
+ shouldHoistExpression ( object ) ;
211
+ if ( ! objectIsJest ) {
212
+ return false ;
213
+ }
214
+
215
+ // Important: Call the function check last
216
+ // It might throw an error to display to the user,
217
+ // which should only happen if we're already sure it's a call on the Jest object.
218
+ const functionLooksHoistable =
219
+ FUNCTIONS [ propertyName ] && FUNCTIONS [ propertyName ] ( args ) ;
220
+
221
+ return functionLooksHoistable ;
222
+ } ;
223
+
224
+ // TODO `require`s
225
+ export default ( ) : { visitor : Visitor } => {
226
+ const visitor : Visitor = {
227
+ Program : {
228
+ enter ( path ) {
229
+ path . node . body = sortBy ( path . node . body , node => {
230
+ console . log ( require ( '@babel/generator' ) . default ( node ) ) ;
203
231
204
232
if (
205
- callee . isIdentifier ( ) &&
206
- callee . node . name === 'require' &&
207
- callArguments . length === 1
233
+ isImportDeclaration ( node ) &&
234
+ node . source . value === JEST_GLOBALS_MODULE_NAME
208
235
) {
209
- const [ argument ] = callArguments ;
210
-
211
- if (
212
- argument . isStringLiteral ( ) &&
213
- argument . node . value === '@jest/globals'
214
- ) {
215
- // @ts -ignore: private, magical property
216
- path . node . _blockHoist = Infinity ;
217
- }
236
+ console . log ( 'found import' ) ;
237
+ return 0 ;
218
238
}
219
- }
220
- }
239
+ const nodePath = path . get ( 'body' ) . find ( p => p . node === node ) ;
240
+ if (
241
+ nodePath &&
242
+ nodePath . isExpressionStatement ( ) &&
243
+ shouldHoistExpression ( nodePath . get < 'expression' > ( 'expression' ) )
244
+ ) {
245
+ console . log ( 'found stmt' ) ;
246
+ return 1 ;
247
+ }
248
+
249
+ console . log ( 'nope' ) ;
250
+ return 2 ;
251
+ } ) ;
252
+ } ,
221
253
} ,
222
254
} ;
223
255
0 commit comments