diff --git a/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md b/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md index 5c701b22a64849..40fddeecddf7bb 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md +++ b/crates/red_knot_python_semantic/resources/mdtest/binary/instances.md @@ -363,7 +363,7 @@ reveal_type(X() + Y()) # revealed: int ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 a = NotBoolable() diff --git a/crates/red_knot_python_semantic/resources/mdtest/call/callable_instance.md b/crates/red_knot_python_semantic/resources/mdtest/call/callable_instance.md index 5b6bb368797bbd..4a6b2b1ad1c6d1 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/call/callable_instance.md +++ b/crates/red_knot_python_semantic/resources/mdtest/call/callable_instance.md @@ -71,7 +71,7 @@ def _(flag: bool): a = NonCallable() # error: [call-non-callable] "Object of type `Literal[1]` is not callable" - reveal_type(a()) # revealed: int | Unknown + reveal_type(a()) # revealed: Unknown | int ``` ## Call binding errors diff --git a/crates/red_knot_python_semantic/resources/mdtest/call/union.md b/crates/red_knot_python_semantic/resources/mdtest/call/union.md index 31240b0c1d9ba5..c72196c6954af5 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/call/union.md +++ b/crates/red_knot_python_semantic/resources/mdtest/call/union.md @@ -40,7 +40,7 @@ def _(flag: bool): def f() -> int: return 1 x = f() # error: [call-non-callable] "Object of type `Literal[1]` is not callable" - reveal_type(x) # revealed: int | Unknown + reveal_type(x) # revealed: Unknown | int ``` ## Multiple non-callable elements in a union @@ -58,7 +58,7 @@ def _(flag: bool, flag2: bool): return 1 # TODO we should mention all non-callable elements of the union # error: [call-non-callable] "Object of type `Literal[1]` is not callable" - # revealed: int | Unknown + # revealed: Unknown | int reveal_type(f()) ``` @@ -148,3 +148,16 @@ def _(flag: bool): x = f(3) reveal_type(x) # revealed: Unknown ``` + +## Union including a special-cased function + +```py +def _(flag: bool): + if flag: + f = str + else: + f = repr + reveal_type(str("string")) # revealed: Literal["string"] + reveal_type(repr("string")) # revealed: Literal["'string'"] + reveal_type(f("string")) # revealed: Literal["string", "'string'"] +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md index 4b1617a979fcff..b28d0d04fa9bf4 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/membership_test.md @@ -191,7 +191,7 @@ It may also be more appropriate to use `unsupported-operator` as the error code. ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 class WithContains: def __contains__(self, item) -> NotBoolable: diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md index a0c6680c610c3b..f6fda97032a892 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/instances/rich_comparison.md @@ -355,7 +355,7 @@ element) of a chained comparison. ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 class Comparable: def __lt__(self, item) -> NotBoolable: diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md index 557791790d5114..d4cd80765c7e1b 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md @@ -355,7 +355,7 @@ def compute_chained_comparison(): ```py class NotBoolable: - __bool__ = 5 + __bool__: int = 5 class Comparable: def __lt__(self, other) -> NotBoolable: @@ -387,7 +387,7 @@ class A: return NotBoolable() class NotBoolable: - __bool__ = None + __bool__: None = None # error: [unsupported-bool-conversion] (A(),) == (A(),) diff --git a/crates/red_knot_python_semantic/resources/mdtest/conditional/if_expression.md b/crates/red_knot_python_semantic/resources/mdtest/conditional/if_expression.md index 47696c065840b6..b14d358ea04bd7 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/conditional/if_expression.md +++ b/crates/red_knot_python_semantic/resources/mdtest/conditional/if_expression.md @@ -40,7 +40,7 @@ def _(flag: bool): ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 # error: [unsupported-bool-conversion] "Boolean conversion is unsupported for type `NotBoolable`; its `__bool__` method isn't callable" 3 if NotBoolable() else 4 diff --git a/crates/red_knot_python_semantic/resources/mdtest/conditional/if_statement.md b/crates/red_knot_python_semantic/resources/mdtest/conditional/if_statement.md index fff101842773f3..9a3fc4f8f4c897 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/conditional/if_statement.md +++ b/crates/red_knot_python_semantic/resources/mdtest/conditional/if_statement.md @@ -152,7 +152,7 @@ def _(flag: bool): ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 # error: [unsupported-bool-conversion] "Boolean conversion is unsupported for type `NotBoolable`; its `__bool__` method isn't callable" if NotBoolable(): diff --git a/crates/red_knot_python_semantic/resources/mdtest/conditional/match.md b/crates/red_knot_python_semantic/resources/mdtest/conditional/match.md index 3fe4956b3468a6..2f0ad24fe1e3e0 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/conditional/match.md +++ b/crates/red_knot_python_semantic/resources/mdtest/conditional/match.md @@ -48,7 +48,7 @@ def _(target: int): ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 def _(target: int, flag: NotBoolable): y = 1 diff --git a/crates/red_knot_python_semantic/resources/mdtest/expression/assert.md b/crates/red_knot_python_semantic/resources/mdtest/expression/assert.md index f7e1715246a6d1..54073f9170fe33 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/expression/assert.md +++ b/crates/red_knot_python_semantic/resources/mdtest/expression/assert.md @@ -2,7 +2,7 @@ ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 # error: [unsupported-bool-conversion] "Boolean conversion is unsupported for type `NotBoolable`; its `__bool__` method isn't callable" assert NotBoolable() diff --git a/crates/red_knot_python_semantic/resources/mdtest/expression/boolean.md b/crates/red_knot_python_semantic/resources/mdtest/expression/boolean.md index ccedbac6f4b068..ce3363636d7403 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/expression/boolean.md +++ b/crates/red_knot_python_semantic/resources/mdtest/expression/boolean.md @@ -121,7 +121,7 @@ if NotBoolable(): ```py class NotBoolable: - __bool__ = None + __bool__: None = None # error: [unsupported-bool-conversion] "Boolean conversion is unsupported for type `NotBoolable`; its `__bool__` method isn't callable" if NotBoolable(): @@ -133,9 +133,9 @@ if NotBoolable(): ```py def test(cond: bool): class NotBoolable: - __bool__ = None if cond else 3 + __bool__: int | None = None if cond else 3 - # error: [unsupported-bool-conversion] "Boolean conversion is unsupported for type `NotBoolable`; it incorrectly implements `__bool__`" + # error: [unsupported-bool-conversion] "Boolean conversion is unsupported for type `NotBoolable`; its `__bool__` method isn't callable" if NotBoolable(): ... ``` @@ -145,7 +145,7 @@ def test(cond: bool): ```py def test(cond: bool): class NotBoolable: - __bool__ = None + __bool__: None = None a = 10 if cond else NotBoolable() diff --git a/crates/red_knot_python_semantic/resources/mdtest/loops/while_loop.md b/crates/red_knot_python_semantic/resources/mdtest/loops/while_loop.md index c3da62e064ec4f..397a06b742dacf 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/loops/while_loop.md +++ b/crates/red_knot_python_semantic/resources/mdtest/loops/while_loop.md @@ -121,7 +121,7 @@ def _(flag: bool, flag2: bool): ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 # error: [unsupported-bool-conversion] "Boolean conversion is unsupported for type `NotBoolable`; its `__bool__` method isn't callable" while NotBoolable(): diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/for.md_-_For_loops_-_Possibly_unbound_`__iter__`_and_possibly_invalid_`__getitem__`.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/for.md_-_For_loops_-_Possibly_unbound_`__iter__`_and_possibly_invalid_`__getitem__`.snap index 95953b915a172b..357f420c5627ef 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/snapshots/for.md_-_For_loops_-_Possibly_unbound_`__iter__`_and_possibly_invalid_`__getitem__`.snap +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/for.md_-_For_loops_-_Possibly_unbound_`__iter__`_and_possibly_invalid_`__getitem__`.snap @@ -86,8 +86,7 @@ error: lint:not-iterable | 35 | # error: [not-iterable] 36 | for y in Iterable2(): - | ^^^^^^^^^^^ Object of type `Iterable2` may not be iterable because it may not have an `__iter__` method and its `__getitem__` method (with type ` | `) - may have an incorrect signature for the old-style iteration protocol (expected a signature at least as permissive as `def __getitem__(self, key: int): ...`) + | ^^^^^^^^^^^ Object of type `Iterable2` may not be iterable because it may not have an `__iter__` method and its `__getitem__` method (with type ` | `) may have an incorrect signature for the old-style iteration protocol (expected a signature at least as permissive as `def __getitem__(self, key: int): ...`) 37 | reveal_type(y) # revealed: bytes | str | int | diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/instances.md_-_Binary_operations_on_instances_-_Operations_involving_types_with_invalid_`__bool__`_methods.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/instances.md_-_Binary_operations_on_instances_-_Operations_involving_types_with_invalid_`__bool__`_methods.snap index cd7e6f94b9c9f8..cebcc8765538fb 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/snapshots/instances.md_-_Binary_operations_on_instances_-_Operations_involving_types_with_invalid_`__bool__`_methods.snap +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/instances.md_-_Binary_operations_on_instances_-_Operations_involving_types_with_invalid_`__bool__`_methods.snap @@ -13,7 +13,7 @@ mdtest path: crates/red_knot_python_semantic/resources/mdtest/binary/instances.m ``` 1 | class NotBoolable: -2 | __bool__ = 3 +2 | __bool__: int = 3 3 | 4 | a = NotBoolable() 5 | diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/membership_test.md_-_Comparison___Membership_Test_-_Return_type_that_doesn't_implement_`__bool__`_correctly.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/membership_test.md_-_Comparison___Membership_Test_-_Return_type_that_doesn't_implement_`__bool__`_correctly.snap index d714fd2c18a009..c811afea2fa0df 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/snapshots/membership_test.md_-_Comparison___Membership_Test_-_Return_type_that_doesn't_implement_`__bool__`_correctly.snap +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/membership_test.md_-_Comparison___Membership_Test_-_Return_type_that_doesn't_implement_`__bool__`_correctly.snap @@ -13,7 +13,7 @@ mdtest path: crates/red_knot_python_semantic/resources/mdtest/comparison/instanc ``` 1 | class NotBoolable: - 2 | __bool__ = 3 + 2 | __bool__: int = 3 3 | 4 | class WithContains: 5 | def __contains__(self, item) -> NotBoolable: diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/not.md_-_Unary_not_-_Object_that_implements_`__bool__`_incorrectly.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/not.md_-_Unary_not_-_Object_that_implements_`__bool__`_incorrectly.snap index 8471ca5c59e0f8..3482463acd311c 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/snapshots/not.md_-_Unary_not_-_Object_that_implements_`__bool__`_incorrectly.snap +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/not.md_-_Unary_not_-_Object_that_implements_`__bool__`_incorrectly.snap @@ -13,7 +13,7 @@ mdtest path: crates/red_knot_python_semantic/resources/mdtest/unary/not.md ``` 1 | class NotBoolable: -2 | __bool__ = 3 +2 | __bool__: int = 3 3 | 4 | # error: [unsupported-bool-conversion] 5 | not NotBoolable() diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/rich_comparison.md_-_Comparison___Rich_Comparison_-_Chained_comparisons_with_objects_that_don't_implement_`__bool__`_correctly.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/rich_comparison.md_-_Comparison___Rich_Comparison_-_Chained_comparisons_with_objects_that_don't_implement_`__bool__`_correctly.snap index 87779b02dc3db4..c0004ad58d0f09 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/snapshots/rich_comparison.md_-_Comparison___Rich_Comparison_-_Chained_comparisons_with_objects_that_don't_implement_`__bool__`_correctly.snap +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/rich_comparison.md_-_Comparison___Rich_Comparison_-_Chained_comparisons_with_objects_that_don't_implement_`__bool__`_correctly.snap @@ -13,7 +13,7 @@ mdtest path: crates/red_knot_python_semantic/resources/mdtest/comparison/instanc ``` 1 | class NotBoolable: - 2 | __bool__ = 3 + 2 | __bool__: int = 3 3 | 4 | class Comparable: 5 | def __lt__(self, item) -> NotBoolable: diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/tuples.md_-_Comparison___Tuples_-_Chained_comparisons_with_elements_that_incorrectly_implement_`__bool__`.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/tuples.md_-_Comparison___Tuples_-_Chained_comparisons_with_elements_that_incorrectly_implement_`__bool__`.snap index f0694a0fdace64..b741702c188372 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/snapshots/tuples.md_-_Comparison___Tuples_-_Chained_comparisons_with_elements_that_incorrectly_implement_`__bool__`.snap +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/tuples.md_-_Comparison___Tuples_-_Chained_comparisons_with_elements_that_incorrectly_implement_`__bool__`.snap @@ -13,7 +13,7 @@ mdtest path: crates/red_knot_python_semantic/resources/mdtest/comparison/tuples. ``` 1 | class NotBoolable: - 2 | __bool__ = 5 + 2 | __bool__: int = 5 3 | 4 | class Comparable: 5 | def __lt__(self, other) -> NotBoolable: diff --git a/crates/red_knot_python_semantic/resources/mdtest/snapshots/tuples.md_-_Comparison___Tuples_-_Equality_with_elements_that_incorrectly_implement_`__bool__`.snap b/crates/red_knot_python_semantic/resources/mdtest/snapshots/tuples.md_-_Comparison___Tuples_-_Equality_with_elements_that_incorrectly_implement_`__bool__`.snap index 386af0045668dc..55e6c8fa677890 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/snapshots/tuples.md_-_Comparison___Tuples_-_Equality_with_elements_that_incorrectly_implement_`__bool__`.snap +++ b/crates/red_knot_python_semantic/resources/mdtest/snapshots/tuples.md_-_Comparison___Tuples_-_Equality_with_elements_that_incorrectly_implement_`__bool__`.snap @@ -17,7 +17,7 @@ mdtest path: crates/red_knot_python_semantic/resources/mdtest/comparison/tuples. 3 | return NotBoolable() 4 | 5 | class NotBoolable: -6 | __bool__ = None +6 | __bool__: None = None 7 | 8 | # error: [unsupported-bool-conversion] 9 | (A(),) == (A(),) diff --git a/crates/red_knot_python_semantic/resources/mdtest/unary/not.md b/crates/red_knot_python_semantic/resources/mdtest/unary/not.md index b3b75678f9da90..82f589517af48e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/unary/not.md +++ b/crates/red_knot_python_semantic/resources/mdtest/unary/not.md @@ -210,7 +210,7 @@ reveal_type(not PossiblyUnboundBool()) ```py class NotBoolable: - __bool__ = 3 + __bool__: int = 3 # error: [unsupported-bool-conversion] not NotBoolable() diff --git a/crates/red_knot_python_semantic/src/symbol.rs b/crates/red_knot_python_semantic/src/symbol.rs index 3ba9bc36920de5..53b2c88b155915 100644 --- a/crates/red_knot_python_semantic/src/symbol.rs +++ b/crates/red_knot_python_semantic/src/symbol.rs @@ -15,7 +15,7 @@ use crate::{resolve_module, Db, KnownModule, Module, Program}; pub(crate) use implicit_globals::module_type_implicit_global_symbol; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub(crate) enum Boundness { Bound, PossiblyUnbound, diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 578c0860d7fddc..05e95722ea5cf8 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -2,7 +2,7 @@ use std::hash::Hash; use std::str::FromStr; use bitflags::bitflags; -use call::{CallDunderError, CallError}; +use call::{CallDunderError, CallError, CallErrorKind}; use context::InferContext; use diagnostic::{INVALID_CONTEXT_MANAGER, NOT_ITERABLE}; use ruff_db::files::File; @@ -20,7 +20,7 @@ pub(crate) use self::infer::{ infer_scope_types, }; pub use self::narrow::KnownConstraintFunction; -pub(crate) use self::signatures::{CallableSignature, Signature}; +pub(crate) use self::signatures::{CallableSignature, Signature, Signatures}; pub use self::subclass_of::SubclassOfType; use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module, KnownModule}; @@ -30,7 +30,7 @@ use crate::semantic_index::symbol::ScopeId; use crate::semantic_index::{imported_modules, semantic_index}; use crate::suppression::check_suppressions; use crate::symbol::{imported_symbol, Boundness, Symbol, SymbolAndQualifiers}; -use crate::types::call::{bind_call, CallArguments, CallOutcome, UnionCallError}; +use crate::types::call::{Bindings, CallArguments}; use crate::types::class_base::ClassBase; use crate::types::diagnostic::{INVALID_TYPE_FORM, UNSUPPORTED_BOOL_CONVERSION}; use crate::types::infer::infer_unpack_types; @@ -1690,11 +1690,11 @@ impl<'db> Type<'db> { if let Symbol::Type(descr_get, descr_get_boundness) = descr_get { let return_ty = descr_get .try_call(db, &CallArguments::positional([self, instance, owner])) - .map(|outcome| { + .map(|bindings| { if descr_get_boundness == Boundness::Bound { - outcome.return_type(db) + bindings.return_type(db) } else { - UnionType::from_elements(db, [outcome.return_type(db), self]) + UnionType::from_elements(db, [bindings.return_type(db), self]) } }) .ok()?; @@ -2163,67 +2163,50 @@ impl<'db> Type<'db> { }; match self.try_call_dunder(db, "__bool__", &CallArguments::none()) { - ref result @ (Ok(ref outcome) - | Err(CallDunderError::PossiblyUnbound(ref outcome))) => { + Ok(outcome) => { let return_type = outcome.return_type(db); + if !return_type.is_assignable_to(db, KnownClass::Bool.to_instance(db)) { + // The type has a `__bool__` method, but it doesn't return a + // boolean. + return Err(BoolError::IncorrectReturnType { + return_type, + not_boolable_type: *instance_ty, + }); + } + type_to_truthiness(return_type) + } - // The type has a `__bool__` method, but it doesn't return a boolean. + Err(CallDunderError::PossiblyUnbound(outcome)) => { + let return_type = outcome.return_type(db); if !return_type.is_assignable_to(db, KnownClass::Bool.to_instance(db)) { + // The type has a `__bool__` method, but it doesn't return a + // boolean. return Err(BoolError::IncorrectReturnType { return_type: outcome.return_type(db), not_boolable_type: *instance_ty, }); } - if result.is_ok() { - type_to_truthiness(return_type) - } else { - // Don't trust possibly unbound `__bool__` method. - Truthiness::Ambiguous - } + // Don't trust possibly unbound `__bool__` method. + Truthiness::Ambiguous } - Err(CallDunderError::MethodNotAvailable) => Truthiness::Ambiguous, - Err(CallDunderError::Call(err)) => { - let err = match err { - // Unwrap call errors where only a single variant isn't callable. - // E.g. in the case of `Unknown & T` - // TODO: Improve handling of unions. While this improves messages overall, - // it still results in loosing information. Or should the information - // be recomputed when rendering the diagnostic? - CallError::Union(union_error) => { - if let Type::Union(_) = union_error.called_type { - if union_error.errors.len() == 1 { - union_error.errors.into_vec().pop().unwrap() - } else { - CallError::Union(union_error) - } - } else { - CallError::Union(union_error) - } - } - err => err, - }; - - match err { - CallError::BindingError { binding } => { - return Err(BoolError::IncorrectArguments { - truthiness: type_to_truthiness(binding.return_type()), - not_boolable_type: *instance_ty, - }); - } - CallError::NotCallable { .. } => { - return Err(BoolError::NotCallable { - not_boolable_type: *instance_ty, - }); - } - CallError::PossiblyUnboundDunderCall { .. } - | CallError::Union(..) => { - return Err(BoolError::Other { - not_boolable_type: *self, - }) - } - } + Err(CallDunderError::MethodNotAvailable) => Truthiness::Ambiguous, + Err(CallDunderError::CallError(CallErrorKind::BindingError, bindings)) => { + return Err(BoolError::IncorrectArguments { + truthiness: type_to_truthiness(bindings.return_type(db)), + not_boolable_type: *instance_ty, + }); + } + Err(CallDunderError::CallError(CallErrorKind::NotCallable, _)) => { + return Err(BoolError::NotCallable { + not_boolable_type: *instance_ty, + }); + } + Err(CallDunderError::CallError(CallErrorKind::PossiblyNotCallable, _)) => { + return Err(BoolError::Other { + not_boolable_type: *self, + }) } } } @@ -2318,36 +2301,36 @@ impl<'db> Type<'db> { } let return_ty = match self.try_call_dunder(db, "__len__", &CallArguments::none()) { - Ok(outcome) | Err(CallDunderError::PossiblyUnbound(outcome)) => outcome.return_type(db), + Ok(bindings) => bindings.return_type(db), + Err(CallDunderError::PossiblyUnbound(bindings)) => bindings.return_type(db), // TODO: emit a diagnostic - Err(err) => err.return_type(db)?, + Err(CallDunderError::MethodNotAvailable) => return None, + Err(CallDunderError::CallError(_, bindings)) => bindings.return_type(db), }; non_negative_int_literal(db, return_ty) } - /// Calls `self` + /// Returns the call signatures of a type. /// - /// Returns `Ok` if the call with the given arguments is successful and `Err` otherwise. - fn try_call( - self, - db: &'db dyn Db, - arguments: &CallArguments<'_, 'db>, - ) -> Result, CallError<'db>> { + /// Note that all types have a valid [`Signatures`], even if the type is not callable. + /// Moreover, "callable" can be subtle for a union type, since some union elements might be + /// callable and some not. A union is callable if every element type is callable — and even + /// then, the elements might be inconsistent, such that there's no argument list that's valid + /// for all elements. It's usually best to only worry about "callability" relative to a + /// particular argument list, via [`try_call`][Self::try_call] and + /// [`CallErrorKind::NotCallable`]. + fn signatures(self, db: &'db dyn Db) -> Signatures<'db> { match self { Type::Callable(CallableType::BoundMethod(bound_method)) => { - let instance = bound_method.self_instance(db); - let arguments = arguments.with_self(instance); - let binding = bind_call( - db, - &arguments, - bound_method.function(db).signature(db), - self, - ); - binding.into_outcome() + let signature = bound_method.function(db).signature(db); + let signature = CallableSignature::single(self, signature.clone()) + .with_bound_type(bound_method.self_instance(db)); + Signatures::single(signature) } - Type::Callable(CallableType::MethodWrapperDunderGet(function)) => { + + Type::Callable(CallableType::MethodWrapperDunderGet(_)) => { // Here, we dynamically model the overloaded function signature of `types.FunctionType.__get__`. // This is required because we need to return more precise types than what the signature in // typeshed provides: @@ -2361,10 +2344,10 @@ impl<'db> Type<'db> { // def __get__(self, instance: object, owner: type | None = None, /) -> MethodType: ... // ``` - #[salsa::tracked(return_ref)] - fn overloads<'db>(db: &'db dyn Db) -> CallableSignature<'db> { - let not_none = Type::none(db).negate(db); - CallableSignature::from_overloads([ + let not_none = Type::none(db).negate(db); + let signature = CallableSignature::from_overloads( + self, + [ Signature::new( Parameters::new([ Parameter::new( @@ -2400,40 +2383,11 @@ impl<'db> Type<'db> { ]), None, ), - ]) - } - - let mut binding = bind_call(db, arguments, overloads(db), self); - let Some((_, overload)) = binding.matching_overload_mut() else { - return Err(CallError::BindingError { binding }); - }; - - if function.has_known_class_decorator(db, KnownClass::Classmethod) - && function.decorators(db).len() == 1 - { - if let Some(owner) = arguments.second_argument() { - overload.set_return_type(Type::Callable(CallableType::BoundMethod( - BoundMethodType::new(db, function, owner), - ))); - } else if let Some(instance) = arguments.first_argument() { - overload.set_return_type(Type::Callable(CallableType::BoundMethod( - BoundMethodType::new(db, function, instance.to_meta_type(db)), - ))); - } - } else { - if let Some(first) = arguments.first_argument() { - if first.is_none(db) { - overload.set_return_type(Type::FunctionLiteral(function)); - } else { - overload.set_return_type(Type::Callable(CallableType::BoundMethod( - BoundMethodType::new(db, function, first), - ))); - } - } - } - - binding.into_outcome() + ], + ); + Signatures::single(signature) } + Type::Callable(CallableType::WrapperDescriptorDunderGet) => { // Here, we also model `types.FunctionType.__get__`, but now we consider a call to // this as a function, i.e. we also expect the `self` argument to be passed in. @@ -2441,10 +2395,10 @@ impl<'db> Type<'db> { // TODO: Consider merging this signature with the one in the previous match clause, // since the previous one is just this signature with the `self` parameters // removed. - #[salsa::tracked(return_ref)] - fn overloads<'db>(db: &'db dyn Db) -> CallableSignature<'db> { - let not_none = Type::none(db).negate(db); - CallableSignature::from_overloads([ + let not_none = Type::none(db).negate(db); + let signature = CallableSignature::from_overloads( + self, + [ Signature::new( Parameters::new([ Parameter::new( @@ -2490,85 +2444,314 @@ impl<'db> Type<'db> { ]), None, ), - ]) + ], + ); + Signatures::single(signature) + } + + Type::FunctionLiteral(function_type) => Signatures::single(CallableSignature::single( + self, + function_type.signature(db).clone(), + )), + + Type::ClassLiteral(ClassLiteralType { class }) => match class.known(db) { + Some(KnownClass::Bool) => { + // ```py + // class bool(int): + // def __new__(cls, o: object = ..., /) -> Self: ... + // ``` + let signature = CallableSignature::single( + self, + Signature::new( + Parameters::new([Parameter::new( + Some(Name::new_static("o")), + Some(Type::any()), + ParameterKind::PositionalOnly { + default_ty: Some(Type::BooleanLiteral(false)), + }, + )]), + Some(KnownClass::Bool.to_instance(db)), + ), + ); + Signatures::single(signature) } - let mut binding = bind_call(db, arguments, overloads(db), self); - let Some((_, overload)) = binding.matching_overload_mut() else { - return Err(CallError::BindingError { binding }); - }; + Some(KnownClass::Str) => { + // ```py + // class str(Sequence[str]): + // @overload + // def __new__(cls, object: object = ...) -> Self: ... + // @overload + // def __new__(cls, object: ReadableBuffer, encoding: str = ..., errors: str = ...) -> Self: ... + // ``` + let signature = CallableSignature::from_overloads( + self, + [ + Signature::new( + Parameters::new([Parameter::new( + Some(Name::new_static("o")), + Some(Type::any()), + ParameterKind::PositionalOnly { + default_ty: Some(Type::string_literal(db, "")), + }, + )]), + Some(KnownClass::Str.to_instance(db)), + ), + Signature::new( + Parameters::new([ + Parameter::new( + Some(Name::new_static("o")), + Some(Type::any()), // TODO: ReadableBuffer + ParameterKind::PositionalOnly { default_ty: None }, + ), + Parameter::new( + Some(Name::new_static("encoding")), + Some(KnownClass::Str.to_instance(db)), + ParameterKind::PositionalOnly { default_ty: None }, + ), + Parameter::new( + Some(Name::new_static("errors")), + Some(KnownClass::Str.to_instance(db)), + ParameterKind::PositionalOnly { default_ty: None }, + ), + ]), + Some(KnownClass::Str.to_instance(db)), + ), + ], + ); + Signatures::single(signature) + } - if let Some(function_ty @ Type::FunctionLiteral(function)) = - arguments.first_argument() + Some(KnownClass::Type) => { + // ```py + // class type: + // @overload + // def __init__(self, o: object, /) -> None: ... + // @overload + // def __init__(self, name: str, bases: tuple[type, ...], dict: dict[str, Any], /, **kwds: Any) -> None: ... + // ``` + let signature = CallableSignature::from_overloads( + self, + [ + Signature::new( + Parameters::new([Parameter::new( + Some(Name::new_static("o")), + Some(Type::any()), + ParameterKind::PositionalOnly { default_ty: None }, + )]), + Some(KnownClass::Type.to_instance(db)), + ), + Signature::new( + Parameters::new([ + Parameter::new( + Some(Name::new_static("o")), + Some(Type::any()), + ParameterKind::PositionalOnly { default_ty: None }, + ), + Parameter::new( + Some(Name::new_static("bases")), + Some(Type::any()), + ParameterKind::PositionalOnly { default_ty: None }, + ), + Parameter::new( + Some(Name::new_static("dict")), + Some(Type::any()), + ParameterKind::PositionalOnly { default_ty: None }, + ), + ]), + Some(KnownClass::Type.to_instance(db)), + ), + ], + ); + Signatures::single(signature) + } + + // TODO annotated return type on `__new__` or metaclass `__call__` + // TODO check call vs signatures of `__new__` and/or `__init__` + _ => { + let signature = CallableSignature::single( + self, + Signature::new(Parameters::gradual_form(), self.to_instance(db)), + ); + Signatures::single(signature) + } + }, + + Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() { + ClassBase::Dynamic(dynamic_type) => Type::Dynamic(dynamic_type).signatures(db), + ClassBase::Class(class) => Type::class_literal(class).signatures(db), + }, + + Type::Instance(_) => { + // Note that for objects that have a (possibly not callable!) `__call__` attribute, + // we will get the signature of the `__call__` attribute, but will pass in the type + // of the original object as the "callable type". That ensures that we get errors + // like "`X` is not callable" instead of "`` is not + // callable". + match self + .member_lookup_with_policy( + db, + Name::new_static("__call__"), + MemberLookupPolicy::NoInstanceFallback, + ) + .symbol { + Symbol::Type(dunder_callable, boundness) => { + let mut signatures = dunder_callable.signatures(db).clone(); + signatures.replace_callable_type(dunder_callable, self); + if boundness == Boundness::PossiblyUnbound { + signatures.set_dunder_call_is_possibly_unbound(); + } + signatures + } + Symbol::Unbound => Signatures::not_callable(self), + } + } + + // Dynamic types are callable, and the return type is the same dynamic type. Similarly, + // `Never` is always callable and returns `Never`. + Type::Dynamic(_) | Type::Never => Signatures::single(CallableSignature::dynamic(self)), + + // Note that this correctly returns `None` if none of the union elements are callable. + Type::Union(union) => Signatures::from_union( + self, + union + .elements(db) + .iter() + .map(|element| element.signatures(db)), + ), + + Type::Intersection(_) => { + Signatures::single(CallableSignature::todo("Type::Intersection.call()")) + } + + _ => Signatures::not_callable(self), + } + } + + /// Calls `self`. Returns a [`CallError`] if `self` is (always or possibly) not callable, or if + /// the arguments are not compatible with the formal parameters. + /// + /// You get back a [`Bindings`] for both successful and unsuccessful calls. + /// It contains information about which formal parameters each argument was matched to, + /// and about any errors matching arguments and parameters. + fn try_call( + self, + db: &'db dyn Db, + arguments: &CallArguments<'_, 'db>, + ) -> Result, CallError<'db>> { + let signatures = self.signatures(db); + let mut bindings = Bindings::bind(db, &signatures, arguments)?; + for binding in &mut bindings { + // For certain known callables, we have special-case logic to determine the return type + // in a way that isn't directly expressible in the type system. Each special case + // listed here should have a corresponding clause above in `signatures`. + let binding_type = binding.callable_type; + let Some((overload_index, overload)) = binding.matching_overload_mut() else { + continue; + }; + + match binding_type { + Type::Callable(CallableType::MethodWrapperDunderGet(function)) => { if function.has_known_class_decorator(db, KnownClass::Classmethod) && function.decorators(db).len() == 1 { - if let Some(owner) = arguments.third_argument() { + if let Some(owner) = arguments.second_argument() { overload.set_return_type(Type::Callable(CallableType::BoundMethod( BoundMethodType::new(db, function, owner), ))); - } else if let Some(instance) = arguments.second_argument() { + } else if let Some(instance) = arguments.first_argument() { overload.set_return_type(Type::Callable(CallableType::BoundMethod( BoundMethodType::new(db, function, instance.to_meta_type(db)), ))); } - } else { - match (arguments.second_argument(), arguments.third_argument()) { - (Some(instance), _) if instance.is_none(db) => { - overload.set_return_type(function_ty); - } - - ( - Some(Type::KnownInstance(KnownInstanceType::TypeAliasType( - type_alias, - ))), - Some(Type::ClassLiteral(ClassLiteralType { class })), - ) if class.is_known(db, KnownClass::TypeAliasType) - && function.name(db) == "__name__" => - { - overload - .set_return_type(Type::string_literal(db, type_alias.name(db))); - } - - ( - Some(Type::KnownInstance(KnownInstanceType::TypeVar(typevar))), - Some(Type::ClassLiteral(ClassLiteralType { class })), - ) if class.is_known(db, KnownClass::TypeVar) - && function.name(db) == "__name__" => - { - overload - .set_return_type(Type::string_literal(db, typevar.name(db))); - } - - (Some(_), _) - if function.has_known_class_decorator(db, KnownClass::Property) => - { - overload.set_return_type(todo_type!("@property")); - } + } else if let Some(first) = arguments.first_argument() { + if first.is_none(db) { + overload.set_return_type(Type::FunctionLiteral(function)); + } else { + overload.set_return_type(Type::Callable(CallableType::BoundMethod( + BoundMethodType::new(db, function, first), + ))); + } + } + } - (Some(instance), _) => { + Type::Callable(CallableType::WrapperDescriptorDunderGet) => { + if let Some(function_ty @ Type::FunctionLiteral(function)) = + arguments.first_argument() + { + if function.has_known_class_decorator(db, KnownClass::Classmethod) + && function.decorators(db).len() == 1 + { + if let Some(owner) = arguments.third_argument() { + overload.set_return_type(Type::Callable( + CallableType::BoundMethod(BoundMethodType::new( + db, function, owner, + )), + )); + } else if let Some(instance) = arguments.second_argument() { overload.set_return_type(Type::Callable( CallableType::BoundMethod(BoundMethodType::new( - db, function, instance, + db, + function, + instance.to_meta_type(db), )), )); } + } else { + match (arguments.second_argument(), arguments.third_argument()) { + (Some(instance), _) if instance.is_none(db) => { + overload.set_return_type(function_ty); + } + + ( + Some(Type::KnownInstance(KnownInstanceType::TypeAliasType( + type_alias, + ))), + Some(Type::ClassLiteral(ClassLiteralType { class })), + ) if class.is_known(db, KnownClass::TypeAliasType) + && function.name(db) == "__name__" => + { + overload.set_return_type(Type::string_literal( + db, + type_alias.name(db), + )); + } + + ( + Some(Type::KnownInstance(KnownInstanceType::TypeVar(typevar))), + Some(Type::ClassLiteral(ClassLiteralType { class })), + ) if class.is_known(db, KnownClass::TypeVar) + && function.name(db) == "__name__" => + { + overload.set_return_type(Type::string_literal( + db, + typevar.name(db), + )); + } + + (Some(_), _) + if function + .has_known_class_decorator(db, KnownClass::Property) => + { + overload.set_return_type(todo_type!("@property")); + } - (None, _) => {} + (Some(instance), _) => { + overload.set_return_type(Type::Callable( + CallableType::BoundMethod(BoundMethodType::new( + db, function, instance, + )), + )); + } + + (None, _) => {} + } } } } - binding.into_outcome() - } - Type::FunctionLiteral(function_type) => { - let mut binding = bind_call(db, arguments, function_type.signature(db), self); - let Some((_, overload)) = binding.matching_overload_mut() else { - return Err(CallError::BindingError { binding }); - }; - - match function_type.known(db) { + Type::FunctionLiteral(function_type) => match function_type.known(db) { Some(KnownFunction::IsEquivalentTo) => { if let [ty_a, ty_b] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral( @@ -2576,6 +2759,7 @@ impl<'db> Type<'db> { )); } } + Some(KnownFunction::IsSubtypeOf) => { if let [ty_a, ty_b] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral( @@ -2583,6 +2767,7 @@ impl<'db> Type<'db> { )); } } + Some(KnownFunction::IsAssignableTo) => { if let [ty_a, ty_b] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral( @@ -2590,6 +2775,7 @@ impl<'db> Type<'db> { )); } } + Some(KnownFunction::IsDisjointFrom) => { if let [ty_a, ty_b] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral( @@ -2597,6 +2783,7 @@ impl<'db> Type<'db> { )); } } + Some(KnownFunction::IsGradualEquivalentTo) => { if let [ty_a, ty_b] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral( @@ -2604,16 +2791,19 @@ impl<'db> Type<'db> { )); } } + Some(KnownFunction::IsFullyStatic) => { if let [ty] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral(ty.is_fully_static(db))); } } + Some(KnownFunction::IsSingleton) => { if let [ty] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral(ty.is_singleton(db))); } } + Some(KnownFunction::IsSingleValued) => { if let [ty] = overload.parameter_types() { overload.set_return_type(Type::BooleanLiteral(ty.is_single_valued(db))); @@ -2649,11 +2839,11 @@ impl<'db> Type<'db> { Some(KnownFunction::GetattrStatic) => { let [instance_ty, attr_name, default] = overload.parameter_types() else { - return binding.into_outcome(); + continue; }; let Some(attr_name) = attr_name.into_string_literal() else { - return binding.into_outcome(); + continue; }; let default = if default.is_unknown() { @@ -2688,238 +2878,42 @@ impl<'db> Type<'db> { } _ => {} - }; - - binding.into_outcome() - } - - Type::ClassLiteral(ClassLiteralType { class }) - if class.is_known(db, KnownClass::Bool) => - { - // ```py - // class bool(int): - // def __new__(cls, o: object = ..., /) -> Self: ... - // ``` - #[salsa::tracked(return_ref)] - fn overloads<'db>(db: &'db dyn Db) -> CallableSignature<'db> { - Signature::new( - Parameters::new([Parameter::new( - Some(Name::new_static("o")), - Some(Type::any()), - ParameterKind::PositionalOnly { - default_ty: Some(Type::BooleanLiteral(false)), - }, - )]), - Some(KnownClass::Bool.to_instance(db)), - ) - .into() - } - - let mut binding = bind_call(db, arguments, overloads(db), self); - let Some((_, overload)) = binding.matching_overload_mut() else { - return Err(CallError::BindingError { binding }); - }; - overload.set_return_type( - arguments - .first_argument() - .map(|arg| arg.bool(db).into_type(db)) - .unwrap_or(Type::BooleanLiteral(false)), - ); - binding.into_outcome() - } - - Type::ClassLiteral(ClassLiteralType { class }) - if class.is_known(db, KnownClass::Str) => - { - // ```py - // class str(Sequence[str]): - // @overload - // def __new__(cls, object: object = ...) -> Self: ... - // @overload - // def __new__(cls, object: ReadableBuffer, encoding: str = ..., errors: str = ...) -> Self: ... - // ``` - #[salsa::tracked(return_ref)] - fn overloads<'db>(db: &'db dyn Db) -> CallableSignature<'db> { - CallableSignature::from_overloads([ - Signature::new( - Parameters::new([Parameter::new( - Some(Name::new_static("o")), - Some(Type::any()), - ParameterKind::PositionalOnly { - default_ty: Some(Type::string_literal(db, "")), - }, - )]), - Some(KnownClass::Str.to_instance(db)), - ), - Signature::new( - Parameters::new([ - Parameter::new( - Some(Name::new_static("o")), - Some(Type::any()), // TODO: ReadableBuffer - ParameterKind::PositionalOnly { default_ty: None }, - ), - Parameter::new( - Some(Name::new_static("encoding")), - Some(KnownClass::Str.to_instance(db)), - ParameterKind::PositionalOnly { default_ty: None }, - ), - Parameter::new( - Some(Name::new_static("errors")), - Some(KnownClass::Str.to_instance(db)), - ParameterKind::PositionalOnly { default_ty: None }, - ), - ]), - Some(KnownClass::Str.to_instance(db)), - ), - ]) - } - - let mut binding = bind_call(db, arguments, overloads(db), self); - let Some((index, overload)) = binding.matching_overload_mut() else { - return Err(CallError::BindingError { binding }); - }; - if index == 0 { - overload.set_return_type( - arguments - .first_argument() - .map(|arg| arg.str(db)) - .unwrap_or_else(|| Type::string_literal(db, "")), - ); - } - binding.into_outcome() - } - - Type::ClassLiteral(ClassLiteralType { class }) - if class.is_known(db, KnownClass::Type) => - { - // ```py - // class type: - // @overload - // def __init__(self, o: object, /) -> None: ... - // @overload - // def __init__(self, name: str, bases: tuple[type, ...], dict: dict[str, Any], /, **kwds: Any) -> None: ... - // ``` - #[salsa::tracked(return_ref)] - fn overloads<'db>(db: &'db dyn Db) -> CallableSignature<'db> { - CallableSignature::from_overloads([ - Signature::new( - Parameters::new([Parameter::new( - Some(Name::new_static("o")), - Some(Type::any()), - ParameterKind::PositionalOnly { default_ty: None }, - )]), - Some(KnownClass::Type.to_instance(db)), - ), - Signature::new( - Parameters::new([ - Parameter::new( - Some(Name::new_static("o")), - Some(Type::any()), - ParameterKind::PositionalOnly { default_ty: None }, - ), - Parameter::new( - Some(Name::new_static("bases")), - Some(Type::any()), - ParameterKind::PositionalOnly { default_ty: None }, - ), - Parameter::new( - Some(Name::new_static("dict")), - Some(Type::any()), - ParameterKind::PositionalOnly { default_ty: None }, - ), - ]), - Some(KnownClass::Type.to_instance(db)), - ), - ]) - } + }, - let mut binding = bind_call(db, arguments, overloads(db), self); - let Some((index, overload)) = binding.matching_overload_mut() else { - return Err(CallError::BindingError { binding }); - }; - if index == 0 { - if let Some(arg) = arguments.first_argument() { - overload.set_return_type(arg.to_meta_type(db)); + Type::ClassLiteral(ClassLiteralType { class }) => match class.known(db) { + Some(KnownClass::Bool) => { + overload.set_return_type( + arguments + .first_argument() + .map(|arg| arg.bool(db).into_type(db)) + .unwrap_or(Type::BooleanLiteral(false)), + ); } - } - binding.into_outcome() - } - // TODO annotated return type on `__new__` or metaclass `__call__` - // TODO check call vs signatures of `__new__` and/or `__init__` - Type::ClassLiteral(ClassLiteralType { .. }) => { - let signature = Signature::new(Parameters::gradual_form(), self.to_instance(db)); - let binding = bind_call(db, arguments, &signature.into(), self); - binding.into_outcome() - } - - Type::SubclassOf(subclass_of_type) => match subclass_of_type.subclass_of() { - ClassBase::Dynamic(dynamic_type) => { - Type::Dynamic(dynamic_type).try_call(db, arguments) - } - ClassBase::Class(class) => Type::class_literal(class).try_call(db, arguments), - }, + Some(KnownClass::Str) if overload_index == 0 => { + overload.set_return_type( + arguments + .first_argument() + .map(|arg| arg.str(db)) + .unwrap_or_else(|| Type::string_literal(db, "")), + ); + } - instance_ty @ Type::Instance(_) => { - instance_ty - .try_call_dunder(db, "__call__", arguments) - .map_err(|err| match err { - CallDunderError::Call(CallError::NotCallable { .. }) => { - // Turn "`` not callable" into - // "`X` not callable" - CallError::NotCallable { - not_callable_type: self, - } - } - CallDunderError::Call(CallError::Union(UnionCallError { - called_type: _, - bindings, - errors, - })) => CallError::Union(UnionCallError { - called_type: self, - bindings, - errors, - }), - CallDunderError::Call(error) => error, - // Turn "possibly unbound object of type `Literal['__call__']`" - // into "`X` not callable (possibly unbound `__call__` method)" - CallDunderError::PossiblyUnbound(outcome) => { - CallError::PossiblyUnboundDunderCall { - called_type: self, - outcome: Box::new(outcome), - } - } - CallDunderError::MethodNotAvailable => { - // Turn "`X.__call__` unbound" into "`X` not callable" - CallError::NotCallable { - not_callable_type: self, - } + Some(KnownClass::Type) if overload_index == 0 => { + if let Some(arg) = arguments.first_argument() { + overload.set_return_type(arg.to_meta_type(db)); } - }) - } - - // Dynamic types are callable, and the return type is the same dynamic type. Similarly, - // `Never` is always callable and returns `Never`. - Type::Dynamic(_) | Type::Never => { - let overloads = CallableSignature::dynamic(self); - let binding = bind_call(db, arguments, &overloads, self); - binding.into_outcome() - } + } - Type::Union(union) => { - CallOutcome::try_call_union(db, union, |element| element.try_call(db, arguments)) - } + _ => {} + }, - Type::Intersection(_) => { - let overloads = CallableSignature::todo("Type::Intersection.call()"); - let binding = bind_call(db, arguments, &overloads, self); - binding.into_outcome() + // Not a special case + _ => {} } - - _ => Err(CallError::NotCallable { - not_callable_type: self, - }), } + + Ok(bindings) } /// Look up a dunder method on the meta-type of `self` and call it. @@ -2931,19 +2925,18 @@ impl<'db> Type<'db> { db: &'db dyn Db, name: &str, arguments: &CallArguments<'_, 'db>, - ) -> Result, CallDunderError<'db>> { + ) -> Result, CallDunderError<'db>> { match self .member_lookup_with_policy(db, name.into(), MemberLookupPolicy::NoInstanceFallback) .symbol { - Symbol::Type(dunder_callbable, boundness) => { - let result = dunder_callbable.try_call(db, arguments)?; - - if boundness == Boundness::Bound { - Ok(result) - } else { - Err(CallDunderError::PossiblyUnbound(result)) + Symbol::Type(dunder_callable, boundness) => { + let signatures = dunder_callable.signatures(db); + let bindings = Bindings::bind(db, &signatures, arguments)?; + if boundness == Boundness::PossiblyUnbound { + return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); } + Ok(bindings) } Symbol::Unbound => Err(CallDunderError::MethodNotAvailable), } @@ -3038,8 +3031,8 @@ impl<'db> Type<'db> { } // `__iter__` is definitely bound but it can't be called with the expected arguments - Err(CallDunderError::Call(dunder_iter_call_error)) => { - Err(IterationErrorKind::IterCallError(dunder_iter_call_error)) + Err(CallDunderError::CallError(kind, bindings)) => { + Err(IterationErrorKind::IterCallError(kind, bindings)) } // There's no `__iter__` method. Try `__getitem__` instead... @@ -3702,7 +3695,8 @@ impl<'db> ContextManagerErrorKind<'db> { CallDunderError::PossiblyUnbound(call_outcome) => { Some(call_outcome.return_type(db)) } - CallDunderError::Call(call_error) => call_error.return_type(db), + CallDunderError::CallError(CallErrorKind::NotCallable, _) => None, + CallDunderError::CallError(_, bindings) => Some(bindings.return_type(db)), CallDunderError::MethodNotAvailable => None, }, } @@ -3723,7 +3717,9 @@ impl<'db> ContextManagerErrorKind<'db> { // TODO: Use more specific error messages for the different error cases. // E.g. hint toward the union variant that doesn't correctly implement enter, // distinguish between a not callable `__enter__` attribute and a wrong signature. - CallDunderError::Call(_) => format!("it does not correctly implement `{name}`"), + CallDunderError::CallError(_, _) => { + format!("it does not correctly implement `{name}`") + } } }; @@ -3738,7 +3734,7 @@ impl<'db> ContextManagerErrorKind<'db> { (CallDunderError::MethodNotAvailable, CallDunderError::MethodNotAvailable) => { format!("it does not implement `{name_a}` and `{name_b}`") } - (CallDunderError::Call(_), CallDunderError::Call(_)) => { + (CallDunderError::CallError(_, _), CallDunderError::CallError(_, _)) => { format!("it does not correctly implement `{name_a}` or `{name_b}`") } (_, _) => format!( @@ -3805,7 +3801,7 @@ impl<'db> IterationError<'db> { enum IterationErrorKind<'db> { /// The object being iterated over has a bound `__iter__` method, /// but calling it with the expected arguments results in an error. - IterCallError(CallError<'db>), + IterCallError(CallErrorKind, Box>), /// The object being iterated over has a bound `__iter__` method that can be called /// with the expected types, but it returns an object that is not a valid iterator. @@ -3845,8 +3841,8 @@ impl<'db> IterationErrorKind<'db> { dunder_next_error, .. } => dunder_next_error.return_type(db), - Self::IterCallError(dunder_iter_call_error) => dunder_iter_call_error - .fallback_return_type(db) + Self::IterCallError(_, dunder_iter_bindings) => dunder_iter_bindings + .return_type(db) .try_call_dunder(db, "__next__", &CallArguments::none()) .map(|dunder_next_outcome| Some(dunder_next_outcome.return_type(db))) .unwrap_or_else(|dunder_next_call_error| dunder_next_call_error.return_type(db)), @@ -3862,15 +3858,14 @@ impl<'db> IterationErrorKind<'db> { [*dunder_next_return, dunder_getitem_outcome.return_type(db)], )) } - CallDunderError::Call(dunder_getitem_call_error) => Some( - dunder_getitem_call_error - .return_type(db) - .map(|dunder_getitem_return| { - let elements = [*dunder_next_return, dunder_getitem_return]; - UnionType::from_elements(db, elements) - }) - .unwrap_or(*dunder_next_return), - ), + CallDunderError::CallError(CallErrorKind::NotCallable, _) => { + Some(*dunder_next_return) + } + CallDunderError::CallError(_, dunder_getitem_bindings) => { + let dunder_getitem_return = dunder_getitem_bindings.return_type(db); + let elements = [*dunder_next_return, dunder_getitem_return]; + Some(UnionType::from_elements(db, elements)) + } }, Self::UnboundIterAndGetitemError { @@ -3897,46 +3892,44 @@ impl<'db> IterationErrorKind<'db> { // or similar, rather than as part of the same sentence as the error message. match self { - Self::IterCallError(dunder_iter_call_error) => match dunder_iter_call_error { - CallError::NotCallable { not_callable_type } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` is not iterable \ - because its `__iter__` attribute has type `{dunder_iter_type}`, \ - which is not callable", - iterable_type = iterable_type.display(db), - dunder_iter_type = not_callable_type.display(db), - )), - CallError::PossiblyUnboundDunderCall { called_type, .. } => { - report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because its `__iter__` attribute (with type `{dunder_iter_type}`) \ - may not be callable", - iterable_type = iterable_type.display(db), - dunder_iter_type = called_type.display(db), - )); - } - CallError::Union(union_call_error) if union_call_error.indicates_type_possibly_not_callable() => { - report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because its `__iter__` attribute (with type `{dunder_iter_type}`) \ - may not be callable", - iterable_type = iterable_type.display(db), - dunder_iter_type = union_call_error.called_type.display(db), - )); - } - CallError::BindingError { .. } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` is not iterable \ - because its `__iter__` method has an invalid signature \ - (expected `def __iter__(self): ...`)", - iterable_type = iterable_type.display(db), - )), - CallError::Union(UnionCallError { called_type, .. }) => report_not_iterable(format_args!( + Self::IterCallError(CallErrorKind::NotCallable, bindings) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` is not iterable \ + because its `__iter__` attribute has type `{dunder_iter_type}`, \ + which is not callable", + iterable_type = iterable_type.display(db), + dunder_iter_type = bindings.callable_type.display(db), + )), + Self::IterCallError(CallErrorKind::PossiblyNotCallable, bindings) if bindings.is_single() => { + report_not_iterable(format_args!( "Object of type `{iterable_type}` may not be iterable \ - because its `__iter__` method (with type `{dunder_iter_type}`) \ - may have an invalid signature (expected `def __iter__(self): ...`)", + because its `__iter__` attribute (with type `{dunder_iter_type}`) \ + may not be callable", iterable_type = iterable_type.display(db), - dunder_iter_type = called_type.display(db), - )), + dunder_iter_type = bindings.callable_type.display(db), + )); } + Self::IterCallError(CallErrorKind::PossiblyNotCallable, bindings) => { + report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because its `__iter__` attribute (with type `{dunder_iter_type}`) \ + may not be callable", + iterable_type = iterable_type.display(db), + dunder_iter_type = bindings.callable_type.display(db), + )); + } + Self::IterCallError(CallErrorKind::BindingError, bindings) if bindings.is_single() => report_not_iterable(format_args!( + "Object of type `{iterable_type}` is not iterable \ + because its `__iter__` method has an invalid signature \ + (expected `def __iter__(self): ...`)", + iterable_type = iterable_type.display(db), + )), + Self::IterCallError(CallErrorKind::BindingError, bindings) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because its `__iter__` method (with type `{dunder_iter_type}`) \ + may have an invalid signature (expected `def __iter__(self): ...`)", + iterable_type = iterable_type.display(db), + dunder_iter_type = bindings.callable_type.display(db), + )), Self::IterReturnsInvalidIterator { iterator, @@ -3956,45 +3949,34 @@ impl<'db> IterationErrorKind<'db> { iterable_type = iterable_type.display(db), iterator_type = iterator.display(db), )), - CallDunderError::Call(dunder_next_call_error) => match dunder_next_call_error { - CallError::NotCallable { .. } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` is not iterable \ - because its `__iter__` method returns an object of type `{iterator_type}`, \ - which has a `__next__` attribute that is not callable", - iterable_type = iterable_type.display(db), - iterator_type = iterator.display(db), - )), - CallError::PossiblyUnboundDunderCall { .. } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because its `__iter__` method returns an object of type `{iterator_type}`, \ - which has a `__next__` attribute that may not be callable", - iterable_type = iterable_type.display(db), - iterator_type = iterator.display(db), - )), - CallError::Union(union_call_error) if union_call_error.indicates_type_possibly_not_callable() => { - report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because its `__iter__` method returns an object of type `{iterator_type}`, \ - which has a `__next__` attribute that may not be callable", - iterable_type = iterable_type.display(db), - iterator_type = iterator.display(db), - )); - } - CallError::BindingError { .. } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` is not iterable \ - because its `__iter__` method returns an object of type `{iterator_type}`, \ - which has an invalid `__next__` method (expected `def __next__(self): ...`)", - iterable_type = iterable_type.display(db), - iterator_type = iterator.display(db), - )), - CallError::Union(_) => report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because its `__iter__` method returns an object of type `{iterator_type}`, \ - which may have an invalid `__next__` method (expected `def __next__(self): ...`)", - iterable_type = iterable_type.display(db), - iterator_type = iterator.display(db), - )), - } + CallDunderError::CallError(CallErrorKind::NotCallable, _) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` is not iterable \ + because its `__iter__` method returns an object of type `{iterator_type}`, \ + which has a `__next__` attribute that is not callable", + iterable_type = iterable_type.display(db), + iterator_type = iterator.display(db), + )), + CallDunderError::CallError(CallErrorKind::PossiblyNotCallable, _) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because its `__iter__` method returns an object of type `{iterator_type}`, \ + which has a `__next__` attribute that may not be callable", + iterable_type = iterable_type.display(db), + iterator_type = iterator.display(db), + )), + CallDunderError::CallError(CallErrorKind::BindingError, bindings) if bindings.is_single() => report_not_iterable(format_args!( + "Object of type `{iterable_type}` is not iterable \ + because its `__iter__` method returns an object of type `{iterator_type}`, \ + which has an invalid `__next__` method (expected `def __next__(self): ...`)", + iterable_type = iterable_type.display(db), + iterator_type = iterator.display(db), + )), + CallDunderError::CallError(CallErrorKind::BindingError, _) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because its `__iter__` method returns an object of type `{iterator_type}`, \ + which may have an invalid `__next__` method (expected `def __next__(self): ...`)", + iterable_type = iterable_type.display(db), + iterator_type = iterator.display(db), + )), } Self::PossiblyUnboundIterAndGetitemError { @@ -4011,51 +3993,49 @@ impl<'db> IterationErrorKind<'db> { because it may not have an `__iter__` method or a `__getitem__` method", iterable_type.display(db) )), - CallDunderError::Call(dunder_getitem_call_error) => match dunder_getitem_call_error { - CallError::NotCallable { not_callable_type } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because it may not have an `__iter__` method \ - and its `__getitem__` attribute has type `{dunder_getitem_type}`, \ - which is not callable", - iterable_type = iterable_type.display(db), - dunder_getitem_type = not_callable_type.display(db), - )), - CallError::PossiblyUnboundDunderCall { .. } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because it may not have an `__iter__` method \ - and its `__getitem__` attribute may not be callable", - iterable_type = iterable_type.display(db), - )), - CallError::Union(union_call_error) if union_call_error.indicates_type_possibly_not_callable() => { - report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because it may not have an `__iter__` method \ - and its `__getitem__` attribute (with type `{dunder_getitem_type}`) \ - may not be callable", - iterable_type = iterable_type.display(db), - dunder_getitem_type = union_call_error.called_type.display(db), - )); - } - CallError::BindingError { .. } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because it may not have an `__iter__` method \ - and its `__getitem__` method has an incorrect signature \ - for the old-style iteration protocol \ - (expected a signature at least as permissive as \ - `def __getitem__(self, key: int): ...`)", - iterable_type = iterable_type.display(db), - )), - CallError::Union(UnionCallError {called_type, ..})=> report_not_iterable(format_args!( + CallDunderError::CallError(CallErrorKind::NotCallable, bindings) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because it may not have an `__iter__` method \ + and its `__getitem__` attribute has type `{dunder_getitem_type}`, \ + which is not callable", + iterable_type = iterable_type.display(db), + dunder_getitem_type = bindings.callable_type.display(db), + )), + CallDunderError::CallError(CallErrorKind::PossiblyNotCallable, bindings) if bindings.is_single() => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because it may not have an `__iter__` method \ + and its `__getitem__` attribute may not be callable", + iterable_type = iterable_type.display(db), + )), + CallDunderError::CallError(CallErrorKind::PossiblyNotCallable, bindings) => { + report_not_iterable(format_args!( "Object of type `{iterable_type}` may not be iterable \ because it may not have an `__iter__` method \ - and its `__getitem__` method (with type `{dunder_getitem_type}`) - may have an incorrect signature for the old-style iteration protocol \ - (expected a signature at least as permissive as \ - `def __getitem__(self, key: int): ...`)", + and its `__getitem__` attribute (with type `{dunder_getitem_type}`) \ + may not be callable", iterable_type = iterable_type.display(db), - dunder_getitem_type = called_type.display(db), - )), + dunder_getitem_type = bindings.callable_type.display(db), + )); } + CallDunderError::CallError(CallErrorKind::BindingError, bindings) if bindings.is_single() => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because it may not have an `__iter__` method \ + and its `__getitem__` method has an incorrect signature \ + for the old-style iteration protocol \ + (expected a signature at least as permissive as \ + `def __getitem__(self, key: int): ...`)", + iterable_type = iterable_type.display(db), + )), + CallDunderError::CallError(CallErrorKind::BindingError, bindings) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because it may not have an `__iter__` method \ + and its `__getitem__` method (with type `{dunder_getitem_type}`) \ + may have an incorrect signature for the old-style iteration protocol \ + (expected a signature at least as permissive as \ + `def __getitem__(self, key: int): ...`)", + iterable_type = iterable_type.display(db), + dunder_getitem_type = bindings.callable_type.display(db), + )), } Self::UnboundIterAndGetitemError { dunder_getitem_error } => match dunder_getitem_error { @@ -4069,50 +4049,48 @@ impl<'db> IterationErrorKind<'db> { and it may not have a `__getitem__` method", iterable_type.display(db) )), - CallDunderError::Call(dunder_getitem_call_error) => match dunder_getitem_call_error { - CallError::NotCallable { not_callable_type } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` is not iterable \ - because it has no `__iter__` method and \ - its `__getitem__` attribute has type `{dunder_getitem_type}`, \ - which is not callable", - iterable_type = iterable_type.display(db), - dunder_getitem_type = not_callable_type.display(db), - )), - CallError::PossiblyUnboundDunderCall { .. } => report_not_iterable(format_args!( + CallDunderError::CallError(CallErrorKind::NotCallable, bindings) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` is not iterable \ + because it has no `__iter__` method and \ + its `__getitem__` attribute has type `{dunder_getitem_type}`, \ + which is not callable", + iterable_type = iterable_type.display(db), + dunder_getitem_type = bindings.callable_type.display(db), + )), + CallDunderError::CallError(CallErrorKind::PossiblyNotCallable, bindings) if bindings.is_single() => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because it has no `__iter__` method and its `__getitem__` attribute \ + may not be callable", + iterable_type = iterable_type.display(db), + )), + CallDunderError::CallError(CallErrorKind::PossiblyNotCallable, bindings) => { + report_not_iterable(format_args!( "Object of type `{iterable_type}` may not be iterable \ because it has no `__iter__` method and its `__getitem__` attribute \ - may not be callable", + (with type `{dunder_getitem_type}`) may not be callable", iterable_type = iterable_type.display(db), - )), - CallError::Union(union_call_error) if union_call_error.indicates_type_possibly_not_callable() => { - report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because it has no `__iter__` method and its `__getitem__` attribute \ - (with type `{dunder_getitem_type}`) may not be callable", - iterable_type = iterable_type.display(db), - dunder_getitem_type = union_call_error.called_type.display(db), - )); - } - CallError::BindingError { .. } => report_not_iterable(format_args!( - "Object of type `{iterable_type}` is not iterable \ - because it has no `__iter__` method and \ - its `__getitem__` method has an incorrect signature \ - for the old-style iteration protocol \ - (expected a signature at least as permissive as \ - `def __getitem__(self, key: int): ...`)", - iterable_type = iterable_type.display(db), - )), - CallError::Union(UnionCallError { called_type, .. }) => report_not_iterable(format_args!( - "Object of type `{iterable_type}` may not be iterable \ - because it has no `__iter__` method and \ - its `__getitem__` method (with type `{dunder_getitem_type}`) \ - may have an incorrect signature for the old-style iteration protocol \ - (expected a signature at least as permissive as \ - `def __getitem__(self, key: int): ...`)", - iterable_type = iterable_type.display(db), - dunder_getitem_type = called_type.display(db), - )), + dunder_getitem_type = bindings.callable_type.display(db), + )); } + CallDunderError::CallError(CallErrorKind::BindingError, bindings) if bindings.is_single() => report_not_iterable(format_args!( + "Object of type `{iterable_type}` is not iterable \ + because it has no `__iter__` method and \ + its `__getitem__` method has an incorrect signature \ + for the old-style iteration protocol \ + (expected a signature at least as permissive as \ + `def __getitem__(self, key: int): ...`)", + iterable_type = iterable_type.display(db), + )), + CallDunderError::CallError(CallErrorKind::BindingError, bindings) => report_not_iterable(format_args!( + "Object of type `{iterable_type}` may not be iterable \ + because it has no `__iter__` method and \ + its `__getitem__` method (with type `{dunder_getitem_type}`) \ + may have an incorrect signature for the old-style iteration protocol \ + (expected a signature at least as permissive as \ + `def __getitem__(self, key: int): ...`)", + iterable_type = iterable_type.display(db), + dunder_getitem_type = bindings.callable_type.display(db), + )), } } } @@ -4343,10 +4321,11 @@ impl<'db> FunctionType<'db> { /// /// Returns `None` if the function is overloaded. This powers the `CallableTypeFromFunction` /// special form from the `knot_extensions` module. - pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Option> { - // TODO: Add support for overloaded callables; return `Type`, not `Option`. - Some(Type::Callable(CallableType::General( - GeneralCallableType::new(db, self.signature(db).as_single()?.clone()), + pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { + // TODO: Add support for overloaded callables + Type::Callable(CallableType::General(GeneralCallableType::new( + db, + self.signature(db).clone(), ))) } @@ -4363,8 +4342,8 @@ impl<'db> FunctionType<'db> { /// Were this not a salsa query, then the calling query /// would depend on the function's AST and rerun for every change in that file. #[salsa::tracked(return_ref)] - pub fn signature(self, db: &'db dyn Db) -> CallableSignature<'db> { - let internal_signature = self.internal_signature(db).into(); + pub fn signature(self, db: &'db dyn Db) -> Signature<'db> { + let internal_signature = self.internal_signature(db); let decorators = self.decorators(db); let mut decorators = decorators.iter(); @@ -4376,7 +4355,7 @@ impl<'db> FunctionType<'db> { { internal_signature } else { - CallableSignature::todo("return type of decorated function") + Signature::todo("return type of decorated function") } } else { internal_signature diff --git a/crates/red_knot_python_semantic/src/types/call.rs b/crates/red_knot_python_semantic/src/types/call.rs index 475326f95792c5..84f64bcc4958ca 100644 --- a/crates/red_knot_python_semantic/src/types/call.rs +++ b/crates/red_knot_python_semantic/src/types/call.rs @@ -1,197 +1,35 @@ use super::context::InferContext; -use super::{CallableSignature, Signature, Type}; -use crate::types::UnionType; +use super::{CallableSignature, Signature, Signatures, Type}; use crate::Db; mod arguments; mod bind; pub(super) use arguments::{Argument, CallArguments}; -pub(super) use bind::{bind_call, CallBinding}; +pub(super) use bind::Bindings; -/// A successfully bound call where all arguments are valid. +/// Wraps a [`Bindings`] for an unsuccessful call with information about why the call was +/// unsuccessful. /// -/// It's guaranteed that the wrapped bindings have no errors. +/// The bindings are boxed so that we do not pass around large `Err` variants on the stack. #[derive(Debug, Clone, PartialEq, Eq)] -pub(super) enum CallOutcome<'db> { - /// The call resolves to exactly one binding. - Single(CallBinding<'db>), - - /// The call resolves to multiple bindings. - Union(Box<[CallBinding<'db>]>), -} - -impl<'db> CallOutcome<'db> { - /// Calls each union element using the provided `call` function. - /// - /// Returns `Ok` if all variants can be called without error according to the callback and `Err` otherwise. - pub(super) fn try_call_union( - db: &'db dyn Db, - union: UnionType<'db>, - call: F, - ) -> Result> - where - F: Fn(Type<'db>) -> Result>, - { - let elements = union.elements(db); - let mut bindings = Vec::with_capacity(elements.len()); - let mut errors = Vec::new(); - let mut all_errors_not_callable = true; - - for element in elements { - match call(*element) { - Ok(CallOutcome::Single(binding)) => bindings.push(binding), - Ok(CallOutcome::Union(inner_bindings)) => { - bindings.extend(inner_bindings); - } - Err(error) => { - all_errors_not_callable &= error.is_not_callable(); - errors.push(error); - } - } - } - - if errors.is_empty() { - Ok(CallOutcome::Union(bindings.into())) - } else if bindings.is_empty() && all_errors_not_callable { - Err(CallError::NotCallable { - not_callable_type: Type::Union(union), - }) - } else { - Err(CallError::Union(UnionCallError { - errors: errors.into(), - bindings: bindings.into(), - called_type: Type::Union(union), - })) - } - } - - /// The type returned by this call. - pub(super) fn return_type(&self, db: &'db dyn Db) -> Type<'db> { - match self { - Self::Single(binding) => binding.return_type(), - Self::Union(bindings) => { - UnionType::from_elements(db, bindings.iter().map(CallBinding::return_type)) - } - } - } - - pub(super) fn bindings(&self) -> &[CallBinding<'db>] { - match self { - Self::Single(binding) => std::slice::from_ref(binding), - Self::Union(bindings) => bindings, - } - } -} +pub(crate) struct CallError<'db>(pub(crate) CallErrorKind, pub(crate) Box>); /// The reason why calling a type failed. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(super) enum CallError<'db> { - /// The type is not callable. - NotCallable { - /// The type that can't be called. - not_callable_type: Type<'db>, - }, - - /// A call to a union failed because at least one variant - /// can't be called with the given arguments. - /// - /// A union where all variants are not callable is represented as a `NotCallable` error. - Union(UnionCallError<'db>), - - /// The type has a `__call__` method but it isn't always bound. - PossiblyUnboundDunderCall { - called_type: Type<'db>, - outcome: Box>, - }, - - /// The type is callable but not with the given arguments. - BindingError { binding: CallBinding<'db> }, -} - -impl<'db> CallError<'db> { - /// Returns a fallback return type to use that best approximates the return type of the call. - /// - /// Returns `None` if the type isn't callable. - pub(super) fn return_type(&self, db: &'db dyn Db) -> Option> { - match self { - CallError::NotCallable { .. } => None, - // If some variants are callable, and some are not, return the union of the return types of the callable variants - // combined with `Type::Unknown` - CallError::Union(UnionCallError { - bindings, errors, .. - }) => Some(UnionType::from_elements( - db, - bindings - .iter() - .map(CallBinding::return_type) - .chain(errors.iter().map(|err| err.fallback_return_type(db))), - )), - Self::PossiblyUnboundDunderCall { outcome, .. } => Some(outcome.return_type(db)), - Self::BindingError { binding } => Some(binding.return_type()), - } - } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum CallErrorKind { + /// The type is not callable. For a union type, _none_ of the union elements are callable. + NotCallable, - /// Returns the return type of the call or a fallback that - /// represents the best guess of the return type (e.g. the actual return type even if the - /// dunder is possibly unbound). + /// The type is not callable with the given arguments. /// - /// If the type is not callable, returns `Type::Unknown`. - pub(super) fn fallback_return_type(&self, db: &'db dyn Db) -> Type<'db> { - self.return_type(db).unwrap_or(Type::unknown()) - } - - /// The resolved type that was not callable. - /// - /// For unions, returns the union type itself, which may contain a mix of callable and - /// non-callable types. - pub(super) fn called_type(&self) -> Type<'db> { - match self { - Self::NotCallable { - not_callable_type, .. - } => *not_callable_type, - Self::Union(UnionCallError { called_type, .. }) - | Self::PossiblyUnboundDunderCall { called_type, .. } => *called_type, - Self::BindingError { binding } => binding.callable_type(), - } - } - - pub(super) const fn is_not_callable(&self) -> bool { - matches!(self, Self::NotCallable { .. }) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(super) struct UnionCallError<'db> { - /// The variants that can't be called with the given arguments. - pub(super) errors: Box<[CallError<'db>]>, - - /// The bindings for the callable variants (that have no binding errors). - pub(super) bindings: Box<[CallBinding<'db>]>, - - /// The union type that we tried calling. - pub(super) called_type: Type<'db>, -} - -impl UnionCallError<'_> { - /// Return `true` if this `UnionCallError` indicates that the union might not be callable at all. - /// Otherwise, return `false`. - /// - /// For example, the union type `Callable[[int], int] | None` may not be callable at all, - /// because the `None` element in this union has no `__call__` method. Calling an object that - /// inhabited this union type would lead to a `UnionCallError` that would indicate that the - /// union might not be callable at all. - /// - /// On the other hand, the union type `Callable[[int], int] | Callable[[str], str]` is always - /// *callable*, but it would still lead to a `UnionCallError` if an inhabitant of this type was - /// called with a single `int` argument passed in. That's because the second element in the - /// union doesn't accept an `int` when it's called: it only accepts a `str`. - pub(crate) fn indicates_type_possibly_not_callable(&self) -> bool { - self.errors.iter().any(|error| match error { - CallError::BindingError { .. } => false, - CallError::NotCallable { .. } | CallError::PossiblyUnboundDunderCall { .. } => true, - CallError::Union(union_error) => union_error.indicates_type_possibly_not_callable(), - }) - } + /// `BindingError` takes precedence over `PossiblyNotCallable`: for a union type, there might + /// be some union elements that are not callable at all, but the call arguments are not + /// compatible with at least one of the callable elements. + BindingError, + + /// Not all of the elements of a union type are callable, but the call arguments are compatible + /// with all of the callable elements. + PossiblyNotCallable, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -199,12 +37,12 @@ pub(super) enum CallDunderError<'db> { /// The dunder attribute exists but it can't be called with the given arguments. /// /// This includes non-callable dunder attributes that are possibly unbound. - Call(CallError<'db>), + CallError(CallErrorKind, Box>), /// The type has the specified dunder method and it is callable /// with the specified arguments without any binding errors /// but it is possibly unbound. - PossiblyUnbound(CallOutcome<'db>), + PossiblyUnbound(Box>), /// The dunder method with the specified name is missing. MethodNotAvailable, @@ -213,9 +51,9 @@ pub(super) enum CallDunderError<'db> { impl<'db> CallDunderError<'db> { pub(super) fn return_type(&self, db: &'db dyn Db) -> Option> { match self { - Self::Call(error) => error.return_type(db), - Self::PossiblyUnbound(call_outcome) => Some(call_outcome.return_type(db)), - Self::MethodNotAvailable => None, + Self::MethodNotAvailable | Self::CallError(CallErrorKind::NotCallable, _) => None, + Self::CallError(_, bindings) => Some(bindings.return_type(db)), + Self::PossiblyUnbound(bindings) => Some(bindings.return_type(db)), } } @@ -225,7 +63,7 @@ impl<'db> CallDunderError<'db> { } impl<'db> From> for CallDunderError<'db> { - fn from(error: CallError<'db>) -> Self { - Self::Call(error) + fn from(CallError(kind, bindings): CallError<'db>) -> Self { + Self::CallError(kind, bindings) } } diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index 20ff424ac354e4..3b25fe89451e91 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -1,11 +1,20 @@ +//! When analyzing a call site, we create _bindings_, which match and type-check the actual +//! arguments against the parameters of the callable. Like with +//! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a +//! union of types, each of which might contain multiple overloads. + +use std::borrow::Cow; + +use smallvec::SmallVec; + use super::{ - Argument, CallArguments, CallError, CallOutcome, CallableSignature, InferContext, Signature, - Type, + Argument, CallArguments, CallError, CallErrorKind, CallableSignature, InferContext, Signature, + Signatures, Type, }; use crate::db::Db; use crate::types::diagnostic::{ - INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT, NO_MATCHING_OVERLOAD, PARAMETER_ALREADY_ASSIGNED, - TOO_MANY_POSITIONAL_ARGUMENTS, UNKNOWN_ARGUMENT, + CALL_NON_CALLABLE, INVALID_ARGUMENT_TYPE, MISSING_ARGUMENT, NO_MATCHING_OVERLOAD, + PARAMETER_ALREADY_ASSIGNED, TOO_MANY_POSITIONAL_ARGUMENTS, UNKNOWN_ARGUMENT, }; use crate::types::signatures::Parameter; use crate::types::{CallableType, UnionType}; @@ -13,162 +22,141 @@ use ruff_db::diagnostic::{OldSecondaryDiagnosticMessage, Span}; use ruff_python_ast as ast; use ruff_text_size::Ranged; -/// Bind a [`CallArguments`] against a [`CallableSignature`]. +/// Binding information for a possible union of callables. At a call site, the arguments must be +/// compatible with _all_ of the types in the union for the call to be valid. /// -/// The returned [`CallBinding`] provides the return type of the call, the bound types for all -/// parameters, and any errors resulting from binding the call. -pub(crate) fn bind_call<'db>( - db: &'db dyn Db, - arguments: &CallArguments<'_, 'db>, - overloads: &CallableSignature<'db>, - callable_ty: Type<'db>, -) -> CallBinding<'db> { - // TODO: This checks every overload. In the proposed more detailed call checking spec [1], - // arguments are checked for arity first, and are only checked for type assignability against - // the matching overloads. Make sure to implement that as part of separating call binding into - // two phases. - // - // [1] https://github.com/python/typing/pull/1839 - let overloads = overloads - .iter() - .map(|signature| bind_overload(db, arguments, signature)) - .collect::>() - .into_boxed_slice(); - CallBinding { - callable_ty, - overloads, - } +/// It's guaranteed that the wrapped bindings have no errors. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct Bindings<'db> { + pub(crate) callable_type: Type<'db>, + /// By using `SmallVec`, we avoid an extra heap allocation for the common case of a non-union + /// type. + elements: SmallVec<[CallableBinding<'db>; 1]>, } -fn bind_overload<'db>( - db: &'db dyn Db, - arguments: &CallArguments<'_, 'db>, - signature: &Signature<'db>, -) -> OverloadBinding<'db> { - let parameters = signature.parameters(); - // The type assigned to each parameter at this call site. - let mut parameter_tys = vec![None; parameters.len()]; - let mut errors = vec![]; - let mut next_positional = 0; - let mut first_excess_positional = None; - let mut num_synthetic_args = 0; - let get_argument_index = |argument_index: usize, num_synthetic_args: usize| { - if argument_index >= num_synthetic_args { - // Adjust the argument index to skip synthetic args, which don't appear at the call - // site and thus won't be in the Call node arguments list. - Some(argument_index - num_synthetic_args) - } else { - // we are erroring on a synthetic argument, we'll just emit the diagnostic on the - // entire Call node, since there's no argument node for this argument at the call site - None +impl<'db> Bindings<'db> { + /// Binds the arguments of a call site against a signature. + /// + /// The returned bindings provide the return type of the call, the bound types for all + /// parameters, and any errors resulting from binding the call, all for each union element and + /// overload (if any). + pub(crate) fn bind( + db: &'db dyn Db, + signatures: &Signatures<'db>, + arguments: &CallArguments<'_, 'db>, + ) -> Result> { + let elements: SmallVec<[CallableBinding<'db>; 1]> = signatures + .into_iter() + .map(|signature| CallableBinding::bind(db, signature, arguments)) + .collect(); + + // In order of precedence: + // + // - If every union element is Ok, then the union is too. + // - If any element has a BindingError, the union has a BindingError. + // - If every element is NotCallable, then the union is also NotCallable. + // - Otherwise, the elements are some mixture of Ok, NotCallable, and PossiblyNotCallable. + // The union as a whole is PossiblyNotCallable. + // + // For example, the union type `Callable[[int], int] | None` may not be callable at all, + // because the `None` element in this union has no `__call__` method. + // + // On the other hand, the union type `Callable[[int], int] | Callable[[str], str]` is + // always *callable*, but it would produce a `BindingError` if an inhabitant of this type + // was called with a single `int` argument passed in. That's because the second element in + // the union doesn't accept an `int` when it's called: it only accepts a `str`. + let mut all_ok = true; + let mut any_binding_error = false; + let mut all_not_callable = true; + for binding in &elements { + let result = binding.as_result(); + all_ok &= result.is_ok(); + any_binding_error |= matches!(result, Err(CallErrorKind::BindingError)); + all_not_callable &= matches!(result, Err(CallErrorKind::NotCallable)); } - }; - for (argument_index, argument) in arguments.iter().enumerate() { - let (index, parameter, argument_ty, positional) = match argument { - Argument::Positional(ty) | Argument::Synthetic(ty) => { - if matches!(argument, Argument::Synthetic(_)) { - num_synthetic_args += 1; - } - let Some((index, parameter)) = parameters - .get_positional(next_positional) - .map(|param| (next_positional, param)) - .or_else(|| parameters.variadic()) - else { - first_excess_positional.get_or_insert(argument_index); - next_positional += 1; - continue; - }; - next_positional += 1; - (index, parameter, ty, !parameter.is_variadic()) - } - Argument::Keyword { name, ty } => { - let Some((index, parameter)) = parameters - .keyword_by_name(name) - .or_else(|| parameters.keyword_variadic()) - else { - errors.push(CallBindingError::UnknownArgument { - argument_name: ast::name::Name::new(name), - argument_index: get_argument_index(argument_index, num_synthetic_args), - }); - continue; - }; - (index, parameter, ty, false) - } - Argument::Variadic(_) | Argument::Keywords(_) => { - // TODO - continue; - } + let bindings = Bindings { + callable_type: signatures.callable_type, + elements, }; - if let Some(expected_ty) = parameter.annotated_type() { - if !argument_ty.is_assignable_to(db, expected_ty) { - errors.push(CallBindingError::InvalidArgumentType { - parameter: ParameterContext::new(parameter, index, positional), - argument_index: get_argument_index(argument_index, num_synthetic_args), - expected_ty, - provided_ty: *argument_ty, - }); - } - } - if let Some(existing) = parameter_tys[index].replace(*argument_ty) { - if parameter.is_variadic() || parameter.is_keyword_variadic() { - let union = UnionType::from_elements(db, [existing, *argument_ty]); - parameter_tys[index].replace(union); - } else { - errors.push(CallBindingError::ParameterAlreadyAssigned { - argument_index: get_argument_index(argument_index, num_synthetic_args), - parameter: ParameterContext::new(parameter, index, positional), - }); - } + + if all_ok { + Ok(bindings) + } else if any_binding_error { + Err(CallError(CallErrorKind::BindingError, Box::new(bindings))) + } else if all_not_callable { + Err(CallError(CallErrorKind::NotCallable, Box::new(bindings))) + } else { + Err(CallError( + CallErrorKind::PossiblyNotCallable, + Box::new(bindings), + )) } } - if let Some(first_excess_argument_index) = first_excess_positional { - errors.push(CallBindingError::TooManyPositionalArguments { - first_excess_argument_index: get_argument_index( - first_excess_argument_index, - num_synthetic_args, - ), - expected_positional_count: parameters.positional().count(), - provided_positional_count: next_positional, - }); + + pub(crate) fn is_single(&self) -> bool { + self.elements.len() == 1 } - let mut missing = vec![]; - for (index, bound_ty) in parameter_tys.iter().enumerate() { - if bound_ty.is_none() { - let param = ¶meters[index]; - if param.is_variadic() || param.is_keyword_variadic() || param.default_type().is_some() - { - // variadic/keywords and defaulted arguments are not required - continue; - } - missing.push(ParameterContext::new(param, index, false)); + + /// Returns the return type of the call. For successful calls, this is the actual return type. + /// For calls with binding errors, this is a type that best approximates the return type. For + /// types that are not callable, returns `Type::Unknown`. + pub(crate) fn return_type(&self, db: &'db dyn Db) -> Type<'db> { + if let [binding] = self.elements.as_slice() { + return binding.return_type(); } + UnionType::from_elements(db, self.into_iter().map(CallableBinding::return_type)) } - if !missing.is_empty() { - errors.push(CallBindingError::MissingArguments { - parameters: ParameterContexts(missing), - }); + /// Report diagnostics for all of the errors that occurred when trying to match actual + /// arguments to formal parameters. If the callable is a union, or has multiple overloads, we + /// report a single diagnostic if we couldn't match any union element or overload. + /// TODO: Update this to add subdiagnostics about how we failed to match each union element and + /// overload. + pub(crate) fn report_diagnostics(&self, context: &InferContext<'db>, node: ast::AnyNodeRef) { + // If all union elements are not callable, report that the union as a whole is not + // callable. + if self.into_iter().all(|b| !b.is_callable()) { + context.report_lint( + &CALL_NON_CALLABLE, + node, + format_args!( + "Object of type `{}` is not callable", + self.callable_type.display(context.db()) + ), + ); + return; + } + + // TODO: We currently only report errors for the first union element. Ideally, we'd report + // an error saying that the union type can't be called, followed by subdiagnostics + // explaining why. + if let Some(first) = self.into_iter().find(|b| b.as_result().is_err()) { + first.report_diagnostics(context, node); + } } +} - OverloadBinding { - return_ty: signature.return_ty.unwrap_or(Type::unknown()), - parameter_tys: parameter_tys - .into_iter() - .map(|opt_ty| opt_ty.unwrap_or(Type::unknown())) - .collect(), - errors, +impl<'a, 'db> IntoIterator for &'a Bindings<'db> { + type Item = &'a CallableBinding<'db>; + type IntoIter = std::slice::Iter<'a, CallableBinding<'db>>; + + fn into_iter(self) -> Self::IntoIter { + self.elements.iter() } } -/// Describes a callable for the purposes of diagnostics. -#[derive(Debug)] -pub(crate) struct CallableDescriptor<'a> { - name: &'a str, - kind: &'a str, +impl<'a, 'db> IntoIterator for &'a mut Bindings<'db> { + type Item = &'a mut CallableBinding<'db>; + type IntoIter = std::slice::IterMut<'a, CallableBinding<'db>>; + + fn into_iter(self) -> Self::IntoIter { + self.elements.iter_mut() + } } -/// Binding information for a call site. +/// Binding information for a single callable. If the callable is overloaded, there is a separate +/// [`Binding`] for each overload. /// /// For a successful binding, each argument is mapped to one of the callable's formal parameters. /// If the callable has multiple overloads, the first one that matches is used as the overall @@ -183,23 +171,72 @@ pub(crate) struct CallableDescriptor<'a> { /// /// [overloads]: https://github.com/python/typing/pull/1839 #[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct CallBinding<'db> { - /// Type of the callable object (function, class...) - callable_ty: Type<'db>, - - overloads: Box<[OverloadBinding<'db>]>, +pub(crate) struct CallableBinding<'db> { + pub(crate) callable_type: Type<'db>, + pub(crate) signature_type: Type<'db>, + pub(crate) dunder_call_is_possibly_unbound: bool, + + /// The bindings of each overload of this callable. Will be empty if the type is not callable. + /// + /// By using `SmallVec`, we avoid an extra heap allocation for the common case of a + /// non-overloaded callable. + overloads: SmallVec<[Binding<'db>; 1]>, } -impl<'db> CallBinding<'db> { - pub(crate) fn into_outcome(self) -> Result, CallError<'db>> { +impl<'db> CallableBinding<'db> { + /// Bind a [`CallArguments`] against a [`CallableSignature`]. + /// + /// The returned [`CallableBinding`] provides the return type of the call, the bound types for + /// all parameters, and any errors resulting from binding the call. + fn bind( + db: &'db dyn Db, + signature: &CallableSignature<'db>, + arguments: &CallArguments<'_, 'db>, + ) -> Self { + // If this callable is a bound method, prepend the self instance onto the arguments list + // before checking. + let arguments = if let Some(bound_type) = signature.bound_type { + Cow::Owned(arguments.with_self(bound_type)) + } else { + Cow::Borrowed(arguments) + }; + + // TODO: This checks every overload. In the proposed more detailed call checking spec [1], + // arguments are checked for arity first, and are only checked for type assignability against + // the matching overloads. Make sure to implement that as part of separating call binding into + // two phases. + // + // [1] https://github.com/python/typing/pull/1839 + let overloads = signature + .into_iter() + .map(|signature| Binding::bind(db, signature, arguments.as_ref())) + .collect(); + CallableBinding { + callable_type: signature.callable_type, + signature_type: signature.signature_type, + dunder_call_is_possibly_unbound: signature.dunder_call_is_possibly_unbound, + overloads, + } + } + + fn as_result(&self) -> Result<(), CallErrorKind> { + if !self.is_callable() { + return Err(CallErrorKind::NotCallable); + } + if self.has_binding_errors() { - return Err(CallError::BindingError { binding: self }); + return Err(CallErrorKind::BindingError); + } + + if self.dunder_call_is_possibly_unbound { + return Err(CallErrorKind::PossiblyNotCallable); } - Ok(CallOutcome::Single(self)) + + Ok(()) } - pub(crate) fn callable_type(&self) -> Type<'db> { - self.callable_ty + fn is_callable(&self) -> bool { + !self.overloads.is_empty() } /// Returns whether there were any errors binding this call site. If the callable has multiple @@ -210,20 +247,20 @@ impl<'db> CallBinding<'db> { /// Returns the overload that matched for this call binding. Returns `None` if none of the /// overloads matched. - pub(crate) fn matching_overload(&self) -> Option<(usize, &OverloadBinding<'db>)> { + pub(crate) fn matching_overload(&self) -> Option<(usize, &Binding<'db>)> { self.overloads .iter() .enumerate() - .find(|(_, overload)| !overload.has_binding_errors()) + .find(|(_, overload)| overload.as_result().is_ok()) } /// Returns the overload that matched for this call binding. Returns `None` if none of the /// overloads matched. - pub(crate) fn matching_overload_mut(&mut self) -> Option<(usize, &mut OverloadBinding<'db>)> { + pub(crate) fn matching_overload_mut(&mut self) -> Option<(usize, &mut Binding<'db>)> { self.overloads .iter_mut() .enumerate() - .find(|(_, overload)| !overload.has_binding_errors()) + .find(|(_, overload)| overload.as_result().is_ok()) } /// Returns the return type of this call. For a valid call, this is the return type of the @@ -235,53 +272,45 @@ impl<'db> CallBinding<'db> { if let Some((_, overload)) = self.matching_overload() { return overload.return_type(); } - if let [overload] = self.overloads.as_ref() { + if let [overload] = self.overloads.as_slice() { return overload.return_type(); } Type::unknown() } - fn callable_descriptor(&self, db: &'db dyn Db) -> Option { - match self.callable_ty { - Type::FunctionLiteral(function) => Some(CallableDescriptor { - kind: "function", - name: function.name(db), - }), - Type::ClassLiteral(class_type) => Some(CallableDescriptor { - kind: "class", - name: class_type.class().name(db), - }), - Type::Callable(CallableType::BoundMethod(bound_method)) => Some(CallableDescriptor { - kind: "bound method", - name: bound_method.function(db).name(db), - }), - Type::Callable(CallableType::MethodWrapperDunderGet(function)) => { - Some(CallableDescriptor { - kind: "method wrapper `__get__` of function", - name: function.name(db), - }) - } - Type::Callable(CallableType::WrapperDescriptorDunderGet) => Some(CallableDescriptor { - kind: "wrapper descriptor", - name: "FunctionType.__get__", - }), - _ => None, + fn report_diagnostics(&self, context: &InferContext<'db>, node: ast::AnyNodeRef) { + if !self.is_callable() { + context.report_lint( + &CALL_NON_CALLABLE, + node, + format_args!( + "Object of type `{}` is not callable", + self.callable_type.display(context.db()), + ), + ); + return; } - } - /// Report diagnostics for all of the errors that occurred when trying to match actual - /// arguments to formal parameters. If the callable has multiple overloads, we report a single - /// diagnostic that we couldn't match any overload. - /// TODO: Update this to add subdiagnostics about how we failed to match each overload. - pub(crate) fn report_diagnostics(&self, context: &InferContext<'db>, node: ast::AnyNodeRef) { - let callable_descriptor = self.callable_descriptor(context.db()); + if self.dunder_call_is_possibly_unbound { + context.report_lint( + &CALL_NON_CALLABLE, + node, + format_args!( + "Object of type `{}` is not callable (possibly unbound `__call__` method)", + self.callable_type.display(context.db()), + ), + ); + return; + } + + let callable_description = CallableDescription::new(context.db(), self.callable_type); if self.overloads.len() > 1 { context.report_lint( &NO_MATCHING_OVERLOAD, node, format_args!( "No overload{} matches arguments", - if let Some(CallableDescriptor { kind, name }) = callable_descriptor { + if let Some(CallableDescription { kind, name }) = callable_description { format!(" of {kind} `{name}`") } else { String::new() @@ -291,12 +320,13 @@ impl<'db> CallBinding<'db> { return; } + let callable_description = CallableDescription::new(context.db(), self.signature_type); for overload in &self.overloads { overload.report_diagnostics( context, node, - self.callable_ty, - callable_descriptor.as_ref(), + self.signature_type, + callable_description.as_ref(), ); } } @@ -304,7 +334,7 @@ impl<'db> CallBinding<'db> { /// Binding information for one of the overloads of a callable. #[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct OverloadBinding<'db> { +pub(crate) struct Binding<'db> { /// Return type of the call. return_ty: Type<'db>, @@ -312,10 +342,133 @@ pub(crate) struct OverloadBinding<'db> { parameter_tys: Box<[Type<'db>]>, /// Call binding errors, if any. - errors: Vec>, + errors: Vec>, } -impl<'db> OverloadBinding<'db> { +impl<'db> Binding<'db> { + fn bind( + db: &'db dyn Db, + signature: &Signature<'db>, + arguments: &CallArguments<'_, 'db>, + ) -> Self { + let parameters = signature.parameters(); + // The type assigned to each parameter at this call site. + let mut parameter_tys = vec![None; parameters.len()]; + let mut errors = vec![]; + let mut next_positional = 0; + let mut first_excess_positional = None; + let mut num_synthetic_args = 0; + let get_argument_index = |argument_index: usize, num_synthetic_args: usize| { + if argument_index >= num_synthetic_args { + // Adjust the argument index to skip synthetic args, which don't appear at the call + // site and thus won't be in the Call node arguments list. + Some(argument_index - num_synthetic_args) + } else { + // we are erroring on a synthetic argument, we'll just emit the diagnostic on the + // entire Call node, since there's no argument node for this argument at the call site + None + } + }; + for (argument_index, argument) in arguments.iter().enumerate() { + let (index, parameter, argument_ty, positional) = match argument { + Argument::Positional(ty) | Argument::Synthetic(ty) => { + if matches!(argument, Argument::Synthetic(_)) { + num_synthetic_args += 1; + } + let Some((index, parameter)) = parameters + .get_positional(next_positional) + .map(|param| (next_positional, param)) + .or_else(|| parameters.variadic()) + else { + first_excess_positional.get_or_insert(argument_index); + next_positional += 1; + continue; + }; + next_positional += 1; + (index, parameter, ty, !parameter.is_variadic()) + } + Argument::Keyword { name, ty } => { + let Some((index, parameter)) = parameters + .keyword_by_name(name) + .or_else(|| parameters.keyword_variadic()) + else { + errors.push(BindingError::UnknownArgument { + argument_name: ast::name::Name::new(name), + argument_index: get_argument_index(argument_index, num_synthetic_args), + }); + continue; + }; + (index, parameter, ty, false) + } + + Argument::Variadic(_) | Argument::Keywords(_) => { + // TODO + continue; + } + }; + if let Some(expected_ty) = parameter.annotated_type() { + if !argument_ty.is_assignable_to(db, expected_ty) { + errors.push(BindingError::InvalidArgumentType { + parameter: ParameterContext::new(parameter, index, positional), + argument_index: get_argument_index(argument_index, num_synthetic_args), + expected_ty, + provided_ty: *argument_ty, + }); + } + } + if let Some(existing) = parameter_tys[index].replace(*argument_ty) { + if parameter.is_variadic() || parameter.is_keyword_variadic() { + let union = UnionType::from_elements(db, [existing, *argument_ty]); + parameter_tys[index].replace(union); + } else { + errors.push(BindingError::ParameterAlreadyAssigned { + argument_index: get_argument_index(argument_index, num_synthetic_args), + parameter: ParameterContext::new(parameter, index, positional), + }); + } + } + } + if let Some(first_excess_argument_index) = first_excess_positional { + errors.push(BindingError::TooManyPositionalArguments { + first_excess_argument_index: get_argument_index( + first_excess_argument_index, + num_synthetic_args, + ), + expected_positional_count: parameters.positional().count(), + provided_positional_count: next_positional, + }); + } + let mut missing = vec![]; + for (index, bound_ty) in parameter_tys.iter().enumerate() { + if bound_ty.is_none() { + let param = ¶meters[index]; + if param.is_variadic() + || param.is_keyword_variadic() + || param.default_type().is_some() + { + // variadic/keywords and defaulted arguments are not required + continue; + } + missing.push(ParameterContext::new(param, index, false)); + } + } + + if !missing.is_empty() { + errors.push(BindingError::MissingArguments { + parameters: ParameterContexts(missing), + }); + } + + Self { + return_ty: signature.return_ty.unwrap_or(Type::unknown()), + parameter_tys: parameter_tys + .into_iter() + .map(|opt_ty| opt_ty.unwrap_or(Type::unknown())) + .collect(), + errors, + } + } + pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) { self.return_ty = return_ty; } @@ -333,15 +486,55 @@ impl<'db> OverloadBinding<'db> { context: &InferContext<'db>, node: ast::AnyNodeRef, callable_ty: Type<'db>, - callable_descriptor: Option<&CallableDescriptor>, + callable_description: Option<&CallableDescription>, ) { for error in &self.errors { - error.report_diagnostic(context, node, callable_ty, callable_descriptor); + error.report_diagnostic(context, node, callable_ty, callable_description); } } - pub(crate) fn has_binding_errors(&self) -> bool { - !self.errors.is_empty() + fn as_result(&self) -> Result<(), CallErrorKind> { + if !self.errors.is_empty() { + return Err(CallErrorKind::BindingError); + } + Ok(()) + } +} + +/// Describes a callable for the purposes of diagnostics. +#[derive(Debug)] +pub(crate) struct CallableDescription<'a> { + name: &'a str, + kind: &'a str, +} + +impl<'db> CallableDescription<'db> { + fn new(db: &'db dyn Db, callable_type: Type<'db>) -> Option> { + match callable_type { + Type::FunctionLiteral(function) => Some(CallableDescription { + kind: "function", + name: function.name(db), + }), + Type::ClassLiteral(class_type) => Some(CallableDescription { + kind: "class", + name: class_type.class().name(db), + }), + Type::Callable(CallableType::BoundMethod(bound_method)) => Some(CallableDescription { + kind: "bound method", + name: bound_method.function(db).name(db), + }), + Type::Callable(CallableType::MethodWrapperDunderGet(function)) => { + Some(CallableDescription { + kind: "method wrapper `__get__` of function", + name: function.name(db), + }) + } + Type::Callable(CallableType::WrapperDescriptorDunderGet) => Some(CallableDescription { + kind: "wrapper descriptor", + name: "FunctionType.__get__", + }), + _ => None, + } } } @@ -399,7 +592,7 @@ impl std::fmt::Display for ParameterContexts { } #[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) enum CallBindingError<'db> { +pub(crate) enum BindingError<'db> { /// The type of an argument is not assignable to the annotated type of its corresponding /// parameter. InvalidArgumentType { @@ -428,7 +621,7 @@ pub(crate) enum CallBindingError<'db> { }, } -impl<'db> CallBindingError<'db> { +impl<'db> BindingError<'db> { fn parameter_span_from_index( db: &'db dyn Db, callable_ty: Type<'db>, @@ -468,7 +661,7 @@ impl<'db> CallBindingError<'db> { context: &InferContext<'db>, node: ast::AnyNodeRef, callable_ty: Type<'db>, - callable_descriptor: Option<&CallableDescriptor>, + callable_description: Option<&CallableDescription>, ) { match self { Self::InvalidArgumentType { @@ -495,7 +688,7 @@ impl<'db> CallBindingError<'db> { format_args!( "Object of type `{provided_ty_display}` cannot be assigned to \ parameter {parameter}{}; expected type `{expected_ty_display}`", - if let Some(CallableDescriptor { kind, name }) = callable_descriptor { + if let Some(CallableDescription { kind, name }) = callable_description { format!(" of {kind} `{name}`") } else { String::new() @@ -516,7 +709,7 @@ impl<'db> CallBindingError<'db> { format_args!( "Too many positional arguments{}: expected \ {expected_positional_count}, got {provided_positional_count}", - if let Some(CallableDescriptor { kind, name }) = callable_descriptor { + if let Some(CallableDescription { kind, name }) = callable_description { format!(" to {kind} `{name}`") } else { String::new() @@ -532,7 +725,7 @@ impl<'db> CallBindingError<'db> { node, format_args!( "No argument{s} provided for required parameter{s} {parameters}{}", - if let Some(CallableDescriptor { kind, name }) = callable_descriptor { + if let Some(CallableDescription { kind, name }) = callable_description { format!(" of {kind} `{name}`") } else { String::new() @@ -550,7 +743,7 @@ impl<'db> CallBindingError<'db> { Self::get_node(node, *argument_index), format_args!( "Argument `{argument_name}` does not match any known parameter{}", - if let Some(CallableDescriptor { kind, name }) = callable_descriptor { + if let Some(CallableDescription { kind, name }) = callable_description { format!(" of {kind} `{name}`") } else { String::new() @@ -568,7 +761,7 @@ impl<'db> CallBindingError<'db> { Self::get_node(node, *argument_index), format_args!( "Multiple values provided for parameter {parameter}{}", - if let Some(CallableDescriptor { kind, name }) = callable_descriptor { + if let Some(CallableDescription { kind, name }) = callable_description { format!(" of {kind} `{name}`") } else { String::new() diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index dd0e825c85e4f1..0de0a2b38c8211 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -11,8 +11,8 @@ use crate::{ Boundness, LookupError, LookupResult, Symbol, SymbolAndQualifiers, }, types::{ - definition_expression_type, CallArguments, CallError, DynamicType, MetaclassCandidate, - TupleType, UnionBuilder, UnionCallError, UnionType, + definition_expression_type, CallArguments, CallError, CallErrorKind, DynamicType, + MetaclassCandidate, TupleType, UnionBuilder, UnionType, }, Db, KnownModule, Program, }; @@ -282,57 +282,21 @@ impl<'db> Class<'db> { let arguments = CallArguments::positional([name, bases, namespace]); let return_ty_result = match metaclass.try_call(db, &arguments) { - Ok(outcome) => Ok(outcome.return_type(db)), + Ok(bindings) => Ok(bindings.return_type(db)), - Err(CallError::NotCallable { not_callable_type }) => Err(MetaclassError { - kind: MetaclassErrorKind::NotCallable(not_callable_type), + Err(CallError(CallErrorKind::NotCallable, bindings)) => Err(MetaclassError { + kind: MetaclassErrorKind::NotCallable(bindings.callable_type), }), - Err(CallError::Union(UnionCallError { - called_type, - errors, - bindings, - })) => { - let mut partly_not_callable = false; - - let return_ty = errors - .iter() - .fold(None, |acc, error| { - let ty = error.return_type(db); - - match (acc, ty) { - (acc, None) => { - partly_not_callable = true; - acc - } - (None, Some(ty)) => Some(UnionBuilder::new(db).add(ty)), - (Some(builder), Some(ty)) => Some(builder.add(ty)), - } - }) - .map(|mut builder| { - for binding in bindings { - builder = builder.add(binding.return_type()); - } - - builder.build() - }); - - if partly_not_callable { - Err(MetaclassError { - kind: MetaclassErrorKind::PartlyNotCallable(called_type), - }) - } else { - Ok(return_ty.unwrap_or(Type::unknown())) - } + // TODO we should also check for binding errors that would indicate the metaclass + // does not accept the right arguments + Err(CallError(CallErrorKind::BindingError, bindings)) => { + Ok(bindings.return_type(db)) } - Err(CallError::PossiblyUnboundDunderCall { .. }) => Err(MetaclassError { + Err(CallError(CallErrorKind::PossiblyNotCallable, _)) => Err(MetaclassError { kind: MetaclassErrorKind::PartlyNotCallable(metaclass), }), - - // TODO we should also check for binding errors that would indicate the metaclass - // does not accept the right arguments - Err(CallError::BindingError { binding }) => Ok(binding.return_type()), }; return return_ty_result.map(|ty| ty.to_meta_type(db)); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 6ffc04d7ca9cc7..6c1ae9fb92e415 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -61,7 +61,7 @@ use crate::symbol::{ module_type_implicit_global_symbol, symbol, symbol_from_bindings, symbol_from_declarations, typing_extensions_symbol, Boundness, LookupError, }; -use crate::types::call::{Argument, CallArguments, UnionCallError}; +use crate::types::call::{Argument, CallArguments, CallError}; use crate::types::diagnostic::{ report_implicit_return_type, report_invalid_arguments_to_annotated, report_invalid_arguments_to_callable, report_invalid_assignment, @@ -89,7 +89,6 @@ use crate::unpack::Unpack; use crate::util::subscript::{PyIndex, PySlice}; use crate::Db; -use super::call::CallError; use super::class_base::ClassBase; use super::context::{InNoTypeCheck, InferContext, WithDiagnostics}; use super::diagnostic::{ @@ -2789,9 +2788,9 @@ impl<'db> TypeInferenceBuilder<'db> { Err(CallDunderError::PossiblyUnbound(outcome)) => { UnionType::from_elements(db, [outcome.return_type(db), binary_return_ty()]) } - Err(CallDunderError::Call(call_error)) => { + Err(CallDunderError::CallError(_, bindings)) => { report_unsupported_augmented_op(&mut self.context); - call_error.fallback_return_type(db) + bindings.return_type(db) } } } @@ -3857,13 +3856,11 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_or_default(); let call_arguments = self.infer_arguments(arguments, parameter_expectations); - let call = function_type.try_call(self.db(), &call_arguments); - - match call { - Ok(outcome) => { - for binding in outcome.bindings() { + match function_type.try_call(self.db(), &call_arguments) { + Ok(bindings) => { + for binding in &bindings { let Some(known_function) = binding - .callable_type() + .callable_type .into_function_literal() .and_then(|function_type| function_type.known(self.db())) else { @@ -3967,61 +3964,12 @@ impl<'db> TypeInferenceBuilder<'db> { _ => {} } } - - outcome.return_type(self.db()) + bindings.return_type(self.db()) } - Err(err) => { - // TODO: We currently only report the first error. Ideally, we'd report - // an error saying that the union type can't be called, followed by a sub - // diagnostic explaining why. - fn report_call_error( - context: &InferContext, - err: CallError, - call_expression: &ast::ExprCall, - ) { - match err { - CallError::NotCallable { not_callable_type } => { - context.report_lint( - &CALL_NON_CALLABLE, - call_expression, - format_args!( - "Object of type `{}` is not callable", - not_callable_type.display(context.db()) - ), - ); - } - - CallError::Union(UnionCallError { errors, .. }) => { - if let Some(first) = IntoIterator::into_iter(errors).next() { - report_call_error(context, first, call_expression); - } else { - debug_assert!( - false, - "Expected `CalLError::Union` to at least have one error" - ); - } - } - - CallError::PossiblyUnboundDunderCall { called_type, .. } => { - context.report_lint( - &CALL_NON_CALLABLE, - call_expression, - format_args!( - "Object of type `{}` is not callable (possibly unbound `__call__` method)", - called_type.display(context.db()) - ), - ); - } - CallError::BindingError { binding, .. } => { - binding.report_diagnostics(context, call_expression.into()); - } - } - } - let return_type = err.fallback_return_type(self.db()); - report_call_error(&self.context, err, call_expression); - - return_type + Err(CallError(_, bindings)) => { + bindings.report_diagnostics(&self.context, call_expression.into()); + bindings.return_type(self.db()) } } } @@ -4669,7 +4617,7 @@ impl<'db> TypeInferenceBuilder<'db> { let reflected_dunder = op.reflected_dunder(); let rhs_reflected = right_class.member(self.db(), reflected_dunder).symbol; // TODO: if `rhs_reflected` is possibly unbound, we should union the two possible - // CallOutcomes together + // Bindings together if !rhs_reflected.is_unbound() && rhs_reflected != left_class.member(self.db(), reflected_dunder).symbol { @@ -5412,7 +5360,7 @@ impl<'db> TypeInferenceBuilder<'db> { db, &CallArguments::positional([Type::Instance(right), Type::Instance(left)]), ) - .map(|outcome| outcome.return_type(db)) + .map(|bindings| bindings.return_type(db)) .ok() } _ => { @@ -5696,18 +5644,18 @@ impl<'db> TypeInferenceBuilder<'db> { return err.fallback_return_type(self.db()); } - Err(CallDunderError::Call(err)) => { + Err(CallDunderError::CallError(_, bindings)) => { self.context.report_lint( - &CALL_NON_CALLABLE, - value_node, - format_args!( - "Method `__getitem__` of type `{}` is not callable on object of type `{}`", - err.called_type().display(self.db()), - value_ty.display(self.db()), - ), - ); + &CALL_NON_CALLABLE, + value_node, + format_args!( + "Method `__getitem__` of type `{}` is not callable on object of type `{}`", + bindings.callable_type.display(self.db()), + value_ty.display(self.db()), + ), + ); - return err.fallback_return_type(self.db()); + return bindings.return_type(self.db()); } Err(CallDunderError::MethodNotAvailable) => { // try `__class_getitem__` @@ -5741,21 +5689,24 @@ impl<'db> TypeInferenceBuilder<'db> { ); } - return ty - .try_call(self.db(), &CallArguments::positional([value_ty, slice_ty])) - .map(|outcome| outcome.return_type(self.db())) - .unwrap_or_else(|err| { + match ty.try_call( + self.db(), + &CallArguments::positional([value_ty, slice_ty]), + ) { + Ok(bindings) => return bindings.return_type(self.db()), + Err(CallError(_, bindings)) => { self.context.report_lint( &CALL_NON_CALLABLE, value_node, format_args!( "Method `__class_getitem__` of type `{}` is not callable on object of type `{}`", - err.called_type().display(self.db()), + bindings.callable_type.display(self.db()), value_ty.display(self.db()), ), ); - err.fallback_return_type(self.db()) - }); + return bindings.return_type(self.db()); + } + } } } @@ -6686,16 +6637,7 @@ impl<'db> TypeInferenceBuilder<'db> { ); return Type::unknown(); }; - function_type - .into_callable_type(self.db()) - .unwrap_or_else(|| { - self.context.report_lint( - &INVALID_TYPE_FORM, - arguments_slice, - format_args!("Overloaded function literal is not yet supported"), - ); - Type::unknown() - }) + function_type.into_callable_type(self.db()) } }, diff --git a/crates/red_knot_python_semantic/src/types/signatures.rs b/crates/red_knot_python_semantic/src/types/signatures.rs index 7681a6d55812fc..eeac874ab70d9d 100644 --- a/crates/red_knot_python_semantic/src/types/signatures.rs +++ b/crates/red_knot_python_semantic/src/types/signatures.rs @@ -10,75 +10,188 @@ //! argument types and return types. For each callable type in the union, the call expression's //! arguments must match _at least one_ overload. +use smallvec::{smallvec, SmallVec}; + use super::{definition_expression_type, DynamicType, Type}; +use crate::semantic_index::definition::Definition; +use crate::types::todo_type; use crate::Db; -use crate::{semantic_index::definition::Definition, types::todo_type}; use ruff_python_ast::{self as ast, name::Name}; +/// The signature of a possible union of callables. +#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] +pub(crate) struct Signatures<'db> { + /// The type that is (hopefully) callable. + pub(crate) callable_type: Type<'db>, + /// The type we'll use for error messages referring to details of the called signature. For calls to functions this + /// will be the same as `callable_type`; for other callable instances it may be a `__call__` method. + pub(crate) signature_type: Type<'db>, + /// By using `SmallVec`, we avoid an extra heap allocation for the common case of a non-union + /// type. + elements: SmallVec<[CallableSignature<'db>; 1]>, +} + +impl<'db> Signatures<'db> { + pub(crate) fn not_callable(signature_type: Type<'db>) -> Self { + Self { + callable_type: signature_type, + signature_type, + elements: smallvec![CallableSignature::not_callable(signature_type)], + } + } + + pub(crate) fn single(signature: CallableSignature<'db>) -> Self { + Self { + callable_type: signature.callable_type, + signature_type: signature.signature_type, + elements: smallvec![signature], + } + } + + /// Creates a new `Signatures` from an iterator of [`Signature`]s. Panics if the iterator is + /// empty. + pub(crate) fn from_union(signature_type: Type<'db>, elements: I) -> Self + where + I: IntoIterator>, + { + let elements: SmallVec<_> = elements + .into_iter() + .flat_map(|s| s.elements.into_iter()) + .collect(); + assert!(!elements.is_empty()); + Self { + callable_type: signature_type, + signature_type, + elements, + } + } + + pub(crate) fn replace_callable_type(&mut self, before: Type<'db>, after: Type<'db>) { + if self.callable_type == before { + self.callable_type = after; + } + for signature in &mut self.elements { + signature.replace_callable_type(before, after); + } + } + + pub(crate) fn set_dunder_call_is_possibly_unbound(&mut self) { + for signature in &mut self.elements { + signature.dunder_call_is_possibly_unbound = true; + } + } +} + +impl<'a, 'db> IntoIterator for &'a Signatures<'db> { + type Item = &'a CallableSignature<'db>; + type IntoIter = std::slice::Iter<'a, CallableSignature<'db>>; + + fn into_iter(self) -> Self::IntoIter { + self.elements.iter() + } +} + /// The signature of a single callable. If the callable is overloaded, there is a separate /// [`Signature`] for each overload. #[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] -pub enum CallableSignature<'db> { - Single(Signature<'db>), - Overloaded(Box<[Signature<'db>]>), +pub(crate) struct CallableSignature<'db> { + /// The type that is (hopefully) callable. + pub(crate) callable_type: Type<'db>, + + /// The type we'll use for error messages referring to details of the called signature. For + /// calls to functions this will be the same as `callable_type`; for other callable instances + /// it may be a `__call__` method. + pub(crate) signature_type: Type<'db>, + + /// If this is a callable object (i.e. called via a `__call__` method), the boundness of + /// that call method. + pub(crate) dunder_call_is_possibly_unbound: bool, + + /// The type of the bound `self` or `cls` parameter if this signature is for a bound method. + pub(crate) bound_type: Option>, + + /// The signatures of each overload of this callable. Will be empty if the type is not + /// callable. + /// + /// By using `SmallVec`, we avoid an extra heap allocation for the common case of a + /// non-overloaded callable. + overloads: SmallVec<[Signature<'db>; 1]>, } impl<'db> CallableSignature<'db> { - /// Creates a new `CallableSignature` from an non-empty iterator of [`Signature`]s. - /// Panics if the iterator is empty. - pub(crate) fn from_overloads(overloads: I) -> Self - where - I: IntoIterator, - I::IntoIter: Iterator>, - { - let mut iter = overloads.into_iter(); - let first_overload = iter.next().expect("overloads should not be empty"); - let Some(second_overload) = iter.next() else { - return CallableSignature::Single(first_overload); - }; - let mut overloads = vec![first_overload, second_overload]; - overloads.extend(iter); - CallableSignature::Overloaded(overloads.into()) + pub(crate) fn not_callable(signature_type: Type<'db>) -> Self { + Self { + callable_type: signature_type, + signature_type, + dunder_call_is_possibly_unbound: false, + bound_type: None, + overloads: smallvec![], + } } - /// Returns the [`Signature`] if this is a non-overloaded callable, [None] otherwise. - pub(crate) fn as_single(&self) -> Option<&Signature<'db>> { - match self { - CallableSignature::Single(signature) => Some(signature), - CallableSignature::Overloaded(_) => None, + pub(crate) fn single(signature_type: Type<'db>, signature: Signature<'db>) -> Self { + Self { + callable_type: signature_type, + signature_type, + dunder_call_is_possibly_unbound: false, + bound_type: None, + overloads: smallvec![signature], } } - pub(crate) fn iter(&self) -> std::slice::Iter> { - match self { - CallableSignature::Single(signature) => std::slice::from_ref(signature).iter(), - CallableSignature::Overloaded(signatures) => signatures.iter(), + /// Creates a new `CallableSignature` from an iterator of [`Signature`]s. Returns a + /// non-callable signature if the iterator is empty. + pub(crate) fn from_overloads(signature_type: Type<'db>, overloads: I) -> Self + where + I: IntoIterator>, + { + Self { + callable_type: signature_type, + signature_type, + dunder_call_is_possibly_unbound: false, + bound_type: None, + overloads: overloads.into_iter().collect(), } } /// Return a signature for a dynamic callable - pub(crate) fn dynamic(ty: Type<'db>) -> Self { + pub(crate) fn dynamic(signature_type: Type<'db>) -> Self { let signature = Signature { parameters: Parameters::gradual_form(), - return_ty: Some(ty), + return_ty: Some(signature_type), }; - signature.into() + Self::single(signature_type, signature) } /// Return a todo signature: (*args: Todo, **kwargs: Todo) -> Todo #[allow(unused_variables)] // 'reason' only unused in debug builds pub(crate) fn todo(reason: &'static str) -> Self { + let signature_type = todo_type!(reason); let signature = Signature { parameters: Parameters::todo(), - return_ty: Some(todo_type!(reason)), + return_ty: Some(signature_type), }; - signature.into() + Self::single(signature_type, signature) + } + + pub(crate) fn with_bound_type(mut self, bound_type: Type<'db>) -> Self { + self.bound_type = Some(bound_type); + self + } + + fn replace_callable_type(&mut self, before: Type<'db>, after: Type<'db>) { + if self.callable_type == before { + self.callable_type = after; + } } } -impl<'db> From> for CallableSignature<'db> { - fn from(signature: Signature<'db>) -> Self { - CallableSignature::Single(signature) +impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> { + type Item = &'a Signature<'db>; + type IntoIter = std::slice::Iter<'a, Signature<'db>>; + + fn into_iter(self) -> Self::IntoIter { + self.overloads.iter() } } @@ -107,11 +220,20 @@ impl<'db> Signature<'db> { } } + /// Return a todo signature: (*args: Todo, **kwargs: Todo) -> Todo + #[allow(unused_variables)] // 'reason' only unused in debug builds + pub(crate) fn todo(reason: &'static str) -> Self { + Signature { + parameters: Parameters::todo(), + return_ty: Some(todo_type!(reason)), + } + } + /// Return a typed signature from a function definition. pub(super) fn from_function( db: &'db dyn Db, definition: Definition<'db>, - function_node: &'db ast::StmtFunctionDef, + function_node: &ast::StmtFunctionDef, ) -> Self { let return_ty = function_node.returns.as_ref().map(|returns| { if function_node.is_async { @@ -249,7 +371,7 @@ impl<'db> Parameters<'db> { fn from_parameters( db: &'db dyn Db, definition: Definition<'db>, - parameters: &'db ast::Parameters, + parameters: &ast::Parameters, ) -> Self { let ast::Parameters { posonlyargs, @@ -413,7 +535,7 @@ impl<'db> Parameter<'db> { fn from_node_and_kind( db: &'db dyn Db, definition: Definition<'db>, - parameter: &'db ast::Parameter, + parameter: &ast::Parameter, kind: ParameterKind<'db>, ) -> Self { Self { @@ -792,7 +914,7 @@ mod tests { .unwrap(); let func = get_function_f(&db, "/src/a.py"); - let expected_sig = func.internal_signature(&db).into(); + let expected_sig = func.internal_signature(&db); // With no decorators, internal and external signature are the same assert_eq!(func.signature(&db), &expected_sig); @@ -813,7 +935,7 @@ mod tests { .unwrap(); let func = get_function_f(&db, "/src/a.py"); - let expected_sig = CallableSignature::todo("return type of decorated function"); + let expected_sig = Signature::todo("return type of decorated function"); // With no decorators, internal and external signature are the same assert_eq!(func.signature(&db), &expected_sig);