From 84fdb04049d7d605b2b1c6d8f5d50f417ed1a605 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 25 Jan 2025 13:07:19 +0000 Subject: [PATCH 1/5] [red-knot] Decompose `bool` to `Literal[True, False]` in unions and intersections --- .../resources/mdtest/comparison/tuples.md | 17 ++- .../mdtest/exception/control_flow.md | 6 +- .../resources/mdtest/narrow/truthiness.md | 8 +- .../type_properties/is_equivalent_to.md | 44 ++++++++ crates/red_knot_python_semantic/src/types.rs | 71 ++++++++++-- .../src/types/builder.rs | 106 ++---------------- .../src/types/display.rs | 26 ++++- .../src/types/property_tests.rs | 59 +++++----- 8 files changed, 192 insertions(+), 145 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md index 8fe7f29541a9a..250ded168f8a7 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md @@ -58,7 +58,22 @@ reveal_type(c >= d) # revealed: Literal[True] #### Results with Ambiguity ```py -def _(x: bool, y: int): +class P: + def __lt__(self, other: "P") -> bool: + return True + + def __le__(self, other: "P") -> bool: + return True + + def __gt__(self, other: "P") -> bool: + return True + + def __ge__(self, other: "P") -> bool: + return True + +class Q(P): ... + +def _(x: P, y: Q): a = (x,) b = (y,) diff --git a/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md b/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md index 284b0f24d5620..a6e703dff5b6e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md +++ b/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md @@ -455,9 +455,9 @@ else: reveal_type(x) # revealed: slice finally: # TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice` - reveal_type(x) # revealed: bool | float | slice + reveal_type(x) # revealed: bool | slice | float -reveal_type(x) # revealed: bool | float | slice +reveal_type(x) # revealed: bool | slice | float ``` ## Nested `try`/`except` blocks @@ -534,7 +534,7 @@ try: reveal_type(x) # revealed: slice finally: # TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice` - reveal_type(x) # revealed: bool | float | slice + reveal_type(x) # revealed: bool | slice | float x = 2 reveal_type(x) # revealed: Literal[2] reveal_type(x) # revealed: Literal[2] diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md index b3975c1a813b7..b1129498ce176 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md @@ -21,22 +21,22 @@ else: if x and not x: reveal_type(x) # revealed: Never else: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None if not (x and not x): - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None else: reveal_type(x) # revealed: Never if x or not x: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None else: reveal_type(x) # revealed: Never if not (x or not x): reveal_type(x) # revealed: Never else: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None if (isinstance(x, int) or isinstance(x, str)) and x: reveal_type(x) # revealed: Literal[-1, True, "foo"] diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md index 047a45fc2fd0b..b607b05f1b663 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_equivalent_to.md @@ -118,4 +118,48 @@ class R: ... static_assert(is_equivalent_to(Intersection[tuple[P | Q], R], Intersection[tuple[Q | P], R])) ``` +## Unions containing tuples containing `bool` + +```py +from knot_extensions import is_equivalent_to, static_assert +from typing_extensions import Literal + +class P: ... + +static_assert(is_equivalent_to(tuple[Literal[True, False]] | P, tuple[bool] | P)) +static_assert(is_equivalent_to(P | tuple[bool], P | tuple[Literal[True, False]])) +``` + +## Unions and intersections involving `AlwaysTruthy`, `bool` and `AlwaysFalsy` + +```py +from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not +from typing_extensions import Literal + +static_assert(is_equivalent_to(AlwaysTruthy | bool, Literal[False] | AlwaysTruthy)) +static_assert(is_equivalent_to(AlwaysFalsy | bool, Literal[True] | AlwaysFalsy)) +static_assert(is_equivalent_to(Not[AlwaysTruthy] | bool, Not[AlwaysTruthy] | Literal[True])) +static_assert(is_equivalent_to(Not[AlwaysFalsy] | bool, Literal[False] | Not[AlwaysFalsy])) +``` + +## Unions and intersections involving `AlwaysTruthy`, `LiteralString` and `AlwaysFalsy` + +```py +from knot_extensions import AlwaysTruthy, AlwaysFalsy, static_assert, is_equivalent_to, Not, Intersection +from typing_extensions import Literal, LiteralString + +# TODO: these should all pass! + +# error: [static-assert-error] +static_assert(is_equivalent_to(AlwaysTruthy | LiteralString, Literal[""] | AlwaysTruthy)) +# error: [static-assert-error] +static_assert(is_equivalent_to(AlwaysFalsy | LiteralString, Intersection[LiteralString, Not[Literal[""]]] | AlwaysFalsy)) +# error: [static-assert-error] +static_assert(is_equivalent_to(Not[AlwaysFalsy] | LiteralString, Literal[""] | Not[AlwaysFalsy])) +# error: [static-assert-error] +static_assert( + is_equivalent_to(Not[AlwaysTruthy] | LiteralString, Not[AlwaysTruthy] | Intersection[LiteralString, Not[Literal[""]]]) +) +``` + [the equivalence relation]: https://typing.readthedocs.io/en/latest/spec/glossary.html#term-equivalent diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index a056c176c7cd0..893ad9f9f738c 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -811,6 +811,31 @@ impl<'db> Type<'db> { } } + /// Normalize the type `bool` -> `Literal[True, False]`. + /// + /// Using this method in various type-relational methods + /// ensures that the following invariants hold true: + /// + /// - bool ≡ Literal[True, False] + /// - bool | T ≡ Literal[True, False] | T + /// - bool <: Literal[True, False] + /// - bool | T <: Literal[True, False] | T + /// - Literal[True, False] <: bool + /// - Literal[True, False] | T <: bool | T + #[must_use] + pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self { + const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; + + match self { + Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => { + Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS))) + } + // TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`? + // We'd need to rename this method... --Alex + _ => self, + } + } + /// Return a normalized version of `self` in which all unions and intersections are sorted /// according to a canonical order, no matter how "deeply" a union/intersection may be nested. #[must_use] @@ -859,6 +884,12 @@ impl<'db> Type<'db> { return true; } + let normalized_self = self.with_normalized_bools(db); + let normalized_target = target.with_normalized_bools(db); + if normalized_self != self || normalized_target != target { + return normalized_self.is_subtype_of(db, normalized_target); + } + // Non-fully-static types do not participate in subtyping. // // Type `A` can only be a subtype of type `B` if the set of possible runtime objects @@ -961,7 +992,7 @@ impl<'db> Type<'db> { KnownClass::Str.to_instance(db).is_subtype_of(db, target) } (Type::BooleanLiteral(_), _) => { - KnownClass::Bool.to_instance(db).is_subtype_of(db, target) + KnownClass::Int.to_instance(db).is_subtype_of(db, target) } (Type::IntLiteral(_), _) => KnownClass::Int.to_instance(db).is_subtype_of(db, target), (Type::BytesLiteral(_), _) => { @@ -1077,6 +1108,11 @@ impl<'db> Type<'db> { if self.is_gradual_equivalent_to(db, target) { return true; } + let normalized_self = self.with_normalized_bools(db); + let normalized_target = target.with_normalized_bools(db); + if normalized_self != self || normalized_target != target { + return normalized_self.is_assignable_to(db, normalized_target); + } match (self, target) { // Never can be assigned to any type. (Type::Never, _) => true, @@ -1177,6 +1213,13 @@ impl<'db> Type<'db> { pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { // TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc. + let normalized_self = self.with_normalized_bools(db); + let normalized_other = other.with_normalized_bools(db); + + if normalized_self != self || normalized_other != other { + return normalized_self.is_equivalent_to(db, normalized_other); + } + match (self, other) { (Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right), (Type::Intersection(left), Type::Intersection(right)) => { @@ -1218,6 +1261,13 @@ impl<'db> Type<'db> { /// /// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { + let normalized_self = self.with_normalized_bools(db); + let normalized_other = other.with_normalized_bools(db); + + if normalized_self != self || normalized_other != other { + return normalized_self.is_gradual_equivalent_to(db, normalized_other); + } + if self == other { return true; } @@ -1250,6 +1300,12 @@ impl<'db> Type<'db> { /// Note: This function aims to have no false positives, but might return /// wrong `false` answers in some cases. pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool { + let normalized_self = self.with_normalized_bools(db); + let normalized_other = other.with_normalized_bools(db); + if normalized_self != self || normalized_other != other { + return normalized_self.is_disjoint_from(db, normalized_other); + } + match (self, other) { (Type::Never, _) | (_, Type::Never) => true, @@ -4642,18 +4698,19 @@ pub struct TupleType<'db> { } impl<'db> TupleType<'db> { - pub fn from_elements>>( - db: &'db dyn Db, - types: impl IntoIterator, - ) -> Type<'db> { + pub fn from_elements(db: &'db dyn Db, types: I) -> Type<'db> + where + I: IntoIterator, + T: Into>, + { let mut elements = vec![]; for ty in types { - let ty = ty.into(); + let ty: Type<'db> = ty.into(); if ty.is_never() { return Type::Never; } - elements.push(ty); + elements.push(ty.with_normalized_bools(db)); } Type::Tuple(Self::new(db, elements.into_boxed_slice())) diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index c19a4f06cefa4..51cefb7db3323 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -26,7 +26,7 @@ //! eliminate the supertype from the intersection). //! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. -use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType}; +use crate::types::{IntersectionType, KnownClass, Type, UnionType}; use crate::{Db, FxOrderSet}; use smallvec::SmallVec; @@ -45,6 +45,7 @@ impl<'db> UnionBuilder<'db> { /// Adds a type to this union. pub(crate) fn add(mut self, ty: Type<'db>) -> Self { + let ty = ty.with_normalized_bools(self.db); match ty { Type::Union(union) => { let new_elements = union.elements(self.db); @@ -55,27 +56,10 @@ impl<'db> UnionBuilder<'db> { } Type::Never => {} _ => { - let bool_pair = if let Type::BooleanLiteral(b) = ty { - Some(Type::BooleanLiteral(!b)) - } else { - None - }; - - let mut to_add = ty; let mut to_remove = SmallVec::<[usize; 2]>::new(); let ty_negated = ty.negate(self.db); for (index, element) in self.elements.iter().enumerate() { - if Some(*element) == bool_pair { - to_add = KnownClass::Bool.to_instance(self.db); - to_remove.push(index); - // The type we are adding is a BooleanLiteral, which doesn't have any - // subtypes. And we just found that the union already contained our - // mirror-image BooleanLiteral, so it can't also contain bool or any - // supertype of bool. Therefore, we are done. - break; - } - if ty.is_same_gradual_form(*element) || ty.is_subtype_of(self.db, *element) { return self; } else if element.is_subtype_of(self.db, ty) { @@ -94,8 +78,8 @@ impl<'db> UnionBuilder<'db> { } } match to_remove[..] { - [] => self.elements.push(to_add), - [index] => self.elements[index] = to_add, + [] => self.elements.push(ty), + [index] => self.elements[index] = ty, _ => { let mut current_index = 0; let mut to_remove = to_remove.into_iter(); @@ -110,7 +94,7 @@ impl<'db> UnionBuilder<'db> { current_index += 1; retain }); - self.elements.push(to_add); + self.elements.push(ty); } } } @@ -154,6 +138,7 @@ impl<'db> IntersectionBuilder<'db> { } pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self { + let ty = ty.with_normalized_bools(self.db); if let Type::Union(union) = ty { // Distribute ourself over this union: for each union element, clone ourself and // intersect with that union element, then create a new union-of-intersections with all @@ -183,6 +168,9 @@ impl<'db> IntersectionBuilder<'db> { pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self { // See comments above in `add_positive`; this is just the negated version. + + let ty = ty.with_normalized_bools(self.db); + if let Type::Union(union) = ty { for elem in union.elements(self.db) { self = self.add_negative(*elem); @@ -246,7 +234,7 @@ struct InnerIntersectionBuilder<'db> { impl<'db> InnerIntersectionBuilder<'db> { /// Adds a positive type to this intersection. - fn add_positive(&mut self, db: &'db dyn Db, mut new_positive: Type<'db>) { + fn add_positive(&mut self, db: &'db dyn Db, new_positive: Type<'db>) { match new_positive { // `LiteralString & AlwaysTruthy` -> `LiteralString & ~Literal[""]` Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => { @@ -293,62 +281,6 @@ impl<'db> InnerIntersectionBuilder<'db> { return; } - let addition_is_bool_instance = known_instance == Some(KnownClass::Bool); - - for (index, existing_positive) in self.positive.iter().enumerate() { - match existing_positive { - // `AlwaysTruthy & bool` -> `Literal[True]` - Type::AlwaysTruthy if addition_is_bool_instance => { - new_positive = Type::BooleanLiteral(true); - } - // `AlwaysFalsy & bool` -> `Literal[False]` - Type::AlwaysFalsy if addition_is_bool_instance => { - new_positive = Type::BooleanLiteral(false); - } - Type::Instance(InstanceType { class }) - if class.is_known(db, KnownClass::Bool) => - { - match new_positive { - // `bool & AlwaysTruthy` -> `Literal[True]` - Type::AlwaysTruthy => { - new_positive = Type::BooleanLiteral(true); - } - // `bool & AlwaysFalsy` -> `Literal[False]` - Type::AlwaysFalsy => { - new_positive = Type::BooleanLiteral(false); - } - _ => continue, - } - } - _ => continue, - } - self.positive.swap_remove_index(index); - break; - } - - if addition_is_bool_instance { - for (index, existing_negative) in self.negative.iter().enumerate() { - match existing_negative { - // `bool & ~Literal[False]` -> `Literal[True]` - // `bool & ~Literal[True]` -> `Literal[False]` - Type::BooleanLiteral(bool_value) => { - new_positive = Type::BooleanLiteral(!bool_value); - } - // `bool & ~AlwaysTruthy` -> `Literal[False]` - Type::AlwaysTruthy => { - new_positive = Type::BooleanLiteral(false); - } - // `bool & ~AlwaysFalsy` -> `Literal[True]` - Type::AlwaysFalsy => { - new_positive = Type::BooleanLiteral(true); - } - _ => continue, - } - self.negative.swap_remove_index(index); - break; - } - } - let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_positive) in self.positive.iter().enumerate() { // S & T = S if S <: T @@ -396,14 +328,6 @@ impl<'db> InnerIntersectionBuilder<'db> { /// Adds a negative type to this intersection. fn add_negative(&mut self, db: &'db dyn Db, new_negative: Type<'db>) { - let contains_bool = || { - self.positive - .iter() - .filter_map(|ty| ty.into_instance()) - .filter_map(|instance| instance.class.known(db)) - .any(KnownClass::is_bool) - }; - match new_negative { Type::Intersection(inter) => { for pos in inter.positive(db) { @@ -427,20 +351,10 @@ impl<'db> InnerIntersectionBuilder<'db> { // simplify the representation. self.add_positive(db, ty); } - // `bool & ~AlwaysTruthy` -> `bool & Literal[False]` - // `bool & ~Literal[True]` -> `bool & Literal[False]` - Type::AlwaysTruthy | Type::BooleanLiteral(true) if contains_bool() => { - self.add_positive(db, Type::BooleanLiteral(false)); - } // `LiteralString & ~AlwaysTruthy` -> `LiteralString & Literal[""]` Type::AlwaysTruthy if self.positive.contains(&Type::LiteralString) => { self.add_positive(db, Type::string_literal(db, "")); } - // `bool & ~AlwaysFalsy` -> `bool & Literal[True]` - // `bool & ~Literal[False]` -> `bool & Literal[True]` - Type::AlwaysFalsy | Type::BooleanLiteral(false) if contains_bool() => { - self.add_positive(db, Type::BooleanLiteral(true)); - } // `LiteralString & ~AlwaysFalsy` -> `LiteralString & ~Literal[""]` Type::AlwaysFalsy if self.positive.contains(&Type::LiteralString) => { self.add_negative(db, Type::string_literal(db, "")); diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 50c3ae1fff8cb..8769c8e021242 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -1,5 +1,6 @@ //! Display implementations for types. +use std::borrow::Cow; use std::fmt::{self, Display, Formatter, Write}; use ruff_db::display::FormatterJoinExtension; @@ -151,12 +152,31 @@ struct DisplayUnionType<'db> { impl Display for DisplayUnionType<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let elements = self.ty.elements(self.db); + let mut elements = Cow::Borrowed(self.ty.elements(self.db)); + + if let Some(literal_false_pos) = elements + .iter() + .position(|ty| matches!(ty, Type::BooleanLiteral(false))) + { + if let Some(literal_true_pos) = elements + .iter() + .position(|ty| matches!(ty, Type::BooleanLiteral(true))) + { + let (min, max) = if literal_false_pos < literal_true_pos { + (literal_false_pos, literal_true_pos) + } else { + (literal_true_pos, literal_false_pos) + }; + let mutable_elements = elements.to_mut(); + mutable_elements.swap_remove(max); + mutable_elements[min] = KnownClass::Bool.to_instance(self.db); + } + } // Group condensed-display types by kind. let mut grouped_condensed_kinds = FxHashMap::default(); - for element in elements { + for element in &*elements { if let Ok(kind) = CondensedDisplayTypeKind::try_from(*element) { grouped_condensed_kinds .entry(kind) @@ -167,7 +187,7 @@ impl Display for DisplayUnionType<'_> { let mut join = f.join(" | "); - for element in elements { + for element in &*elements { if let Ok(kind) = CondensedDisplayTypeKind::try_from(*element) { let Some(condensed_kind) = grouped_condensed_kinds.remove(&kind) else { continue; diff --git a/crates/red_knot_python_semantic/src/types/property_tests.rs b/crates/red_knot_python_semantic/src/types/property_tests.rs index 5f6e8edf06ce3..f4471826cc4d4 100644 --- a/crates/red_knot_python_semantic/src/types/property_tests.rs +++ b/crates/red_knot_python_semantic/src/types/property_tests.rs @@ -328,8 +328,9 @@ fn union<'db>(db: &'db TestDb, tys: impl IntoIterator>) -> Type } mod stable { - use super::union; + use super::{intersection, union}; use crate::types::{KnownClass, Type}; + use itertools::Itertools; // Reflexivity: `T` is equivalent to itself. type_property_test!( @@ -474,6 +475,32 @@ mod stable { all_type_pairs_are_assignable_to_their_union, db, forall types s, t. s.is_assignable_to(db, union(db, [s, t])) && t.is_assignable_to(db, union(db, [s, t])) ); + + // Equal element sets of intersections implies equivalence + type_property_test!( + intersection_equivalence_not_order_dependent, db, + forall types s, t, u. + s.is_fully_static(db) && t.is_fully_static(db) && u.is_fully_static(db) + => [s, t, u] + .into_iter() + .permutations(3) + .map(|trio_of_types| intersection(db, trio_of_types)) + .permutations(2) + .all(|vec_of_intersections| vec_of_intersections[0].is_equivalent_to(db, vec_of_intersections[1])) + ); + + // Equal element sets of unions implies equivalence + type_property_test!( + union_equivalence_not_order_dependent, db, + forall types s, t, u. + s.is_fully_static(db) && t.is_fully_static(db) && u.is_fully_static(db) + => [s, t, u] + .into_iter() + .permutations(3) + .map(|trio_of_types| union(db, trio_of_types)) + .permutations(2) + .all(|vec_of_unions| vec_of_unions[0].is_equivalent_to(db, vec_of_unions[1])) + ); } /// This module contains property tests that currently lead to many false positives. @@ -484,8 +511,6 @@ mod stable { /// tests to the `stable` section. In the meantime, it can still be useful to run these /// tests (using [`types::property_tests::flaky`]), to see if there are any new obvious bugs. mod flaky { - use itertools::Itertools; - use super::{intersection, union}; // Negating `T` twice is equivalent to `T`. @@ -522,34 +547,6 @@ mod flaky { forall types s, t. intersection(db, [s, t]).is_assignable_to(db, s) && intersection(db, [s, t]).is_assignable_to(db, t) ); - // Equal element sets of intersections implies equivalence - // flaky at least in part because of https://github.com/astral-sh/ruff/issues/15513 - type_property_test!( - intersection_equivalence_not_order_dependent, db, - forall types s, t, u. - s.is_fully_static(db) && t.is_fully_static(db) && u.is_fully_static(db) - => [s, t, u] - .into_iter() - .permutations(3) - .map(|trio_of_types| intersection(db, trio_of_types)) - .permutations(2) - .all(|vec_of_intersections| vec_of_intersections[0].is_equivalent_to(db, vec_of_intersections[1])) - ); - - // Equal element sets of unions implies equivalence - // flaky at laest in part because of https://github.com/astral-sh/ruff/issues/15513 - type_property_test!( - union_equivalence_not_order_dependent, db, - forall types s, t, u. - s.is_fully_static(db) && t.is_fully_static(db) && u.is_fully_static(db) - => [s, t, u] - .into_iter() - .permutations(3) - .map(|trio_of_types| union(db, trio_of_types)) - .permutations(2) - .all(|vec_of_unions| vec_of_unions[0].is_equivalent_to(db, vec_of_unions[1])) - ); - // `S | T` is always a supertype of `S`. // Thus, `S` is never disjoint from `S | T`. type_property_test!( From 2851098c36b9f9ff31f632e9a1dba006ce5d1100 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Mon, 27 Jan 2025 11:40:28 +0000 Subject: [PATCH 2/5] Add missing test coverage for the changes to `is_subtype_of` and `is_assignable_to` --- .../type_properties/is_assignable_to.md | 12 +++++++++ .../mdtest/type_properties/is_subtype_of.md | 26 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index 06ee56aaf0d5f..e51cc4dcaae8e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -346,4 +346,16 @@ static_assert(is_assignable_to(Never, type[str])) static_assert(is_assignable_to(Never, type[Any])) ``` +### `bool` is assignable to unions that include `bool` + +Since we decompose `bool` to `Literal[True, False]` in unions, it would be surprisingly easy to get +this wrong if we forgot to normalize `bool` to `Literal[True, False]` when it appeared on the +left-hand side in `Type::is_assignable_to()`. + +```py +from knot_extensions import is_assignable_to, static_assert + +static_assert(is_assignable_to(bool, str | bool)) +``` + [typing documentation]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index 4e081338d38bc..c120023f819a3 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -449,5 +449,31 @@ static_assert(not is_subtype_of(Intersection[Unknown, int], int)) static_assert(not is_subtype_of(tuple[int, int], tuple[int, Unknown])) ``` +## `bool` is a subtype of `AlwaysTruthy | AlwaysFalsy` + +`bool` is equivalent to `Literal[True] | Literal[False]`. `Literal[True]` is a subtype of +`AlwaysTruthy` and `Literal[False]` is a subtype of `AlwaysFalsy`; it therefore stands to reason +that `bool` is a subtype of `AlwaysTruthy | AlwaysFalsy`. + +```py +from knot_extensions import AlwaysTruthy, AlwaysFalsy, is_subtype_of, static_assert, Not, is_disjoint_from +from typing_extensions import Literal + +static_assert(is_subtype_of(bool, AlwaysTruthy | AlwaysFalsy)) + +# the inverse also applies -- TODO: this should pass! +# See the TODO comments in the `Type::Intersection` branch of `Type::is_disjoint_from()`. +static_assert(is_disjoint_from(bool, Not[AlwaysTruthy | AlwaysFalsy])) # error: [static-assert-error] + +# `Type::is_subtype_of` delegates many questions of `bool` subtyping to `int`, +# but set-theoretic types like intersections and unions are still handled differently to `int` +static_assert(is_subtype_of(Literal[True], Not[Literal[2]])) +static_assert(is_subtype_of(bool, Not[Literal[2]])) +static_assert(is_subtype_of(Literal[True], bool | None)) +static_assert(is_subtype_of(bool, bool | None)) + +static_assert(not is_subtype_of(int, Not[Literal[2]])) +``` + [special case for float and complex]: https://typing.readthedocs.io/en/latest/spec/special-types.html#special-cases-for-float-and-complex [typing documentation]: https://typing.readthedocs.io/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence From 80136d81f680ec82e331fa419c850e178b91d1ea Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 28 Jan 2025 12:15:34 +0000 Subject: [PATCH 3/5] take 7? --- .../resources/mdtest/comparison/tuples.md | 17 +---- .../mdtest/exception/control_flow.md | 6 +- .../resources/mdtest/narrow/truthiness.md | 8 +-- .../type_properties/is_assignable_to.md | 8 +++ crates/red_knot_python_semantic/src/types.rs | 70 ++++++++----------- .../src/types/builder.rs | 31 +++++++- .../src/types/display.rs | 26 +------ 7 files changed, 76 insertions(+), 90 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md index 250ded168f8a7..8fe7f29541a9a 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/tuples.md @@ -58,22 +58,7 @@ reveal_type(c >= d) # revealed: Literal[True] #### Results with Ambiguity ```py -class P: - def __lt__(self, other: "P") -> bool: - return True - - def __le__(self, other: "P") -> bool: - return True - - def __gt__(self, other: "P") -> bool: - return True - - def __ge__(self, other: "P") -> bool: - return True - -class Q(P): ... - -def _(x: P, y: Q): +def _(x: bool, y: int): a = (x,) b = (y,) diff --git a/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md b/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md index a6e703dff5b6e..284b0f24d5620 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md +++ b/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md @@ -455,9 +455,9 @@ else: reveal_type(x) # revealed: slice finally: # TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice` - reveal_type(x) # revealed: bool | slice | float + reveal_type(x) # revealed: bool | float | slice -reveal_type(x) # revealed: bool | slice | float +reveal_type(x) # revealed: bool | float | slice ``` ## Nested `try`/`except` blocks @@ -534,7 +534,7 @@ try: reveal_type(x) # revealed: slice finally: # TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice` - reveal_type(x) # revealed: bool | slice | float + reveal_type(x) # revealed: bool | float | slice x = 2 reveal_type(x) # revealed: Literal[2] reveal_type(x) # revealed: Literal[2] diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md index b1129498ce176..b3975c1a813b7 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md @@ -21,22 +21,22 @@ else: if x and not x: reveal_type(x) # revealed: Never else: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None + reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] if not (x and not x): - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None + reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] else: reveal_type(x) # revealed: Never if x or not x: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None + reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] else: reveal_type(x) # revealed: Never if not (x or not x): reveal_type(x) # revealed: Never else: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | tuple[()] | None + reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] if (isinstance(x, int) or isinstance(x, str)) and x: reveal_type(x) # revealed: Literal[-1, True, "foo"] diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md index e51cc4dcaae8e..318af242476d4 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_assignable_to.md @@ -358,4 +358,12 @@ from knot_extensions import is_assignable_to, static_assert static_assert(is_assignable_to(bool, str | bool)) ``` +### `bool` is assignable to `AlwaysTruthy | AlwaysFalsy` + +```py +from knot_extensions import static_assert, is_assignable_to, AlwaysTruthy, AlwaysFalsy + +static_assert(is_assignable_to(bool, AlwaysTruthy | AlwaysFalsy)) +``` + [typing documentation]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 893ad9f9f738c..2b87e14b8cb1c 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -824,11 +824,9 @@ impl<'db> Type<'db> { /// - Literal[True, False] | T <: bool | T #[must_use] pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self { - const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; - match self { Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => { - Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS))) + Type::normalized_bool(db) } // TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`? // We'd need to rename this method... --Alex @@ -884,12 +882,6 @@ impl<'db> Type<'db> { return true; } - let normalized_self = self.with_normalized_bools(db); - let normalized_target = target.with_normalized_bools(db); - if normalized_self != self || normalized_target != target { - return normalized_self.is_subtype_of(db, normalized_target); - } - // Non-fully-static types do not participate in subtyping. // // Type `A` can only be a subtype of type `B` if the set of possible runtime objects @@ -912,6 +904,13 @@ impl<'db> Type<'db> { (Type::Never, _) => true, (_, Type::Never) => false, + (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { + Type::normalized_bool(db).is_subtype_of(db, target) + } + (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { + self.is_boolean_literal() + } + (Type::Union(union), _) => union .elements(db) .iter() @@ -1108,11 +1107,7 @@ impl<'db> Type<'db> { if self.is_gradual_equivalent_to(db, target) { return true; } - let normalized_self = self.with_normalized_bools(db); - let normalized_target = target.with_normalized_bools(db); - if normalized_self != self || normalized_target != target { - return normalized_self.is_assignable_to(db, normalized_target); - } + match (self, target) { // Never can be assigned to any type. (Type::Never, _) => true, @@ -1129,6 +1124,13 @@ impl<'db> Type<'db> { true } + (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { + Type::normalized_bool(db).is_assignable_to(db, target) + } + (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { + self.is_assignable_to(db, Type::normalized_bool(db)) + } + // A union is assignable to a type T iff every element of the union is assignable to T. (Type::Union(union), ty) => union .elements(db) @@ -1213,13 +1215,6 @@ impl<'db> Type<'db> { pub(crate) fn is_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { // TODO equivalent but not identical types: TypedDicts, Protocols, type aliases, etc. - let normalized_self = self.with_normalized_bools(db); - let normalized_other = other.with_normalized_bools(db); - - if normalized_self != self || normalized_other != other { - return normalized_self.is_equivalent_to(db, normalized_other); - } - match (self, other) { (Type::Union(left), Type::Union(right)) => left.is_equivalent_to(db, right), (Type::Intersection(left), Type::Intersection(right)) => { @@ -1261,13 +1256,6 @@ impl<'db> Type<'db> { /// /// [Summary of type relations]: https://typing.readthedocs.io/en/latest/spec/concepts.html#summary-of-type-relations pub(crate) fn is_gradual_equivalent_to(self, db: &'db dyn Db, other: Type<'db>) -> bool { - let normalized_self = self.with_normalized_bools(db); - let normalized_other = other.with_normalized_bools(db); - - if normalized_self != self || normalized_other != other { - return normalized_self.is_gradual_equivalent_to(db, normalized_other); - } - if self == other { return true; } @@ -1300,12 +1288,6 @@ impl<'db> Type<'db> { /// Note: This function aims to have no false positives, but might return /// wrong `false` answers in some cases. pub(crate) fn is_disjoint_from(self, db: &'db dyn Db, other: Type<'db>) -> bool { - let normalized_self = self.with_normalized_bools(db); - let normalized_other = other.with_normalized_bools(db); - if normalized_self != self || normalized_other != other { - return normalized_self.is_disjoint_from(db, normalized_other); - } - match (self, other) { (Type::Never, _) | (_, Type::Never) => true, @@ -2427,6 +2409,13 @@ impl<'db> Type<'db> { KnownClass::NoneType.to_instance(db) } + /// The type `Literal[True, False]`, which is exactly equivalent to `bool` + /// (and which `bool` is eagerly normalized to in several situations) + pub fn normalized_bool(db: &'db dyn Db) -> Type<'db> { + const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; + Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS))) + } + /// Return the type of `tuple(sys.version_info)`. /// /// This is not exactly the type that `sys.version_info` has at runtime, @@ -4698,19 +4687,18 @@ pub struct TupleType<'db> { } impl<'db> TupleType<'db> { - pub fn from_elements(db: &'db dyn Db, types: I) -> Type<'db> - where - I: IntoIterator, - T: Into>, - { + pub fn from_elements>>( + db: &'db dyn Db, + types: impl IntoIterator, + ) -> Type<'db> { let mut elements = vec![]; for ty in types { - let ty: Type<'db> = ty.into(); + let ty = ty.into(); if ty.is_never() { return Type::Never; } - elements.push(ty.with_normalized_bools(db)); + elements.push(ty); } Type::Tuple(Self::new(db, elements.into_boxed_slice())) diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 51cefb7db3323..193a63a1a3a7b 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -103,10 +103,35 @@ impl<'db> UnionBuilder<'db> { } pub(crate) fn build(self) -> Type<'db> { - match self.elements.len() { + let UnionBuilder { elements, db } = self; + + match elements.len() { 0 => Type::Never, - 1 => self.elements[0], - _ => Type::Union(UnionType::new(self.db, self.elements.into_boxed_slice())), + 1 => elements[0], + _ => { + let mut normalized_elements = Vec::with_capacity(elements.len()); + let mut first_bool_literal_pos = None; + let mut seen_two_bool_literals = false; + for (i, element) in elements.into_iter().enumerate() { + if element.is_boolean_literal() { + if first_bool_literal_pos.is_none() { + first_bool_literal_pos = Some(i); + } else { + seen_two_bool_literals = true; + continue; + } + } + normalized_elements.push(element); + } + if let (Some(pos), true) = (first_bool_literal_pos, seen_two_bool_literals) { + // If we have two boolean literals, we can merge them to `bool`. + if normalized_elements.len() == 1 { + return KnownClass::Bool.to_instance(db); + } + normalized_elements[pos] = KnownClass::Bool.to_instance(db); + } + Type::Union(UnionType::new(db, normalized_elements.into_boxed_slice())) + } } } } diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 8769c8e021242..50c3ae1fff8cb 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -1,6 +1,5 @@ //! Display implementations for types. -use std::borrow::Cow; use std::fmt::{self, Display, Formatter, Write}; use ruff_db::display::FormatterJoinExtension; @@ -152,31 +151,12 @@ struct DisplayUnionType<'db> { impl Display for DisplayUnionType<'_> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let mut elements = Cow::Borrowed(self.ty.elements(self.db)); - - if let Some(literal_false_pos) = elements - .iter() - .position(|ty| matches!(ty, Type::BooleanLiteral(false))) - { - if let Some(literal_true_pos) = elements - .iter() - .position(|ty| matches!(ty, Type::BooleanLiteral(true))) - { - let (min, max) = if literal_false_pos < literal_true_pos { - (literal_false_pos, literal_true_pos) - } else { - (literal_true_pos, literal_false_pos) - }; - let mutable_elements = elements.to_mut(); - mutable_elements.swap_remove(max); - mutable_elements[min] = KnownClass::Bool.to_instance(self.db); - } - } + let elements = self.ty.elements(self.db); // Group condensed-display types by kind. let mut grouped_condensed_kinds = FxHashMap::default(); - for element in &*elements { + for element in elements { if let Ok(kind) = CondensedDisplayTypeKind::try_from(*element) { grouped_condensed_kinds .entry(kind) @@ -187,7 +167,7 @@ impl Display for DisplayUnionType<'_> { let mut join = f.join(" | "); - for element in &*elements { + for element in elements { if let Ok(kind) = CondensedDisplayTypeKind::try_from(*element) { let Some(condensed_kind) = grouped_condensed_kinds.remove(&kind) else { continue; From 91784bbe93bee8fd4ba4e53f9a3337a5ca11da48 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 28 Jan 2025 13:39:48 +0000 Subject: [PATCH 4/5] is this faster --- .../mdtest/exception/control_flow.md | 6 +-- .../resources/mdtest/narrow/truthiness.md | 14 +++---- .../src/types/builder.rs | 39 +++++++++---------- 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md b/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md index 284b0f24d5620..a6e703dff5b6e 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md +++ b/crates/red_knot_python_semantic/resources/mdtest/exception/control_flow.md @@ -455,9 +455,9 @@ else: reveal_type(x) # revealed: slice finally: # TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice` - reveal_type(x) # revealed: bool | float | slice + reveal_type(x) # revealed: bool | slice | float -reveal_type(x) # revealed: bool | float | slice +reveal_type(x) # revealed: bool | slice | float ``` ## Nested `try`/`except` blocks @@ -534,7 +534,7 @@ try: reveal_type(x) # revealed: slice finally: # TODO: should be `Literal[1] | str | bytes | bool | memoryview | float | range | slice` - reveal_type(x) # revealed: bool | float | slice + reveal_type(x) # revealed: bool | slice | float x = 2 reveal_type(x) # revealed: Literal[2] reveal_type(x) # revealed: Literal[2] diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md index b3975c1a813b7..203ffe9827aa4 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md @@ -11,37 +11,37 @@ x = foo() if x: reveal_type(x) # revealed: Literal[-1, True, "foo", b"bar"] else: - reveal_type(x) # revealed: Literal[0, False, "", b""] | None | tuple[()] + reveal_type(x) # revealed: Literal[0, False, "", b""] | tuple[()] | None if not x: - reveal_type(x) # revealed: Literal[0, False, "", b""] | None | tuple[()] + reveal_type(x) # revealed: Literal[0, False, "", b""] | tuple[()] | None else: reveal_type(x) # revealed: Literal[-1, True, "foo", b"bar"] if x and not x: reveal_type(x) # revealed: Never else: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[0, -1, b"bar", "", "foo", b""] | bool | tuple[()] | None if not (x and not x): - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[0, -1, b"bar", "", "foo", b""] | bool | tuple[()] | None else: reveal_type(x) # revealed: Never if x or not x: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[0, -1, b"bar", "", "foo", b""] | bool | tuple[()] | None else: reveal_type(x) # revealed: Never if not (x or not x): reveal_type(x) # revealed: Never else: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[0, -1, b"bar", "", "foo", b""] | bool | tuple[()] | None if (isinstance(x, int) or isinstance(x, str)) and x: reveal_type(x) # revealed: Literal[-1, True, "foo"] else: - reveal_type(x) # revealed: Literal[b"", b"bar", 0, False, ""] | None | tuple[()] + reveal_type(x) # revealed: tuple[()] | None | Literal[b"", b"bar", 0, False, ""] ``` ## Function Literals diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 193a63a1a3a7b..49ed137a96181 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -33,6 +33,7 @@ use smallvec::SmallVec; pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, + contains_bool_literals: bool, } impl<'db> UnionBuilder<'db> { @@ -40,6 +41,7 @@ impl<'db> UnionBuilder<'db> { Self { db, elements: vec![], + contains_bool_literals: false, } } @@ -77,6 +79,7 @@ impl<'db> UnionBuilder<'db> { return self; } } + self.contains_bool_literals |= ty.is_boolean_literal(); match to_remove[..] { [] => self.elements.push(ty), [index] => self.elements[index] = ty, @@ -103,34 +106,30 @@ impl<'db> UnionBuilder<'db> { } pub(crate) fn build(self) -> Type<'db> { - let UnionBuilder { elements, db } = self; + let UnionBuilder { + mut elements, + db, + contains_bool_literals, + } = self; match elements.len() { 0 => Type::Never, 1 => elements[0], _ => { - let mut normalized_elements = Vec::with_capacity(elements.len()); - let mut first_bool_literal_pos = None; - let mut seen_two_bool_literals = false; - for (i, element) in elements.into_iter().enumerate() { - if element.is_boolean_literal() { - if first_bool_literal_pos.is_none() { - first_bool_literal_pos = Some(i); - } else { - seen_two_bool_literals = true; - continue; + if contains_bool_literals { + let mut element_iter = elements.iter(); + if let Some(first_pos) = element_iter.position(Type::is_boolean_literal) { + if let Some(second_pos) = element_iter.position(Type::is_boolean_literal) { + let bool_instance = KnownClass::Bool.to_instance(db); + if elements.len() == 2 { + return bool_instance; + } + elements.swap_remove(first_pos + second_pos + 1); + elements[first_pos] = bool_instance; } } - normalized_elements.push(element); - } - if let (Some(pos), true) = (first_bool_literal_pos, seen_two_bool_literals) { - // If we have two boolean literals, we can merge them to `bool`. - if normalized_elements.len() == 1 { - return KnownClass::Bool.to_instance(db); - } - normalized_elements[pos] = KnownClass::Bool.to_instance(db); } - Type::Union(UnionType::new(db, normalized_elements.into_boxed_slice())) + Type::Union(UnionType::new(db, elements.into_boxed_slice())) } } } From 1ef2f73820271c04b60f9f764e1a7ce0facc0501 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 28 Jan 2025 18:32:53 +0000 Subject: [PATCH 5/5] any better --- crates/red_knot_python_semantic/src/types.rs | 42 +----- .../src/types/builder.rs | 140 ++++++++++++------ 2 files changed, 100 insertions(+), 82 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 2b87e14b8cb1c..e6e2f9eea7699 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -811,29 +811,6 @@ impl<'db> Type<'db> { } } - /// Normalize the type `bool` -> `Literal[True, False]`. - /// - /// Using this method in various type-relational methods - /// ensures that the following invariants hold true: - /// - /// - bool ≡ Literal[True, False] - /// - bool | T ≡ Literal[True, False] | T - /// - bool <: Literal[True, False] - /// - bool | T <: Literal[True, False] | T - /// - Literal[True, False] <: bool - /// - Literal[True, False] | T <: bool | T - #[must_use] - pub fn with_normalized_bools(self, db: &'db dyn Db) -> Self { - match self { - Type::Instance(InstanceType { class }) if class.is_known(db, KnownClass::Bool) => { - Type::normalized_bool(db) - } - // TODO: decompose `LiteralString` into `Literal[""] | TruthyLiteralString`? - // We'd need to rename this method... --Alex - _ => self, - } - } - /// Return a normalized version of `self` in which all unions and intersections are sorted /// according to a canonical order, no matter how "deeply" a union/intersection may be nested. #[must_use] @@ -905,10 +882,12 @@ impl<'db> Type<'db> { (_, Type::Never) => false, (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { - Type::normalized_bool(db).is_subtype_of(db, target) + Type::BooleanLiteral(true).is_subtype_of(db, target) + && Type::BooleanLiteral(false).is_subtype_of(db, target) } (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { - self.is_boolean_literal() + self.is_subtype_of(db, Type::BooleanLiteral(true)) + || self.is_subtype_of(db, Type::BooleanLiteral(false)) } (Type::Union(union), _) => union @@ -1125,10 +1104,12 @@ impl<'db> Type<'db> { } (Type::Instance(InstanceType { class }), _) if class.is_known(db, KnownClass::Bool) => { - Type::normalized_bool(db).is_assignable_to(db, target) + Type::BooleanLiteral(true).is_assignable_to(db, target) + && Type::BooleanLiteral(false).is_assignable_to(db, target) } (_, Type::Instance(InstanceType { class })) if class.is_known(db, KnownClass::Bool) => { - self.is_assignable_to(db, Type::normalized_bool(db)) + self.is_assignable_to(db, Type::BooleanLiteral(false)) + || self.is_assignable_to(db, Type::BooleanLiteral(true)) } // A union is assignable to a type T iff every element of the union is assignable to T. @@ -2409,13 +2390,6 @@ impl<'db> Type<'db> { KnownClass::NoneType.to_instance(db) } - /// The type `Literal[True, False]`, which is exactly equivalent to `bool` - /// (and which `bool` is eagerly normalized to in several situations) - pub fn normalized_bool(db: &'db dyn Db) -> Type<'db> { - const LITERAL_BOOLS: [Type; 2] = [Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; - Type::Union(UnionType::new(db, Box::from(LITERAL_BOOLS))) - } - /// Return the type of `tuple(sys.version_info)`. /// /// This is not exactly the type that `sys.version_info` has at runtime, diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 49ed137a96181..a9430f37a1fa7 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -26,14 +26,14 @@ //! eliminate the supertype from the intersection). //! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. -use crate::types::{IntersectionType, KnownClass, Type, UnionType}; +use crate::types::{InstanceType, IntersectionType, KnownClass, Type, UnionType}; use crate::{Db, FxOrderSet}; use smallvec::SmallVec; pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, - contains_bool_literals: bool, + bool_literals_present: BoolLiteralsPresent, } impl<'db> UnionBuilder<'db> { @@ -41,13 +41,12 @@ impl<'db> UnionBuilder<'db> { Self { db, elements: vec![], - contains_bool_literals: false, + bool_literals_present: BoolLiteralsPresent::Zero, } } /// Adds a type to this union. pub(crate) fn add(mut self, ty: Type<'db>) -> Self { - let ty = ty.with_normalized_bools(self.db); match ty { Type::Union(union) => { let new_elements = union.elements(self.db); @@ -57,6 +56,11 @@ impl<'db> UnionBuilder<'db> { } } Type::Never => {} + Type::Instance(InstanceType { class }) if class.is_known(self.db, KnownClass::Bool) => { + self = self + .add(Type::BooleanLiteral(false)) + .add(Type::BooleanLiteral(true)); + } _ => { let mut to_remove = SmallVec::<[usize; 2]>::new(); let ty_negated = ty.negate(self.db); @@ -79,7 +83,11 @@ impl<'db> UnionBuilder<'db> { return self; } } - self.contains_bool_literals |= ty.is_boolean_literal(); + + if ty.is_boolean_literal() { + self.bool_literals_present.increment(); + } + match to_remove[..] { [] => self.elements.push(ty), [index] => self.elements[index] = ty, @@ -109,14 +117,14 @@ impl<'db> UnionBuilder<'db> { let UnionBuilder { mut elements, db, - contains_bool_literals, + bool_literals_present, } = self; match elements.len() { 0 => Type::Never, 1 => elements[0], _ => { - if contains_bool_literals { + if bool_literals_present.is_two() { let mut element_iter = elements.iter(); if let Some(first_pos) = element_iter.position(Type::is_boolean_literal) { if let Some(second_pos) = element_iter.position(Type::is_boolean_literal) { @@ -135,6 +143,27 @@ impl<'db> UnionBuilder<'db> { } } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum BoolLiteralsPresent { + Zero, + One, + Two, +} + +impl BoolLiteralsPresent { + fn increment(&mut self) { + *self = match self { + BoolLiteralsPresent::Zero => BoolLiteralsPresent::One, + BoolLiteralsPresent::One => BoolLiteralsPresent::Two, + BoolLiteralsPresent::Two => BoolLiteralsPresent::Two, + }; + } + + const fn is_two(self) -> bool { + matches!(self, BoolLiteralsPresent::Two) + } +} + #[derive(Clone)] pub(crate) struct IntersectionBuilder<'db> { // Really this builds a union-of-intersections, because we always keep our set-theoretic types @@ -162,8 +191,18 @@ impl<'db> IntersectionBuilder<'db> { } pub(crate) fn add_positive(mut self, ty: Type<'db>) -> Self { - let ty = ty.with_normalized_bools(self.db); - if let Type::Union(union) = ty { + const BOOL_LITERALS: &[Type] = &[Type::BooleanLiteral(false), Type::BooleanLiteral(true)]; + + // Treat `bool` as `Literal[True] | Literal[False]` + let union_elements = match ty { + Type::Union(union) => Some(union.elements(self.db)), + Type::Instance(InstanceType { class }) if class.is_known(self.db, KnownClass::Bool) => { + Some(BOOL_LITERALS) + } + _ => None, + }; + + if let Some(elements) = union_elements { // Distribute ourself over this union: for each union element, clone ourself and // intersect with that union element, then create a new union-of-intersections with all // of those sub-intersections in it. E.g. if `self` is a simple intersection `T1 & T2` @@ -172,8 +211,7 @@ impl<'db> IntersectionBuilder<'db> { // (T2 & T4)`. If `self` is already a union-of-intersections `(T1 & T2) | (T3 & T4)` // and we add `T5 | T6` to it, that flattens all the way out to `(T1 & T2 & T5) | (T1 & // T2 & T6) | (T3 & T4 & T5) ...` -- you get the idea. - union - .elements(self.db) + elements .iter() .map(|elem| self.clone().add_positive(*elem)) .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { @@ -193,45 +231,51 @@ impl<'db> IntersectionBuilder<'db> { pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self { // See comments above in `add_positive`; this is just the negated version. - let ty = ty.with_normalized_bools(self.db); - - if let Type::Union(union) = ty { - for elem in union.elements(self.db) { - self = self.add_negative(*elem); + match ty { + Type::Union(union) => { + for elem in union.elements(self.db) { + self = self.add_negative(*elem); + } + self } - self - } else if let Type::Intersection(intersection) = ty { - // (A | B) & ~(C & ~D) - // -> (A | B) & (~C | D) - // -> ((A | B) & ~C) | ((A | B) & D) - // i.e. if we have an intersection of positive constraints C - // and negative constraints D, then our new intersection - // is (existing & ~C) | (existing & D) - - let positive_side = intersection - .positive(self.db) - .iter() - // we negate all the positive constraints while distributing - .map(|elem| self.clone().add_negative(*elem)); - - let negative_side = intersection - .negative(self.db) - .iter() - // all negative constraints end up becoming positive constraints - .map(|elem| self.clone().add_positive(*elem)); - - positive_side.chain(negative_side).fold( - IntersectionBuilder::empty(self.db), - |mut builder, sub| { - builder.intersections.extend(sub.intersections); - builder - }, - ) - } else { - for inner in &mut self.intersections { - inner.add_negative(self.db, ty); + Type::Instance(InstanceType { class }) if class.is_known(self.db, KnownClass::Bool) => { + self.add_negative(Type::BooleanLiteral(false)) + .add_negative(Type::BooleanLiteral(true)) + } + Type::Intersection(intersection) => { + // (A | B) & ~(C & ~D) + // -> (A | B) & (~C | D) + // -> ((A | B) & ~C) | ((A | B) & D) + // i.e. if we have an intersection of positive constraints C + // and negative constraints D, then our new intersection + // is (existing & ~C) | (existing & D) + + let positive_side = intersection + .positive(self.db) + .iter() + // we negate all the positive constraints while distributing + .map(|elem| self.clone().add_negative(*elem)); + + let negative_side = intersection + .negative(self.db) + .iter() + // all negative constraints end up becoming positive constraints + .map(|elem| self.clone().add_positive(*elem)); + + positive_side.chain(negative_side).fold( + IntersectionBuilder::empty(self.db), + |mut builder, sub| { + builder.intersections.extend(sub.intersections); + builder + }, + ) + } + _ => { + for inner in &mut self.intersections { + inner.add_negative(self.db, ty); + } + self } - self } }