1
1
use emmylua_parser:: {
2
- BinaryOperator , LuaBinaryExpr , LuaCallExpr , LuaChunk , LuaExpr , LuaLiteralToken ,
2
+ BinaryOperator , LuaBinaryExpr , LuaCallExpr , LuaChunk , LuaExpr , LuaIndexMemberExpr ,
3
+ LuaLiteralToken ,
3
4
} ;
4
5
5
6
use crate :: {
6
7
infer_expr,
7
8
semantic:: infer:: {
9
+ infer_index:: infer_member_by_member_key,
8
10
narrow:: {
9
11
condition_flow:: { call_flow:: get_type_at_call_expr, InferConditionFlow } ,
10
12
get_single_antecedent,
@@ -15,7 +17,8 @@ use crate::{
15
17
} ,
16
18
VarRefId ,
17
19
} ,
18
- DbIndex , FlowNode , FlowTree , InferFailReason , LuaInferCache , LuaType , TypeOps ,
20
+ DbIndex , FlowNode , FlowTree , InferFailReason , InferGuard , LuaInferCache , LuaType , LuaUnionType ,
21
+ TypeOps ,
19
22
} ;
20
23
21
24
pub fn get_type_at_binary_expr (
@@ -36,73 +39,56 @@ pub fn get_type_at_binary_expr(
36
39
return Ok ( ResultTypeOrContinue :: Continue ) ;
37
40
} ;
38
41
39
- match op_token. get_op ( ) {
40
- BinaryOperator :: OpLt
41
- | BinaryOperator :: OpLe
42
- | BinaryOperator :: OpGt
43
- | BinaryOperator :: OpGe => {
44
- // todo check number range
42
+ let condition_flow = match op_token. get_op ( ) {
43
+ BinaryOperator :: OpEq => condition_flow,
44
+ BinaryOperator :: OpNe => condition_flow. get_negated ( ) ,
45
+ _ => {
46
+ return Ok ( ResultTypeOrContinue :: Continue ) ;
45
47
}
46
- BinaryOperator :: OpEq => {
47
- let result_type = maybe_type_guard_binary (
48
- db,
49
- tree,
50
- cache,
51
- root,
52
- var_ref_id,
53
- flow_node,
54
- left_expr. clone ( ) ,
55
- right_expr. clone ( ) ,
56
- condition_flow,
57
- ) ?;
58
- if let ResultTypeOrContinue :: Result ( result_type) = result_type {
59
- return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
60
- }
48
+ } ;
61
49
62
- return maybe_var_eq_narrow (
63
- db,
64
- tree,
65
- cache,
66
- root,
67
- var_ref_id,
68
- flow_node,
69
- left_expr,
70
- right_expr,
71
- condition_flow,
72
- ) ;
73
- }
74
- BinaryOperator :: OpNe => {
75
- let result_type = maybe_type_guard_binary (
76
- db,
77
- tree,
78
- cache,
79
- root,
80
- var_ref_id,
81
- flow_node,
82
- left_expr. clone ( ) ,
83
- right_expr. clone ( ) ,
84
- condition_flow. get_negated ( ) ,
85
- ) ?;
86
- if let ResultTypeOrContinue :: Result ( result_type) = result_type {
87
- return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
88
- }
50
+ let mut result_type = maybe_type_guard_binary (
51
+ db,
52
+ tree,
53
+ cache,
54
+ root,
55
+ var_ref_id,
56
+ flow_node,
57
+ left_expr. clone ( ) ,
58
+ right_expr. clone ( ) ,
59
+ condition_flow,
60
+ ) ?;
61
+ if let ResultTypeOrContinue :: Result ( result_type) = result_type {
62
+ return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
63
+ }
89
64
90
- return maybe_var_eq_narrow (
91
- db,
92
- tree,
93
- cache,
94
- root,
95
- var_ref_id,
96
- flow_node,
97
- left_expr,
98
- right_expr,
99
- condition_flow. get_negated ( ) ,
100
- ) ;
101
- }
102
- _ => { }
65
+ result_type = maybe_field_literal_eq_narrow (
66
+ db,
67
+ tree,
68
+ cache,
69
+ root,
70
+ var_ref_id,
71
+ flow_node,
72
+ left_expr. clone ( ) ,
73
+ right_expr. clone ( ) ,
74
+ condition_flow,
75
+ ) ?;
76
+
77
+ if let ResultTypeOrContinue :: Result ( result_type) = result_type {
78
+ return Ok ( ResultTypeOrContinue :: Result ( result_type) ) ;
103
79
}
104
80
105
- Ok ( ResultTypeOrContinue :: Continue )
81
+ return maybe_var_eq_narrow (
82
+ db,
83
+ tree,
84
+ cache,
85
+ root,
86
+ var_ref_id,
87
+ flow_node,
88
+ left_expr,
89
+ right_expr,
90
+ condition_flow,
91
+ ) ;
106
92
}
107
93
108
94
fn maybe_type_guard_binary (
@@ -296,3 +282,107 @@ fn maybe_var_eq_narrow(
296
282
}
297
283
}
298
284
}
285
+
286
+ fn maybe_field_literal_eq_narrow (
287
+ db : & DbIndex ,
288
+ tree : & FlowTree ,
289
+ cache : & mut LuaInferCache ,
290
+ root : & LuaChunk ,
291
+ var_ref_id : & VarRefId ,
292
+ flow_node : & FlowNode ,
293
+ left_expr : LuaExpr ,
294
+ right_expr : LuaExpr ,
295
+ condition_flow : InferConditionFlow ,
296
+ ) -> Result < ResultTypeOrContinue , InferFailReason > {
297
+ // only check left as need narrow
298
+ let ( index_expr, literal_expr) = match ( left_expr, right_expr) {
299
+ ( LuaExpr :: IndexExpr ( index_expr) , LuaExpr :: LiteralExpr ( literal_expr) ) => {
300
+ ( index_expr, literal_expr)
301
+ }
302
+ ( LuaExpr :: LiteralExpr ( literal_expr) , LuaExpr :: IndexExpr ( index_expr) ) => {
303
+ ( index_expr, literal_expr)
304
+ }
305
+ _ => return Ok ( ResultTypeOrContinue :: Continue ) ,
306
+ } ;
307
+
308
+ let Some ( prefix_expr) = index_expr. get_prefix_expr ( ) else {
309
+ return Ok ( ResultTypeOrContinue :: Continue ) ;
310
+ } ;
311
+
312
+ let Some ( maybe_var_ref_id) = get_var_expr_var_ref_id ( db, cache, prefix_expr. clone ( ) ) else {
313
+ // If we cannot find a reference declaration ID, we cannot narrow it
314
+ return Ok ( ResultTypeOrContinue :: Continue ) ;
315
+ } ;
316
+
317
+ if maybe_var_ref_id != * var_ref_id {
318
+ return Ok ( ResultTypeOrContinue :: Continue ) ;
319
+ }
320
+
321
+ let antecedent_flow_id = get_single_antecedent ( tree, flow_node) ?;
322
+ let left_type = get_type_at_flow ( db, tree, cache, root, & var_ref_id, antecedent_flow_id) ?;
323
+ let LuaType :: Union ( union_type) = left_type else {
324
+ return Ok ( ResultTypeOrContinue :: Continue ) ;
325
+ } ;
326
+
327
+ let right_type = infer_expr ( db, cache, LuaExpr :: LiteralExpr ( literal_expr) ) ?;
328
+ let mut guard = InferGuard :: new ( ) ;
329
+ let index = LuaIndexMemberExpr :: IndexExpr ( index_expr) ;
330
+ let mut opt_result = None ;
331
+ let mut union_types = union_type. get_types ( ) ;
332
+ for ( i, sub_type) in union_types. iter ( ) . enumerate ( ) {
333
+ let member_type =
334
+ match infer_member_by_member_key ( db, cache, & sub_type, index. clone ( ) , & mut guard) {
335
+ Ok ( member_type) => member_type,
336
+ Err ( _) => continue , // If we cannot infer the member type, skip this type
337
+ } ;
338
+ if const_type_eq ( & member_type, & right_type) {
339
+ // If the right type matches the member type, we can narrow it
340
+ opt_result = Some ( i) ;
341
+ }
342
+ }
343
+
344
+ match condition_flow {
345
+ InferConditionFlow :: TrueCondition => {
346
+ if let Some ( i) = opt_result {
347
+ return Ok ( ResultTypeOrContinue :: Result ( union_types[ i] . clone ( ) ) ) ;
348
+ }
349
+ }
350
+ InferConditionFlow :: FalseCondition => {
351
+ if let Some ( i) = opt_result {
352
+ union_types. remove ( i) ;
353
+ match union_types. len ( ) {
354
+ 0 => return Ok ( ResultTypeOrContinue :: Result ( LuaType :: Unknown ) ) ,
355
+ 1 => return Ok ( ResultTypeOrContinue :: Result ( union_types[ 0 ] . clone ( ) ) ) ,
356
+ _ => {
357
+ let union_type = LuaUnionType :: new ( union_types) ;
358
+ return Ok ( ResultTypeOrContinue :: Result ( LuaType :: Union (
359
+ union_type. into ( ) ,
360
+ ) ) ) ;
361
+ }
362
+ }
363
+ }
364
+ }
365
+ }
366
+
367
+ Ok ( ResultTypeOrContinue :: Continue )
368
+ }
369
+
370
+ fn const_type_eq ( left_type : & LuaType , right_type : & LuaType ) -> bool {
371
+ if left_type == right_type {
372
+ return true ;
373
+ }
374
+
375
+ match ( left_type, right_type) {
376
+ (
377
+ LuaType :: StringConst ( l) | LuaType :: DocStringConst ( l) ,
378
+ LuaType :: StringConst ( r) | LuaType :: DocStringConst ( r) ,
379
+ ) => l == r,
380
+ ( LuaType :: FloatConst ( l) , LuaType :: FloatConst ( r) ) => l == r,
381
+ ( LuaType :: BooleanConst ( l) , LuaType :: BooleanConst ( r) ) => l == r,
382
+ (
383
+ LuaType :: IntegerConst ( l) | LuaType :: DocIntegerConst ( l) ,
384
+ LuaType :: IntegerConst ( r) | LuaType :: DocIntegerConst ( r) ,
385
+ ) => l == r,
386
+ _ => false ,
387
+ }
388
+ }
0 commit comments