Skip to content

Commit 0a09a28

Browse files
committed
Merge branch 'dcreager/generic-constructor' into dcreager/legacy-typevar-instance
* dcreager/generic-constructor: Pull FunctionLiteral out into separate type
2 parents bcea973 + bfac80b commit 0a09a28

File tree

4 files changed

+208
-163
lines changed

4 files changed

+208
-163
lines changed

crates/red_knot_python_semantic/src/types.rs

Lines changed: 150 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,10 @@ impl<'db> Type<'db> {
759759
.expect("Expected a Type::Intersection variant")
760760
}
761761

762+
pub(crate) fn function_literal(db: &'db dyn Db, function: FunctionLiteral<'db>) -> Self {
763+
Self::FunctionLiteral(FunctionType::new(db, function, None, None))
764+
}
765+
762766
pub const fn into_function_literal(self) -> Option<FunctionType<'db>> {
763767
match self {
764768
Type::FunctionLiteral(function_type) => Some(function_type),
@@ -6059,6 +6063,25 @@ impl<'db> FunctionSignature<'db> {
60596063
pub(crate) fn iter(&self) -> Iter<Signature<'db>> {
60606064
self.as_slice().iter()
60616065
}
6066+
6067+
fn iter_mut(&mut self) -> impl Iterator<Item = &mut Signature<'db>> {
6068+
match self {
6069+
Self::Single(signature) => std::slice::from_mut(signature).iter_mut(),
6070+
Self::Overloaded(signatures, _) => signatures.iter_mut(),
6071+
}
6072+
}
6073+
6074+
fn set_inherited_generic_context(&mut self, inherited_generic_context: GenericContext<'db>) {
6075+
self.iter_mut().for_each(|signature| {
6076+
signature.set_inherited_generic_context(inherited_generic_context);
6077+
});
6078+
}
6079+
6080+
fn apply_specialization(&mut self, db: &'db dyn Db, specialization: Specialization<'db>) {
6081+
self.iter_mut().for_each(|signature| {
6082+
signature.apply_specialization(db, specialization);
6083+
});
6084+
}
60626085
}
60636086

60646087
impl<'db> IntoIterator for &'db FunctionSignature<'db> {
@@ -6072,22 +6095,7 @@ impl<'db> IntoIterator for &'db FunctionSignature<'db> {
60726095

60736096
#[salsa::interned(debug)]
60746097
pub struct FunctionType<'db> {
6075-
/// Name of the function at definition.
6076-
#[return_ref]
6077-
pub name: ast::name::Name,
6078-
6079-
/// Is this a function that we special-case somehow? If so, which one?
6080-
known: Option<KnownFunction>,
6081-
6082-
/// The scope that's created by the function, in which the function body is evaluated.
6083-
body_scope: ScopeId<'db>,
6084-
6085-
/// A set of special decorators that were applied to this function
6086-
decorators: FunctionDecorators,
6087-
6088-
/// The arguments to `dataclass_transformer`, if this function was annotated
6089-
/// with `@dataclass_transformer(...)`.
6090-
dataclass_transformer_params: Option<DataclassTransformerParams>,
6098+
function: FunctionLiteral<'db>,
60916099

