Skip to content

Commit 5e4163e

Browse files
committed
support type narrowing on literal fields
Fix #526
1 parent ad91f6a commit 5e4163e

File tree

2 files changed

+182
-64
lines changed
  • crates/emmylua_code_analysis/src

2 files changed

+182
-64
lines changed

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,4 +822,32 @@ end
822822
"#,
823823
);
824824
}
825+
826+
#[test]
827+
fn test_issue_526() {
828+
let mut ws = VirtualWorkspace::new();
829+
830+
ws.def(
831+
r#"
832+
--- @alias A { kind: 'A'}
833+
--- @alias B { kind: 'B'}
834+
835+
local x --- @type A|B
836+
837+
if x.kind == 'A' then
838+
a = x
839+
return
840+
end
841+
842+
b = x
843+
"#,
844+
);
845+
846+
let a = ws.expr_ty("a");
847+
let a_expected = ws.ty("A");
848+
assert_eq!(a, a_expected);
849+
let b = ws.expr_ty("b");
850+
let b_expected = ws.ty("B");
851+
assert_eq!(b, b_expected);
852+
}
825853
}

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs

Lines changed: 154 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use emmylua_parser::{
2-
BinaryOperator, LuaBinaryExpr, LuaCallExpr, LuaChunk, LuaExpr, LuaLiteralToken,
2+
BinaryOperator, LuaBinaryExpr, LuaCallExpr, LuaChunk, LuaExpr, LuaIndexMemberExpr,
3+
LuaLiteralToken,
34
};
45

