Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ static_assert(is_assignable_to(Intersection[int, Parent], Intersection[int, Not[
static_assert(not is_assignable_to(int, Not[int]))
static_assert(not is_assignable_to(int, Not[Literal[1]]))

static_assert(is_assignable_to(Not[Parent], Not[Child1]))
static_assert(not is_assignable_to(Not[Parent], Parent))
static_assert(not is_assignable_to(Intersection[Unrelated, Not[Parent]], Parent))

# Intersection with `Any` dominates the left hand side of intersections
static_assert(is_assignable_to(Intersection[Any, Parent], Parent))
static_assert(is_assignable_to(Intersection[Any, Child1], Parent))
Expand All @@ -277,6 +281,7 @@ static_assert(is_assignable_to(Intersection[Any, Parent, Unrelated], Intersectio

# Even Any & Not[Parent] is assignable to Parent, since it could be Never
static_assert(is_assignable_to(Intersection[Any, Not[Parent]], Parent))
static_assert(is_assignable_to(Intersection[Any, Not[Parent]], Not[Parent]))

# Intersection with `Any` is effectively ignored on the right hand side for the sake of assignment
static_assert(is_assignable_to(Parent, Intersection[Any, Parent]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ static_assert(not is_disjoint_from(bool, object))

static_assert(not is_disjoint_from(Any, bool))
static_assert(not is_disjoint_from(Any, Any))
static_assert(not is_disjoint_from(Any, Not[Any]))

static_assert(not is_disjoint_from(LiteralString, LiteralString))
static_assert(not is_disjoint_from(str, LiteralString))
Expand Down Expand Up @@ -95,8 +96,8 @@ static_assert(not is_disjoint_from(Literal[1, 2], Literal[2, 3]))
## Intersections

```py
from typing_extensions import Literal, final
from knot_extensions import Intersection, is_disjoint_from, static_assert
from typing_extensions import Literal, final, Any
from knot_extensions import Intersection, is_disjoint_from, static_assert, Not

@final
class P: ...
Expand Down Expand Up @@ -130,6 +131,27 @@ static_assert(not is_disjoint_from(Y, Z))
static_assert(not is_disjoint_from(Intersection[X, Y], Z))
static_assert(not is_disjoint_from(Intersection[X, Z], Y))
static_assert(not is_disjoint_from(Intersection[Y, Z], X))

# If one side has a positive fully-static element and the other side has a negative of that element, they are disjoint
static_assert(is_disjoint_from(int, Not[int]))
static_assert(is_disjoint_from(Intersection[X, Y, Not[Z]], Intersection[X, Z]))
static_assert(is_disjoint_from(Intersection[X, Not[Literal[1]]], Literal[1]))

class Parent: ...
class Child(Parent): ...

static_assert(not is_disjoint_from(Parent, Child))
static_assert(not is_disjoint_from(Parent, Not[Child]))
static_assert(not is_disjoint_from(Not[Parent], Not[Child]))
static_assert(is_disjoint_from(Not[Parent], Child))
static_assert(is_disjoint_from(Intersection[X, Not[Parent]], Child))
static_assert(is_disjoint_from(Intersection[X, Not[Parent]], Intersection[X, Child]))

static_assert(not is_disjoint_from(Intersection[Any, X], Intersection[Any, Not[Y]]))
static_assert(not is_disjoint_from(Intersection[Any, Not[Y]], Intersection[Any, X]))

static_assert(is_disjoint_from(Intersection[int, Any], Not[int]))
static_assert(is_disjoint_from(Not[int], Intersection[int, Any]))
```

## Special types
Expand All @@ -152,7 +174,7 @@ static_assert(is_disjoint_from(Never, object))

```py
from typing_extensions import Literal, LiteralString
from knot_extensions import is_disjoint_from, static_assert
from knot_extensions import is_disjoint_from, static_assert, Intersection, Not

static_assert(is_disjoint_from(None, Literal[True]))
static_assert(is_disjoint_from(None, Literal[1]))
Expand All @@ -165,6 +187,9 @@ static_assert(is_disjoint_from(None, type[object]))
static_assert(not is_disjoint_from(None, None))
static_assert(not is_disjoint_from(None, int | None))
static_assert(not is_disjoint_from(None, object))

static_assert(is_disjoint_from(Intersection[int, Not[str]], None))
static_assert(is_disjoint_from(None, Intersection[int, Not[str]]))
```

### Literals
Expand Down
72 changes: 32 additions & 40 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,38 +580,9 @@ impl<'db> Type<'db> {
true
}

(Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => {
// Check that all target positive values are covered in self positive values
target_intersection
.positive(db)
.iter()
.all(|&target_pos_elem| {
self_intersection
.positive(db)
.iter()
.any(|&self_pos_elem| self_pos_elem.is_subtype_of(db, target_pos_elem))
})
// Check that all target negative values are excluded in self, either by being
// subtypes of a self negative value or being disjoint from a self positive value.
&& target_intersection
.negative(db)
.iter()
.all(|&target_neg_elem| {
// Is target negative value is subtype of a self negative value
self_intersection.negative(db).iter().any(|&self_neg_elem| {
target_neg_elem.is_subtype_of(db, self_neg_elem)
// Is target negative value is disjoint from a self positive value?
}) || self_intersection.positive(db).iter().any(|&self_pos_elem| {
self_pos_elem.is_disjoint_from(db, target_neg_elem)
})
})
}
Comment on lines -583 to -608
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty cool that this whole case can be eliminated by improving is_disjoint_from! Really nice find.


(Type::Intersection(intersection), _) => intersection
.positive(db)
.iter()
.any(|&elem_ty| elem_ty.is_subtype_of(db, target)),

// If both sides are intersections we need to handle the right side first
// (A & B & C) is a subtype of (A & B) because the left is a subtype of both A and B,
// but none of A, B, or C is a subtype of (A & B).
(_, Type::Intersection(intersection)) => {
intersection
.positive(db)
Expand All @@ -623,6 +594,11 @@ impl<'db> Type<'db> {
.all(|&neg_ty| self.is_disjoint_from(db, neg_ty))
}

(Type::Intersection(intersection), _) => intersection
.positive(db)
.iter()
.any(|&elem_ty| elem_ty.is_subtype_of(db, target)),

// Note that the definition of `Type::AlwaysFalsy` depends on the return value of `__bool__`.
// If `__bool__` always returns True or False, it can be treated as a subtype of `AlwaysTruthy` or `AlwaysFalsy`, respectively.
(left, Type::AlwaysFalsy) => left.bool(db).is_always_false(),
Expand Down Expand Up @@ -802,6 +778,10 @@ impl<'db> Type<'db> {
.iter()
.any(|&elem_ty| ty.is_assignable_to(db, elem_ty)),

// If both sides are intersections we need to handle the right side first
// (A & B & C) is assignable to (A & B) because the left is assignable to both A and B,
// but none of A, B, or C is assignable to (A & B).
//
// A type S is assignable to an intersection type T if
// S is assignable to all positive elements of T (e.g. `str & int` is assignable to `str & Any`), and
// S is disjoint from all negative elements of T (e.g. `int` is not assignable to Intersection[int, Not[Literal[1]]]).
Expand Down Expand Up @@ -998,19 +978,31 @@ impl<'db> Type<'db> {
.iter()
.all(|e| e.is_disjoint_from(db, other)),

// If we have two intersections, we test the positive elements of each one against the other intersection
// Negative elements need a positive element on the other side in order to be disjoint.
// This is similar to what would happen if we tried to build a new intersection that combines the two
(Type::Intersection(self_intersection), Type::Intersection(other_intersection)) => {
self_intersection
.positive(db)
.iter()
.any(|p| p.is_disjoint_from(db, other))
|| other_intersection
.positive(db)
.iter()
.any(|p: &Type<'_>| p.is_disjoint_from(db, self))
}

(Type::Intersection(intersection), other)
| (other, Type::Intersection(intersection)) => {
if intersection
intersection
.positive(db)
.iter()
.any(|p| p.is_disjoint_from(db, other))
{
true
} else {
// TODO we can do better here. For example:
// X & ~Literal[1] is disjoint from Literal[1]
false
}
// A & B & Not[C] is disjoint from C
|| intersection
.negative(db)
.iter()
.any(|&neg_ty| other.is_subtype_of(db, neg_ty))
}

// any single-valued type is disjoint from another single-valued type
Expand Down
Loading