60926100
/// The inherited generic context, if this function is a class method being used to infer the
60936101
/// specialization of its generic class. If the method is itself generic, this is in addition
@@ -6103,7 +6111,7 @@ pub struct FunctionType<'db> {
61036111
#[salsa::tracked]
61046112
impl<'db> FunctionType<'db> {
61056113
pub(crate) fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool {
6106-
self.decorators(db).contains(decorator)
6114+
self.function(db).has_known_decorator(db, decorator)
61076115
}
61086116

61096117
/// Convert the `FunctionType` into a [`Type::Callable`].
@@ -6123,25 +6131,29 @@ impl<'db> FunctionType<'db> {
61236131
Type::BoundMethod(BoundMethodType::new(db, self, self_instance))
61246132
}
61256133

6134+
pub fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {
6135+
self.function(db).name(db)
6136+
}
6137+
61266138
/// Returns the [`FileRange`] of the function's name.
61276139
pub fn focus_range(self, db: &dyn Db) -> FileRange {
6140+
let body_scope = self.function(db).body_scope(db);
61286141
FileRange::new(
6129-
self.body_scope(db).file(db),
6130-
self.body_scope(db).node(db).expect_function().name.range,
6142+
body_scope.file(db),
6143+
body_scope.node(db).expect_function().name.range,
61316144
)
61326145
}
61336146

61346147
pub fn full_range(self, db: &dyn Db) -> FileRange {
6148+
let body_scope = self.function(db).body_scope(db);
61356149
FileRange::new(
6136-
self.body_scope(db).file(db),
6137-
self.body_scope(db).node(db).expect_function().range,
6150+
body_scope.file(db),
6151+
body_scope.node(db).expect_function().range,
61386152
)
61396153
}
61406154

61416155
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
6142-
let body_scope = self.body_scope(db);
6143-
let index = semantic_index(db, body_scope.file(db));
6144-
index.expect_single_definition(body_scope.node(db).expect_function())
6156+
self.function(db).definition(db)
61456157
}
61466158

61476159
/// Typed externally-visible signature for this function.
@@ -6158,6 +6170,114 @@ impl<'db> FunctionType<'db> {
61586170
/// would depend on the function's AST and rerun for every change in that file.
61596171
#[salsa::tracked(return_ref)]
61606172
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
6173+
let mut signature = self.function(db).signature(db).clone();
6174+
if let Some(inherited_generic_context) = self.inherited_generic_context(db) {
6175+
signature.set_inherited_generic_context(inherited_generic_context);
6176+
}
6177+
if let Some(specialization) = self.specialization(db) {
6178+
signature.apply_specialization(db, specialization);
6179+
}
6180+
signature
6181+
}
6182+
6183+
pub(crate) fn known(self, db: &'db dyn Db) -> Option<KnownFunction> {
6184+
self.function(db).known(db)
6185+
}
6186+
6187+
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
6188+
self.known(db) == Some(known_function)
6189+
}
6190+
6191+
fn with_dataclass_transformer_params(
6192+
self,
6193+
db: &'db dyn Db,
6194+
params: DataclassTransformerParams,
6195+
) -> Self {
6196+
let function = self
6197+
.function(db)
6198+
.with_dataclass_transformer_params(db, params);
6199+
Self::new(
6200+
db,
6201+
function,
6202+
self.inherited_generic_context(db),
6203+
self.specialization(db),
6204+
)
6205+
}
6206+
6207+
fn with_inherited_generic_context(
6208+
self,
6209+
db: &'db dyn Db,
6210+
inherited_generic_context: GenericContext<'db>,
6211+
) -> Self {
6212+
// A function cannot inherit more than one generic context from its containing class.
6213+
debug_assert!(self.inherited_generic_context(db).is_none());
6214+
Self::new(
6215+
db,
6216+
self.function(db),
6217+
Some(inherited_generic_context),
6218+
self.specialization(db),
6219+
)
6220+
}
6221+
6222+
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
6223+
let specialization = match self.specialization(db) {
6224+
Some(existing) => existing.apply_specialization(db, specialization),
6225+
None => specialization,
6226+
};
6227+
Self::new(
6228+
db,
6229+
self.function(db),
6230+
self.inherited_generic_context(db),
6231+
Some(specialization),
6232+
)
6233+
}
6234+
6235+
fn find_legacy_typevars(
6236+
self,
6237+
db: &'db dyn Db,
6238+
typevars: &mut FxOrderSet<TypeVarInstance<'db>>,
6239+
) {
6240+
let signatures = self.signature(db);
6241+
for signature in signatures {
6242+
signature.find_legacy_typevars(db, typevars);
6243+
}
6244+
}
6245+
}
6246+
6247+
#[salsa::interned(debug)]
6248+
pub struct FunctionLiteral<'db> {
6249+
/// Name of the function at definition.
6250+
#[return_ref]
6251+
pub name: ast::name::Name,
6252+
6253+
/// Is this a function that we special-case somehow? If so, which one?
6254+
known: Option<KnownFunction>,
6255+
6256+
/// The scope that's created by the function, in which the function body is evaluated.
6257+
body_scope: ScopeId<'db>,
6258+
6259+
/// A set of special decorators that were applied to this function
6260+
decorators: FunctionDecorators,
6261+
6262+
/// The arguments to `dataclass_transformer`, if this function was annotated
6263+
/// with `@dataclass_transformer(...)`.
6264+
dataclass_transformer_params: Option<DataclassTransformerParams>,
6265+
}
6266+
6267+
#[salsa::tracked]
6268+
impl<'db> FunctionLiteral<'db> {
6269+
fn has_known_decorator(self, db: &dyn Db, decorator: FunctionDecorators) -> bool {
6270+
self.decorators(db).contains(decorator)
6271+
}
6272+
6273+
fn definition(self, db: &'db dyn Db) -> Definition<'db> {
6274+
let body_scope = self.body_scope(db);
6275+
let index = semantic_index(db, body_scope.file(db));
6276+
index.expect_single_definition(body_scope.node(db).expect_function())
6277+
}
6278+
6279+
#[salsa::tracked(return_ref)]
6280+
fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
61616281
let internal_signature = self.internal_signature(db);
61626282

61636283
// The semantic model records a use for each function on the name node. This is used here
@@ -6222,21 +6342,7 @@ impl<'db> FunctionType<'db> {
62226342
let index = semantic_index(db, scope.file(db));
62236343
GenericContext::from_type_params(db, index, type_params)
62246344
});
6225-
let mut signature = Signature::from_function(
6226-
db,
6227-
generic_context,
6228-
self.inherited_generic_context(db),
6229-
definition,
6230-
function_stmt_node,
6231-
);
6232-
if let Some(specialization) = self.specialization(db) {
6233-
signature = signature.apply_specialization(db, specialization);
6234-
}
6235-
signature
6236-
}
6237-
6238-
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
6239-
self.known(db) == Some(known_function)
6345+
Signature::from_function(db, generic_context, definition, function_stmt_node)
62406346
}
62416347