56
use crate::{
67
infer_expr,
78
semantic::infer::{
9+
infer_index::infer_member_by_member_key,
810
narrow::{
911
condition_flow::{call_flow::get_type_at_call_expr, InferConditionFlow},
1012
get_single_antecedent,
@@ -15,7 +17,8 @@ use crate::{
1517
},
1618
VarRefId,
1719
},
18-
DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, LuaType, TypeOps,
20+
DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaInferCache, LuaType, LuaUnionType,
21+
TypeOps,
1922
};
2023

2124
pub fn get_type_at_binary_expr(
@@ -36,73 +39,56 @@ pub fn get_type_at_binary_expr(
3639
return Ok(ResultTypeOrContinue::Continue);
3740
};
3841

39-
match op_token.get_op() {
40-
BinaryOperator::OpLt
41-
| BinaryOperator::OpLe
42-
| BinaryOperator::OpGt
43-
| BinaryOperator::OpGe => {
44-
// todo check number range
42+
let condition_flow = match op_token.get_op() {
43+
BinaryOperator::OpEq => condition_flow,
44+
BinaryOperator::OpNe => condition_flow.get_negated(),
45+
_ => {
46+
return Ok(ResultTypeOrContinue::Continue);
4547
}
46-
BinaryOperator::OpEq => {
47-
let result_type = maybe_type_guard_binary(
48-
db,
49-
tree,
50-
cache,
51-
root,
52-
var_ref_id,
53-
flow_node,
54-
left_expr.clone(),
55-
right_expr.clone(),
56-
condition_flow,
57-
)?;
58-
if let ResultTypeOrContinue::Result(result_type) = result_type {
59-
return Ok(ResultTypeOrContinue::Result(result_type));
60-
}
48+
};
6149

62-
return maybe_var_eq_narrow(
63-
db,
64-
tree,
65-
cache,
66-
root,
67-
var_ref_id,
68-
flow_node,
69-
left_expr,
70-
right_expr,
71-
condition_flow,
72-
);
73-
}
74-
BinaryOperator::OpNe => {
75-
let result_type = maybe_type_guard_binary(
76-
db,
77-
tree,
78-
cache,
79-
root,
80-
var_ref_id,
81-
flow_node,
82-
left_expr.clone(),
83-
right_expr.clone(),
84-
condition_flow.get_negated(),
85-
)?;
86-
if let ResultTypeOrContinue::Result(result_type) = result_type {
87-
return Ok(ResultTypeOrContinue::Result(result_type));
88-
}
50+
let mut result_type = maybe_type_guard_binary(
51+
db,
52+
tree,
53+
cache,
54+
root,
55+
var_ref_id,
56+
flow_node,
57+
left_expr.clone(),
58+
right_expr.clone(),
59+
condition_flow,
60+
)?;
61+
if let ResultTypeOrContinue::Result(result_type) = result_type {
62+
return Ok(ResultTypeOrContinue::Result(result_type));
63+
}
8964

90-
return maybe_var_eq_narrow(
91-
db,
92-
tree,
93-
cache,
94-
root,
95-
var_ref_id,
96-
flow_node,
97-
left_expr,
98-
right_expr,
99-
condition_flow.get_negated(),
100-
);
101-
}
102-
_ => {}
65+
result_type = maybe_field_literal_eq_narrow(
66+
db,
67+
tree,
68+
cache,
69+
root,
70+
var_ref_id,
71+
flow_node,
72+
left_expr.clone(),
73+
right_expr.clone(),
74+
condition_flow,
75+
)?;
76+
77+
if let ResultTypeOrContinue::Result(result_type) = result_type {
78+
return Ok(ResultTypeOrContinue::Result(result_type));
10379
}
10480

105-
Ok(ResultTypeOrContinue::Continue)
81+
return maybe_var_eq_narrow(
82+
db,
83+
tree,
84+
cache,
85+
root,
86+
var_ref_id,
87+
flow_node,
88+
left_expr,
89+
right_expr,
90+
condition_flow,
91+
);
10692
}
10793

10894
fn maybe_type_guard_binary(
@@ -296,3 +282,107 @@ fn maybe_var_eq_narrow(
296282
}
297283
}
298284
}
285+
286+
fn maybe_field_literal_eq_narrow(
287+
db: &DbIndex,
288+
tree: &FlowTree,
289+
cache: &mut LuaInferCache,
290+
root: &LuaChunk,
291+
var_ref_id: &VarRefId,
292+
flow_node: &FlowNode,
293+
left_expr: LuaExpr,
294+
right_expr: LuaExpr,
295+
condition_flow: InferConditionFlow,
296+
) -> Result<ResultTypeOrContinue, InferFailReason> {
297+
// only check left as need narrow
298+
let (index_expr, literal_expr) = match (left_expr, right_expr) {
299+
(LuaExpr::IndexExpr(index_expr), LuaExpr::LiteralExpr(literal_expr)) => {
300+
(index_expr, literal_expr)
301+
}
302+
(LuaExpr::LiteralExpr(literal_expr), LuaExpr::IndexExpr(index_expr)) => {
303+
(index_expr, literal_expr)
304+
}
305+
_ => return Ok(ResultTypeOrContinue::Continue),
306+
};
307+
308+
let Some(prefix_expr) = index_expr.get_prefix_expr() else {
309+
return Ok(ResultTypeOrContinue::Continue);
310+
};
311+
312+
let Some(maybe_var_ref_id) = get_var_expr_var_ref_id(db, cache, prefix_expr.clone()) else {
313+
// If we cannot find a reference declaration ID, we cannot narrow it
314+
return Ok(ResultTypeOrContinue::Continue);
315+
};
316+
317+
if maybe_var_ref_id != *var_ref_id {
318+
return Ok(ResultTypeOrContinue::Continue);
319+
}
320+
321+
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
322+
let left_type = get_type_at_flow(db, tree, cache, root, &var_ref_id, antecedent_flow_id)?;
323+
let LuaType::Union(union_type) = left_type else {
324+
return Ok(ResultTypeOrContinue::Continue);
325+
};
326+
327+
let right_type = infer_expr(db, cache, LuaExpr::LiteralExpr(literal_expr))?;
328+
let mut guard = InferGuard::new();
329+
let index = LuaIndexMemberExpr::IndexExpr(index_expr);
330+
let mut opt_result = None;
331+
let mut union_types = union_type.get_types();
332+
for (i, sub_type) in union_types.iter().enumerate() {
333+
let member_type =
334+
match infer_member_by_member_key(db, cache, &sub_type, index.clone(), &mut guard) {
335+
Ok(member_type) => member_type,
336+
Err(_) => continue, // If we cannot infer the member type, skip this type
337+
};
338+
if const_type_eq(&member_type, &right_type) {
339+
// If the right type matches the member type, we can narrow it
340+
opt_result = Some(i);
341+
}
342+
}
343+
344+
match condition_flow {
345+
InferConditionFlow::TrueCondition => {
346+
if let Some(i) = opt_result {
347+
return Ok(ResultTypeOrContinue::Result(union_types[i].clone()));
348+
}
349+
}
350+
InferConditionFlow::FalseCondition => {
351+
if let Some(i) = opt_result {
352+
union_types.remove(i);
353+
match union_types.len() {
354+
0 => return Ok(ResultTypeOrContinue::Result(LuaType::Unknown)),
355+
1 => return Ok(ResultTypeOrContinue::Result(union_types[0].clone())),
356+
_ => {
357+
let union_type = LuaUnionType::new(union_types);
358+
return Ok(ResultTypeOrContinue::Result(LuaType::Union(
359+
union_type.into(),
360+
)));
361+
}
362+
}
363+
}
364+
}
365+
}
366+
367+
Ok(ResultTypeOrContinue::Continue)
368+
}
369+
370+
fn const_type_eq(left_type: &LuaType, right_type: &LuaType) -> bool {
371+
if left_type == right_type {
372+
return true;
373+
}
374+
375+
match (left_type, right_type) {
376+
(
377+
LuaType::StringConst(l) | LuaType::DocStringConst(l),
378+
LuaType::StringConst(r) | LuaType::DocStringConst(r),
379+
) => l == r,
380+
(LuaType::FloatConst(l), LuaType::FloatConst(r)) => l == r,
381+
(LuaType::BooleanConst(l), LuaType::BooleanConst(r)) => l == r,
382+
(
383+
LuaType::IntegerConst(l) | LuaType::DocIntegerConst(l),
384+
LuaType::IntegerConst(r) | LuaType::DocIntegerConst(r),
385+
) => l == r,
386+
_ => false,
387+
}
388+
}

0 commit comments

Comments
 (0)