Skip to content

Commit 16a926d

Browse files
authored
[red-knot] infer int literal types (#11623)
## Summary Give red-knot the ability to infer int literal types. This is quick and easy, mostly because these types are a convenient way to observe control-flow handling with simple assignments. ## Test Plan Added test.
1 parent 05566c6 commit 16a926d

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

crates/red_knot/src/types.rs

+6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub enum Type {
3636
Instance(ClassTypeId),
3737
Union(UnionTypeId),
3838
Intersection(IntersectionTypeId),
39+
IntLiteral(i64),
3940
// TODO protocols, callable types, overloads, generics, type vars
4041
}
4142

@@ -78,6 +79,10 @@ impl Type {
7879
// TODO return the intersection of those results
7980
todo!("attribute lookup on Intersection type")
8081
}
82+
Type::IntLiteral(_) => {
83+
// TODO raise error
84+
Ok(Some(Type::Unknown))
85+
}
8186
}
8287
}
8388
}
@@ -616,6 +621,7 @@ impl std::fmt::Display for DisplayType<'_> {
616621
.get_module(int_id.file_id)
617622
.get_intersection(int_id.intersection_id)
618623
.display(f, self.store),
624+
Type::IntLiteral(n) => write!(f, "Literal[{n}]"),
619625
}
620626
}
621627
}

crates/red_knot/src/types/infer.rs

+40
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu
145145
// TODO cache the resolution of the type on the node
146146
let symbols = symbol_table(db, file_id)?;
147147
match expr {
148+
ast::Expr::NumberLiteral(ast::ExprNumberLiteral { value, .. }) => {
149+
match value {
150+
ast::Number::Int(n) => {
151+
// TODO support big int literals
152+
Ok(n.as_i64().map(Type::IntLiteral).unwrap_or(Type::Unknown))
153+
}
154+
// TODO builtins.float or builtins.complex
155+
_ => Ok(Type::Unknown),
156+
}
157+
}
148158
ast::Expr::Name(name) => {
149159
// TODO look up in the correct scope, don't assume global
150160
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
@@ -348,4 +358,34 @@ mod tests {
348358
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
349359
Ok(())
350360
}
361+
362+
#[test]
363+
fn resolve_literal() -> anyhow::Result<()> {
364+
let case = create_test()?;
365+
let db = &case.db;
366+
367+
let path = case.src.path().join("a.py");
368+
std::fs::write(path, "x = 1")?;
369+
let file = resolve_module(db, ModuleName::new("a"))?
370+
.expect("module should be found")
371+
.path(db)?
372+
.file();
373+
let syms = symbol_table(db, file)?;
374+
let x_sym = syms
375+
.root_symbol_id_by_name("x")
376+
.expect("x symbol should be found");
377+
378+
let ty = infer_symbol_type(
379+
db,
380+
GlobalSymbolId {
381+
file_id: file,
382+
symbol_id: x_sym,
383+
},
384+
)?;
385+
386+
let jar = HasJar::<SemanticJar>::jar(db)?;
387+
assert!(matches!(ty, Type::IntLiteral(_)));
388+
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[1]");
389+
Ok(())
390+
}
351391
}

0 commit comments

Comments
 (0)