62426348
fn with_dataclass_transformer_params(
@@ -6251,57 +6357,8 @@ impl<'db> FunctionType<'db> {
62516357
self.body_scope(db),
62526358
self.decorators(db),
62536359
Some(params),
6254-
self.inherited_generic_context(db),
6255-
self.specialization(db),
6256-
)
6257-
}
6258-
6259-
fn with_inherited_generic_context(
6260-
self,
6261-
db: &'db dyn Db,
6262-
inherited_generic_context: GenericContext<'db>,
6263-
) -> Self {
6264-
// A function cannot inherit more than one generic context from its containing class.
6265-
debug_assert!(self.inherited_generic_context(db).is_none());
6266-
Self::new(
6267-
db,
6268-
self.name(db).clone(),
6269-
self.known(db),
6270-
self.body_scope(db),
6271-
self.decorators(db),
6272-
self.dataclass_transformer_params(db),
6273-
Some(inherited_generic_context),
6274-
self.specialization(db),
6275-
)
6276-
}
6277-
6278-
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
6279-
let specialization = match self.specialization(db) {
6280-
Some(existing) => existing.apply_specialization(db, specialization),
6281-
None => specialization,
6282-
};
6283-
Self::new(
6284-
db,
6285-
self.name(db).clone(),
6286-
self.known(db),
6287-
self.body_scope(db),
6288-
self.decorators(db),
6289-
self.dataclass_transformer_params(db),
6290-
self.inherited_generic_context(db),
6291-
Some(specialization),
62926360
)
62936361
}
6294-
6295-
fn find_legacy_typevars(
6296-
self,
6297-
db: &'db dyn Db,
6298-
typevars: &mut FxOrderSet<TypeVarInstance<'db>>,
6299-
) {
6300-
let signatures = self.signature(db);
6301-
for signature in signatures {
6302-
signature.find_legacy_typevars(db, typevars);
6303-
}
6304-
}
63056362
}
63066363

63076364
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might
@@ -6519,9 +6576,11 @@ impl<'db> CallableType<'db> {
65196576
fn apply_specialization(self, db: &'db dyn Db, specialization: Specialization<'db>) -> Self {
65206577
CallableType::from_overloads(
65216578
db,
6522-
self.signatures(db)
6523-
.iter()
6524-
.map(|signature| signature.apply_specialization(db, specialization)),
6579+
self.signatures(db).iter().map(|signature| {
6580+
let mut signature = signature.clone();
6581+
signature.apply_specialization(db, specialization);
6582+
signature
6583+
}),
65256584
)
65266585
}
65276586

crates/red_knot_python_semantic/src/types/call/bind.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,9 @@ impl<'db> Bindings<'db> {
705705
}
706706

707707
_ => {
708-
if let Some(params) = function_type.dataclass_transformer_params(db) {
708+
if let Some(params) =
709+
function_type.function(db).dataclass_transformer_params(db)
710+
{
709711
// This is a call to a custom function that was decorated with `@dataclass_transformer`.
710712
// If this function was called with a keyword argument like `order=False`, we extract
711713
// the argument type and overwrite the corresponding flag in `dataclass_params` after
@@ -1486,7 +1488,7 @@ impl<'db> BindingError<'db> {
14861488
) -> Option<(Span, Span)> {
14871489
match callable_ty {
14881490
Type::FunctionLiteral(function) => {
1489-
let function_scope = function.body_scope(db);
1491+
let function_scope = function.function(db).body_scope(db);
14901492
let span = Span::from(function_scope.file(db));
14911493
let node = function_scope.node(db);
14921494
if let Some(func_def) = node.as_function() {

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ use crate::types::mro::MroErrorKind;
8282
use crate::types::unpacker::{UnpackResult, Unpacker};
8383
use crate::types::{
8484
binding_type, todo_type, CallDunderError, CallableSignature, CallableType, Class,
85-
ClassLiteralType, ClassType, DataclassParams, DynamicType, FunctionDecorators, FunctionType,
85+
ClassLiteralType, ClassType, DataclassParams, DynamicType, FunctionDecorators, FunctionLiteral,
8686
GenericAlias, GenericClass, IntersectionBuilder, IntersectionType, KnownClass, KnownFunction,
8787
KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter,
8888
ParameterForm, Parameters, Signature, Signatures, SliceLiteralType, StringLiteralType,
@@ -1576,19 +1576,17 @@ impl<'db> TypeInferenceBuilder<'db> {
15761576
.node_scope(NodeWithScopeRef::Function(function))
15771577
.to_scope_id(self.db(), self.file());
15781578

1579-
let inherited_generic_context = None;
1580-
let specialization = None;
1581-
1582-
let mut inferred_ty = Type::FunctionLiteral(FunctionType::new(
1579+
let mut inferred_ty = Type::function_literal(
15831580
self.db(),
1584-
&name.id,
1585-
function_kind,
1586-
body_scope,
1587-
function_decorators,
1588-
dataclass_transformer_params,
1589-
inherited_generic_context,
1590-
specialization,
1591-
));
1581+
FunctionLiteral::new(
1582+
self.db(),
1583+
&name.id,
1584+
function_kind,
1585+
body_scope,
1586+
function_decorators,
1587+
dataclass_transformer_params,
1588+
),
1589+
);
15921590

15931591
for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() {
15941592
inferred_ty = match decorator_ty
@@ -1842,7 +1840,10 @@ impl<'db> TypeInferenceBuilder<'db> {
18421840
}
18431841

18441842
if let Type::FunctionLiteral(f) = decorator_ty {
1845-
if let Some(params) = f.dataclass_transformer_params(self.db()) {
1843+
if let Some(params) = f
1844+
.function(self.db())
1845+
.dataclass_transformer_params(self.db())
1846+
{
18461847
dataclass_params = Some(params.into());
18471848
continue;
18481849
}

0 commit comments

Comments
 (0)