diff --git a/crates/ruff_benchmark/benches/ty.rs b/crates/ruff_benchmark/benches/ty.rs index e6990de32ef2a0..6f5a122cea19d7 100644 --- a/crates/ruff_benchmark/benches/ty.rs +++ b/crates/ruff_benchmark/benches/ty.rs @@ -658,6 +658,70 @@ class E(Enum): }); } +/// Benchmark for narrowing a large union type through multiple match statements. +/// +/// This is extracted from egglog-python's `pretty.py`, where a ~30-class union type +/// (`AllDecls`) is narrowed by exhaustive match statements. +/// +/// Sample code structure: +/// ```python +/// from __future__ import annotations +/// from dataclasses import dataclass +/// +/// @dataclass +/// class C0: +/// value: int +/// ... +/// +/// AllDecls = C0 | C1 | ... +/// +/// def process(decl: AllDecls) -> None: +/// match decl: +/// case C0(): pass +/// ... +/// case _: pass +/// ``` +fn benchmark_large_union_narrowing(criterion: &mut Criterion) { + const NUM_CLASSES: usize = 30; + const NUM_MATCH_BRANCHES: usize = 29; + + setup_rayon(); + + let mut code = + "from __future__ import annotations\nfrom dataclasses import dataclass\n\n".to_string(); + + for i in 0..NUM_CLASSES { + writeln!(&mut code, "@dataclass\nclass C{i}:\n value: int\n").ok(); + } + + code.push_str("AllDecls = "); + for i in 0..NUM_CLASSES { + if i > 0 { + code.push_str(" | "); + } + write!(&mut code, "C{i}").ok(); + } + code.push_str("\n\n"); + + code.push_str("def process(decl: AllDecls) -> None:\n match decl:\n"); + for i in 0..NUM_MATCH_BRANCHES { + writeln!(&mut code, " case C{i}():\n pass",).ok(); + } + code.push_str(" case _:\n pass\n\n"); + + criterion.bench_function("ty_micro[large_union_narrowing]", |b| { + b.iter_batched_ref( + || setup_micro_case(&code), + |case| { + let Case { db, .. } = case; + let result = db.check(); + assert_eq!(result.len(), 0); + }, + BatchSize::SmallInput, + ); + }); +} + struct ProjectBenchmark<'a> { project: InstalledProject<'a>, fs: MemoryFileSystem, @@ -820,6 +884,7 @@ criterion_group!( benchmark_many_enum_members, benchmark_many_enum_members_2, benchmark_very_large_tuple, + benchmark_large_union_narrowing, ); criterion_group!(project, anyio, attrs, hydra, datetype); criterion_main!(check_file, micro, project); diff --git a/crates/ty_python_semantic/resources/mdtest/binary/integers.md b/crates/ty_python_semantic/resources/mdtest/binary/integers.md index 30561981810b32..3834a90e0ba20e 100644 --- a/crates/ty_python_semantic/resources/mdtest/binary/integers.md +++ b/crates/ty_python_semantic/resources/mdtest/binary/integers.md @@ -1,5 +1,11 @@ # Binary operations on integers +> Developer's note: This is mainly a test for the behavior of the type inferer. The constant +> evaluator (`resolve_to_literal`) of `SemanticIndexBuilder` is implemented separately from the type +> inferer, so if you modify the contents of this file or the type inferer, please also modify the +> implementation of `resolve_to_literal` and the unit tests (semantic_index/tests/const_eval\_\*) at +> the same time. + ## Basic Arithmetic ```py diff --git a/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md b/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md index e917582c5804ab..8e88a68e63f9c4 100644 --- a/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md +++ b/crates/ty_python_semantic/resources/mdtest/loops/while_loop.md @@ -247,21 +247,45 @@ Here the loop condition forces `x` to be `False` at loop exit, because there is def random() -> bool: return True -x = random() -reveal_type(x) # revealed: bool -while x: - pass -reveal_type(x) # revealed: Literal[False] +def _(x: bool): + while x: + pass + reveal_type(x) # revealed: Literal[False] ``` However, we can't narrow `x` like this when there's a `break` in the loop: ```py -x = random() -while x: - if random(): +def _(x: bool): + while x: + if random(): + break + reveal_type(x) # revealed: bool + +def _(x: bool): + while x: + pass + reveal_type(x) # revealed: Literal[False] + + x = random() + while x: + if random(): + break + reveal_type(x) # revealed: bool + +def _(y: int | None): + x = 1 + while True: + if x == 0: + break + + if y is None: + y = 0 + continue + break -reveal_type(x) # revealed: bool + + reveal_type(y) # revealed: int ``` ### Non-static loop conditions diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/match.md b/crates/ty_python_semantic/resources/mdtest/narrow/match.md index d18d48b2b745cc..34468027be3eb5 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/match.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/match.md @@ -254,7 +254,7 @@ def _(x: Literal["foo", b"bar"] | int): pass case b"bar" if reveal_type(x): # revealed: Literal[b"bar"] | int pass - case _ if reveal_type(x): # revealed: int | Literal["foo", b"bar"] + case _ if reveal_type(x): # revealed: Literal["foo", b"bar"] | int pass ``` diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md index 76d96d746baf10..be84d381e9223c 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/post_if_statement.md @@ -180,27 +180,56 @@ def _(x: int | None): ``` ```py +from typing import Final + def _(x: int | None): if 1 + 1 == 2: if x is None: return reveal_type(x) # revealed: int - # TODO: should be `int` (the else-branch of `1 + 1 == 2` is unreachable) - reveal_type(x) # revealed: int | None + reveal_type(x) # revealed: int + +# non-constant but always-true condition +needs_inference: Final = True + +def _(x: int | None): + if needs_inference: + if x is None: + return + reveal_type(x) # revealed: int + + reveal_type(x) # revealed: int ``` This also works when the always-true condition is nested inside a narrowing branch: ```py +from typing import Literal + def _(x: int | None): if x is None: if 1 + 1 == 2: return - # TODO: should be `int` (the inner always-true branch makes the outer - # if-branch terminal) - reveal_type(x) # revealed: int | None + reveal_type(x) # revealed: int + +def _(x: int | None): + if x is None: + if needs_inference: + return + + reveal_type(x) # revealed: int + +def always_true(val: object) -> Literal[True]: + return True + +def _(x: int | None): + if x is None: + if always_true(x): + return + + reveal_type(x) # revealed: int ``` ## Narrowing from `assert` should not affect reassigned variables diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md b/crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md index ff0c06e55ff008..5d52cb47daae46 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/truthiness.md @@ -31,14 +31,14 @@ else: reveal_type(x) # revealed: Never if x or not x: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[-1, 0, "foo", "", b"bar", b""] | bool | None | tuple[()] else: reveal_type(x) # revealed: Never if not (x or not x): reveal_type(x) # revealed: Never else: - reveal_type(x) # revealed: Literal[0, -1, "", "foo", b"", b"bar"] | bool | None | tuple[()] + reveal_type(x) # revealed: Literal[-1, 0, "foo", "", b"bar", b""] | bool | None | tuple[()] if (isinstance(x, int) or isinstance(x, str)) and x: reveal_type(x) # revealed: Literal[-1, True, "foo"] diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index e49d3188104ec8..d64bde0cef810f 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -925,6 +925,50 @@ mod tests { .collect() } + /// A function to test how the constant evaluator of `SemanticIndexBuilder` evaluates an expression + /// (the evaluation should match that of `TypeInferenceBuilder`). + /// For example, for the input `x = 1\nif cond: x = 2\nx`, if `cond` evaluates to `AlwaysTrue`, it returns `vec![2]`, + /// if it evaluates to `AlwaysFalse`, it returns `vec![1]`, ​​if it evaluates to `Ambiguous`, it returns `vec![1, 2]`. + fn reachable_bindings_for_terminal_use(content: &str) -> Vec { + let TestCase { db, file } = test_case(content); + let scope = global_scope(&db, file); + let module = parsed_module(&db, file).load(&db); + let ast = module.syntax(); + + let terminal_expr = ast + .body + .last() + .and_then(ast::Stmt::as_expr_stmt) + .map(|stmt| stmt.value.as_ref()) + .expect("expected terminal expression statement"); + let terminal_name = terminal_expr + .as_name_expr() + .expect("terminal expression should be a name"); + + let use_id = terminal_name.scoped_use_id(&db, scope); + let use_def = use_def_map(&db, scope); + + use_def + .bindings_at_use(use_id) + .filter_map(|binding_with_constraints| { + let definition = binding_with_constraints.binding.definition()?; + let DefinitionKind::Assignment(assignment) = definition.kind(&db) else { + return None; + }; + + let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + value: ast::Number::Int(value), + .. + }) = assignment.value(&module) + else { + return None; + }; + + value.as_i64() + }) + .collect::>() + } + #[test] fn empty() { let TestCase { db, file } = test_case(""); @@ -1590,6 +1634,71 @@ class C[T]: assert_eq!(*num, 1); } + #[test] + fn const_eval_lshift_overflow_is_ambiguous() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 1 << 63: + x = 2 +x +", + ); + assert_eq!(values, vec![1, 2]); + } + + #[test] + fn const_eval_lshift_zero_short_circuit() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 0 << 4000000000000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![1]); + } + + #[test] + fn const_eval_rshift_large_positive() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 1 >> 5000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![1]); + } + + #[test] + fn const_eval_rshift_large_negative_operand() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if (-1) >> 5000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![2]); + } + + #[test] + fn const_eval_negative_lshift_is_ambiguous() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 42 << -3: + x = 2 +x +", + ); + assert_eq!(values, vec![1, 2]); + } + #[test] fn expression_scope() { let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4"); diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 76cb3c3df1ef36..2c081293182837 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -908,27 +908,253 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { } fn build_predicate(&mut self, predicate_node: &ast::Expr) -> PredicateOrLiteral<'db> { + /// Returns if the expression is a `TYPE_CHECKING` expression. + fn is_if_type_checking(expr: &ast::Expr) -> bool { + fn is_dotted_name(expr: &ast::Expr) -> bool { + match expr { + ast::Expr::Name(_) => true, + ast::Expr::Attribute(ast::ExprAttribute { value, .. }) => is_dotted_name(value), + _ => false, + } + } + + match expr { + ast::Expr::Name(ast::ExprName { id, .. }) => id == "TYPE_CHECKING", + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { + attr == "TYPE_CHECKING" && is_dotted_name(value) + } + _ => false, + } + } + // Some commonly used test expressions are eagerly evaluated as `true` // or `false` here for performance reasons. This list does not need to // be exhaustive. More complex expressions will still evaluate to the // correct value during type-checking. fn resolve_to_literal(node: &ast::Expr) -> Option { - match node { - ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { value, .. }) => Some(*value), - ast::Expr::Name(ast::ExprName { id, .. }) if id == "TYPE_CHECKING" => Some(true), - ast::Expr::NumberLiteral(ast::ExprNumberLiteral { - value: ast::Number::Int(n), - .. - }) => Some(*n != 0), - ast::Expr::EllipsisLiteral(_) => Some(true), - ast::Expr::NoneLiteral(_) => Some(false), - ast::Expr::UnaryOp(ast::ExprUnaryOp { - op: ast::UnaryOp::Not, - operand, - .. - }) => Some(!resolve_to_literal(operand)?), - _ => None, + #[derive(Copy, Clone)] + enum ConstExpr { + Bool(bool), + Int(i64), + None, + Ellipsis, + } + + impl ConstExpr { + fn truthiness(self) -> bool { + match self { + ConstExpr::Bool(value) => value, + ConstExpr::Int(value) => value != 0, + ConstExpr::None => false, + ConstExpr::Ellipsis => true, + } + } + + fn as_int(self) -> Option { + match self { + ConstExpr::Int(value) => Some(value), + ConstExpr::Bool(value) => Some(i64::from(value)), + _ => None, + } + } + } + + fn resolve_const_expr(node: &ast::Expr) -> Option { + match node { + ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { value, .. }) => { + Some(ConstExpr::Bool(*value)) + } + ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + value: ast::Number::Int(n), + .. + }) => n.as_i64().map(ConstExpr::Int), + ast::Expr::EllipsisLiteral(_) => Some(ConstExpr::Ellipsis), + ast::Expr::NoneLiteral(_) => Some(ConstExpr::None), + // See also: `TypeInferenceBuilder::infer_unary_expression_type` + ast::Expr::UnaryOp(ast::ExprUnaryOp { op, operand, .. }) => { + let operand = resolve_const_expr(operand)?; + match op { + ast::UnaryOp::Not => Some(ConstExpr::Bool(!operand.truthiness())), + ast::UnaryOp::UAdd => Some(ConstExpr::Int(operand.as_int()?)), + ast::UnaryOp::USub => { + Some(ConstExpr::Int(operand.as_int()?.checked_neg()?)) + } + ast::UnaryOp::Invert => Some(ConstExpr::Int(!operand.as_int()?)), + } + } + // See also: `TypeInferenceBuilder::infer_binary_expression_type` + ast::Expr::BinOp(ast::ExprBinOp { + left, op, right, .. + }) => { + let left = resolve_const_expr(left)?.as_int()?; + let right = resolve_const_expr(right)?.as_int()?; + let value = match op { + ast::Operator::Add => left.checked_add(right)?, + ast::Operator::Sub => left.checked_sub(right)?, + ast::Operator::Mult => left.checked_mul(right)?, + ast::Operator::FloorDiv => { + let mut q = left.checked_div(right); + let r = left.checked_rem(right); + // Division works differently in Python than in Rust. If the + // result is negative and there is a remainder, floor division + // rounds down (instead of toward zero). + if left.is_negative() != right.is_negative() && r.unwrap_or(0) != 0 + { + q = q.map(|q| q - 1); + } + q? + } + ast::Operator::Mod => { + let mut r = left.checked_rem(right); + // Python's modulo keeps the sign of the divisor. Adjust the Rust + // remainder accordingly so that `q * right + r == left`. + if left.is_negative() != right.is_negative() && r.unwrap_or(0) != 0 + { + r = r.map(|x| x + right); + } + r? + } + ast::Operator::BitAnd => left & right, + ast::Operator::BitOr => left | right, + ast::Operator::BitXor => left ^ right, + ast::Operator::LShift => { + if left == 0 && right >= 0 { + 0 + } else { + // An additional overflow check beyond `checked_shl` is + // necessary here, because `checked_shl` only rejects shift + // amounts >= 64; it does not detect when significant bits + // are shifted into (or past) the sign bit. + // + // We compute the "headroom": the number of redundant + // sign-extension bits minus one (for the sign bit itself). + // A shift is safe iff `shift <= headroom`. + let headroom = if left >= 0 { + left.leading_zeros().saturating_sub(1) + } else { + left.leading_ones().saturating_sub(1) + }; + u32::try_from(right) + .ok() + .filter(|&shift| shift <= headroom) + .and_then(|shift| left.checked_shl(shift))? + } + } + ast::Operator::RShift => match u32::try_from(right) { + Ok(shift) => left >> shift.clamp(0, 63), + Err(_) if right > 0 => { + if left >= 0 { + 0 + } else { + -1 + } + } + Err(_) => return None, + }, + ast::Operator::Pow => { + let exp = u32::try_from(right).ok()?; + left.checked_pow(exp)? + } + ast::Operator::Div | ast::Operator::MatMult => return None, + }; + Some(ConstExpr::Int(value)) + } + ast::Expr::BoolOp(ast::ExprBoolOp { op, values, .. }) => { + let value = match op { + ast::BoolOp::And => { + let mut all_true = true; + for expr in values { + if !resolve_const_expr(expr)?.truthiness() { + all_true = false; + break; + } + } + all_true + } + ast::BoolOp::Or => { + let mut any_true = false; + for expr in values { + if resolve_const_expr(expr)?.truthiness() { + any_true = true; + break; + } + } + any_true + } + }; + Some(ConstExpr::Bool(value)) + } + ast::Expr::Compare(ast::ExprCompare { + left, + ops, + comparators, + .. + }) => { + let mut left_value = resolve_const_expr(left)?; + for (op, comparator) in ops.iter().zip(comparators.iter()) { + let right_value = resolve_const_expr(comparator)?; + let eq = |left: ConstExpr, right: ConstExpr| match (left, right) { + (ConstExpr::Int(left), ConstExpr::Int(right)) => { + Some(left == right) + } + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) => Some(true), + (ConstExpr::None | ConstExpr::Ellipsis, _) + | (_, ConstExpr::None | ConstExpr::Ellipsis) => Some(false), + _ => None, + }; + let result = match op { + ast::CmpOp::Eq => eq(left_value, right_value)?, + ast::CmpOp::NotEq => !eq(left_value, right_value)?, + ast::CmpOp::Lt => left_value.as_int()? < right_value.as_int()?, + ast::CmpOp::LtE => left_value.as_int()? <= right_value.as_int()?, + ast::CmpOp::Gt => left_value.as_int()? > right_value.as_int()?, + ast::CmpOp::GtE => left_value.as_int()? >= right_value.as_int()?, + ast::CmpOp::Is => match (left_value, right_value) { + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) + | (ConstExpr::Bool(true), ConstExpr::Bool(true)) + | (ConstExpr::Bool(false), ConstExpr::Bool(false)) => true, + ( + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + _, + ) + | ( + _, + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + ) => false, + _ => return None, + }, + ast::CmpOp::IsNot => match (left_value, right_value) { + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) + | (ConstExpr::Bool(true), ConstExpr::Bool(true)) + | (ConstExpr::Bool(false), ConstExpr::Bool(false)) => false, + ( + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + _, + ) + | ( + _, + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + ) => true, + _ => return None, + }, + ast::CmpOp::In | ast::CmpOp::NotIn => return None, + }; + if !result { + return Some(ConstExpr::Bool(false)); + } + left_value = right_value; + } + Some(ConstExpr::Bool(true)) + } + _ if is_if_type_checking(node) => Some(ConstExpr::Bool(true)), + _ => None, + } } + + Some(resolve_const_expr(node)?.truthiness()) } let expression = self.add_standalone_expression(predicate_node); @@ -1955,14 +2181,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { if let Some(msg) = msg { let post_test = self.flow_snapshot(); let negated_predicate = predicate.negated(); - self.record_narrowing_constraint(negated_predicate); - self.record_reachability_constraint(negated_predicate); + let predicate_id = self.record_narrowing_constraint(negated_predicate); + self.record_reachability_constraint_id(predicate_id); self.visit_expr(msg); self.flow_restore(post_test); } - self.record_narrowing_constraint(predicate); - self.record_reachability_constraint(predicate); + let predicate_id = self.record_narrowing_constraint(predicate); + self.record_reachability_constraint_id(predicate_id); } ast::Stmt::Assign(node) => { @@ -2080,7 +2306,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { let (mut last_predicate, mut last_narrowing_id) = self.record_expression_narrowing_constraint(&node.test); let mut last_reachability_constraint = - self.record_reachability_constraint(last_predicate); + self.record_reachability_constraint_id(last_narrowing_id); let is_outer_block_in_type_checking = self.in_type_checking_block; @@ -2131,7 +2357,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.record_expression_narrowing_constraint(elif_test); last_reachability_constraint = - self.record_reachability_constraint(last_predicate); + self.record_reachability_constraint_id(last_narrowing_id); } // Determine if this clause is in type checking context @@ -2195,7 +2421,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { // after the loop. let pre_loop = self.flow_snapshot(); let (predicate, predicate_id) = self.record_expression_narrowing_constraint(test); - self.record_reachability_constraint(predicate); + self.record_reachability_constraint_id(predicate_id); let outer_loop = self.push_loop(); self.visit_body(body); @@ -2375,36 +2601,25 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { ); previous_pattern = Some(match_pattern_predicate); let reachability_constraint = - self.record_reachability_constraint(match_predicate); + self.record_reachability_constraint_id(match_narrowing_id); let match_success_guard_failure = case.guard.as_ref().map(|guard| { - let guard_expr = self.add_standalone_expression(guard); - // We could also add the guard expression as a reachability constraint, but - // it seems unlikely that both the case predicate as well as the guard are - // statically known conditions, so we currently don't model that. - self.record_ambiguous_reachability(); self.visit_expr(guard); let post_guard_eval = self.flow_snapshot(); - let predicate = PredicateOrLiteral::Predicate(Predicate { - node: PredicateNode::Expression(guard_expr), - is_positive: true, - }); - // Add the predicate once, then use TDD-level negation for the failure - // path. This ensures the positive and negative atoms share the same ID. - let guard_predicate_id = self.add_predicate(predicate); - let possibly_narrowed = self.compute_possibly_narrowed_places(&predicate); - self.current_use_def_map_mut() - .record_negated_narrowing_constraint_for_places( - guard_predicate_id, - &possibly_narrowed, - ); - let match_success_guard_failure = self.flow_snapshot(); + let (guard_predicate, guard_predicate_id) = + self.record_expression_narrowing_constraint(guard); + let guard_reachability_constraint = + self.record_reachability_constraint_id(guard_predicate_id); + + let guard_success_state = self.flow_snapshot(); self.flow_restore(post_guard_eval); - self.current_use_def_map_mut() - .record_narrowing_constraint_for_places( - guard_predicate_id, - &possibly_narrowed, - ); + self.record_negated_narrowing_constraint( + guard_predicate, + guard_predicate_id, + ); + self.record_negated_reachability_constraint(guard_reachability_constraint); + let match_success_guard_failure = self.flow_snapshot(); + self.flow_restore(guard_success_state); match_success_guard_failure }); @@ -2963,7 +3178,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.visit_expr(test); let pre_if = self.flow_snapshot(); let (predicate, predicate_id) = self.record_expression_narrowing_constraint(test); - let reachability_constraint = self.record_reachability_constraint(predicate); + let reachability_constraint = self.record_reachability_constraint_id(predicate_id); self.visit_expr(body); let post_body = self.flow_snapshot(); self.flow_restore(pre_if); diff --git a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs index 647bb088d15826..26ee3a291e6acb 100644 --- a/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs +++ b/crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs @@ -202,12 +202,16 @@ use crate::Db; use crate::dunder_all::dunder_all_names; use crate::place::{RequiresExplicitReExport, imported_symbol}; use crate::rank::RankBitBox; +use crate::semantic_index::narrowing_constraints::ScopedNarrowingConstraint; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::place_table; use crate::semantic_index::predicate::{ CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId, }; +use crate::semantic_index::use_def::{ + PlaceVersion, PredicatePlaceVersionInfo, PredicatePlaceVersions, +}; use crate::types::{ CallableTypes, IntersectionBuilder, NarrowingConstraint, Truthiness, Type, TypeContext, UnionBuilder, UnionType, infer_expression_type, infer_narrowing_constraint, @@ -761,7 +765,7 @@ fn accumulate_constraint<'db>( new: Option>, ) -> Option> { match (accumulated, new) { - (Some(acc), Some(new_c)) => Some(new_c.merge_constraint_and(acc, db)), + (Some(acc), Some(new_c)) => Some(new_c.merge_constraint_and(db, acc)), (None, Some(new_c)) => Some(new_c), (Some(acc), None) => Some(acc), (None, None) => None, @@ -802,34 +806,73 @@ impl ReachabilityConstraints { /// - `ALWAYS_FALSE`: this path is impossible → Never /// /// The final result is the union of all path results. + #[expect(clippy::too_many_arguments)] pub(crate) fn narrow_by_constraint<'db>( &self, db: &'db dyn Db, predicates: &Predicates<'db>, - id: ScopedReachabilityConstraintId, + predicate_place_versions: &PredicatePlaceVersions, + id: ScopedNarrowingConstraint, base_ty: Type<'db>, place: ScopedPlaceId, + binding_place_version: Option, ) -> Type<'db> { - self.narrow_by_constraint_inner(db, predicates, id, base_ty, place, None) + let mut memo = FxHashMap::default(); + let mut truthiness_memo = FxHashMap::default(); + let redundant_union = self.narrow_by_constraint_inner( + db, + predicates, + predicate_place_versions, + id, + base_ty, + place, + binding_place_version, + None, + &mut memo, + &mut truthiness_memo, + ); + UnionBuilder::new(db) + .unpack_aliases(false) + .add(redundant_union) + .build() } /// Inner recursive helper that accumulates narrowing constraints along each TDD path. + #[allow(clippy::too_many_arguments)] fn narrow_by_constraint_inner<'db>( &self, db: &'db dyn Db, predicates: &Predicates<'db>, - id: ScopedReachabilityConstraintId, + predicate_place_versions: &PredicatePlaceVersions, + id: ScopedNarrowingConstraint, base_ty: Type<'db>, place: ScopedPlaceId, + binding_place_version: Option, accumulated: Option>, + memo: &mut FxHashMap< + (ScopedNarrowingConstraint, Option>), + Type<'db>, + >, + truthiness_memo: &mut FxHashMap, Truthiness>, ) -> Type<'db> { - match id { + // `ALWAYS_TRUE` and `AMBIGUOUS` are equivalent for narrowing purposes. + // Canonicalize to improve memo hits across terminal leaves. + let memo_id = match id { + ALWAYS_TRUE | AMBIGUOUS => ALWAYS_TRUE, + _ => id, + }; + let key = (memo_id, accumulated); + if let Some(cached) = memo.get(&key) { + return *cached; + } + + let narrowed = match id { ALWAYS_TRUE | AMBIGUOUS => { // Apply all accumulated narrowing constraints to the base type match accumulated { - Some(constraint) => NarrowingConstraint::intersection(base_ty) - .merge_constraint_and(constraint, db) - .evaluate_constraint_type(db), + Some(constraint) => NarrowingConstraint::intersection(db, base_ty) + .merge_constraint_and(db, constraint) + .evaluate_constraint_type(db, false), None => base_ty, } } @@ -837,101 +880,116 @@ impl ReachabilityConstraints { _ => { let node = self.get_interior_node(id); let predicate = predicates[node.atom]; - - // `ReturnsNever` predicates don't narrow any variable; they only - // affect reachability. Evaluate the predicate to determine which - // path(s) are reachable, rather than walking both branches. - // `ReturnsNever` always evaluates to `AlwaysTrue` or `AlwaysFalse`, - // never `Ambiguous`. - if matches!(predicate.node, PredicateNode::ReturnsNever(_)) { - return match Self::analyze_single(db, &predicate) { - Truthiness::AlwaysTrue => self.narrow_by_constraint_inner( + macro_rules! narrow { + ($next_id:expr, $next_accumulated:expr) => { + self.narrow_by_constraint_inner( db, predicates, - node.if_true, + predicate_place_versions, + $next_id, base_ty, place, - accumulated, - ), - Truthiness::AlwaysFalse => self.narrow_by_constraint_inner( - db, - predicates, - node.if_false, - base_ty, - place, - accumulated, - ), - Truthiness::Ambiguous => { - unreachable!("ReturnsNever predicates should never be Ambiguous") - } + binding_place_version, + $next_accumulated, + memo, + truthiness_memo, + ) }; } // Check if this predicate narrows the variable we're interested in. - let pos_constraint = infer_narrowing_constraint(db, predicate, place); + let neg_predicate = Predicate { + node: predicate.node, + is_positive: !predicate.is_positive, + }; + let place_version_info = predicate_place_versions.get(&(node.atom, place)); + let can_apply_narrowing = place_version_info.is_some() + && Self::predicate_applies_to_place_version( + place_version_info, + binding_place_version, + ); + let (pos_constraint, neg_constraint) = if can_apply_narrowing { + ( + infer_narrowing_constraint(db, predicate, place), + infer_narrowing_constraint(db, neg_predicate, place), + ) + } else { + // No recorded place-version metadata means this predicate cannot narrow + // this place, or the narrowing belongs to a different place version. + // In either case, skip the expensive narrowing-inference queries. + (None, None) + }; + + // If this predicate does not narrow the current place and we can statically + // determine its truthiness, follow only the reachable branch. + if pos_constraint.is_none() && neg_constraint.is_none() { + match Self::analyze_single_cached(db, predicate, truthiness_memo) { + Truthiness::AlwaysTrue => { + let narrowed = narrow!(node.if_true, accumulated); + memo.insert(key, narrowed); + return narrowed; + } + Truthiness::AlwaysFalse => { + let narrowed = narrow!(node.if_false, accumulated); + memo.insert(key, narrowed); + return narrowed; + } + Truthiness::Ambiguous => {} + } + } // If the true branch is statically unreachable, skip it entirely. if node.if_true == ALWAYS_FALSE { - let neg_predicate = Predicate { - node: predicate.node, - is_positive: !predicate.is_positive, - }; - let neg_constraint = infer_narrowing_constraint(db, neg_predicate, place); let false_accumulated = accumulate_constraint(db, accumulated, neg_constraint); - return self.narrow_by_constraint_inner( - db, - predicates, - node.if_false, - base_ty, - place, - false_accumulated, - ); + let narrowed = narrow!(node.if_false, false_accumulated); + memo.insert(key, narrowed); + return narrowed; } // If the false branch is statically unreachable, skip it entirely. if node.if_false == ALWAYS_FALSE { let true_accumulated = accumulate_constraint(db, accumulated, pos_constraint); - return self.narrow_by_constraint_inner( - db, - predicates, - node.if_true, - base_ty, - place, - true_accumulated, - ); + let narrowed = narrow!(node.if_true, true_accumulated); + memo.insert(key, narrowed); + return narrowed; } // True branch: predicate holds → accumulate positive narrowing - let true_accumulated = - accumulate_constraint(db, accumulated.clone(), pos_constraint); - let true_ty = self.narrow_by_constraint_inner( - db, - predicates, - node.if_true, - base_ty, - place, - true_accumulated, - ); + let true_accumulated = accumulate_constraint(db, accumulated, pos_constraint); + let true_ty = narrow!(node.if_true, true_accumulated); + + // Narrowing can only produce subtypes of `base_ty`, so + // if one branch already returns `base_ty`, skip the other. + if true_ty == base_ty { + memo.insert(key, base_ty); + return base_ty; + } // False branch: predicate doesn't hold → accumulate negative narrowing - let neg_predicate = Predicate { - node: predicate.node, - is_positive: !predicate.is_positive, - }; - let neg_constraint = infer_narrowing_constraint(db, neg_predicate, place); let false_accumulated = accumulate_constraint(db, accumulated, neg_constraint); - let false_ty = self.narrow_by_constraint_inner( - db, - predicates, - node.if_false, - base_ty, - place, - false_accumulated, - ); - - UnionType::from_elements(db, [true_ty, false_ty]) + let false_ty = narrow!(node.if_false, false_accumulated); + + if false_ty == base_ty { + memo.insert(key, base_ty); + return base_ty; + } + + // We won't do a union type redundancy check here, as it only needs to be performed once for the final result. + UnionType::from_elements_no_redundancy_check(db, [true_ty, false_ty]) } - } + }; + + memo.insert(key, narrowed); + narrowed + } + + fn predicate_applies_to_place_version( + place_version_info: Option<&PredicatePlaceVersionInfo>, + binding_place_version: Option, + ) -> bool { + binding_place_version.is_none_or(|binding_place_version| { + place_version_info.is_some_and(|info| info.versions.contains(&binding_place_version)) + }) } /// Analyze the statically known reachability for a given constraint. @@ -1168,4 +1226,18 @@ impl ReachabilityConstraints { } } } + + fn analyze_single_cached<'db>( + db: &'db dyn Db, + predicate: Predicate<'db>, + memo: &mut FxHashMap, Truthiness>, + ) -> Truthiness { + if let Some(cached) = memo.get(&predicate) { + return *cached; + } + + let analyzed = Self::analyze_single(db, &predicate); + memo.insert(predicate, analyzed); + analyzed + } } diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index c0597e87a8e51b..36ecebfbe3de74 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -242,6 +242,7 @@ use ruff_index::{IndexVec, newtype_index}; use rustc_hash::FxHashMap; +use smallvec::SmallVec; use crate::node_key::NodeKey; use crate::place::BoundnessAnalysis; @@ -268,7 +269,15 @@ use crate::types::{PossiblyNarrowedPlaces, Truthiness, Type}; mod place_state; pub(super) use place_state::PreviousDefinitions; -pub(crate) use place_state::{LiveBinding, ScopedDefinitionId}; +pub(crate) use place_state::{LiveBinding, PlaceVersion, ScopedDefinitionId}; + +#[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)] +pub(crate) struct PredicatePlaceVersionInfo { + pub(crate) versions: SmallVec<[PlaceVersion; 2]>, +} + +pub(crate) type PredicatePlaceVersions = + FxHashMap<(ScopedPredicateId, ScopedPlaceId), PredicatePlaceVersionInfo>; /// Applicable definitions and constraints for every use of a name. #[derive(Debug, PartialEq, Eq, salsa::Update, get_size2::GetSize)] @@ -280,6 +289,15 @@ pub(crate) struct UseDefMap<'db> { /// Array of predicates in this scope. predicates: Predicates<'db>, + /// Place version associated with each definition ID. + /// + /// This stores the version once per definition instead of duplicating it in every `LiveBinding` + /// clone across `bindings_by_use` / snapshots. + definition_place_versions: IndexVec, + + /// Place versions to which a given predicate occurrence can apply for narrowing. + predicate_place_versions: PredicatePlaceVersions, + /// Array of reachability constraints in this scope. reachability_constraints: ReachabilityConstraints, @@ -373,7 +391,9 @@ impl<'db> UseDefMap<'db> { ApplicableConstraints::UnboundBinding(NarrowingEvaluator { constraint, predicates: &self.predicates, + predicate_place_versions: &self.predicate_place_versions, reachability_constraints: &self.reachability_constraints, + binding_place_version: None, }) } ConstraintKey::NestedScope(nested_scope) => { @@ -413,7 +433,9 @@ impl<'db> UseDefMap<'db> { NarrowingEvaluator { constraint, predicates: &self.predicates, + predicate_place_versions: &self.predicate_place_versions, reachability_constraints: &self.reachability_constraints, + binding_place_version: None, } } @@ -654,7 +676,9 @@ impl<'db> UseDefMap<'db> { ) -> BindingWithConstraintsIterator<'map, 'db> { BindingWithConstraintsIterator { all_definitions: &self.all_definitions, + definition_place_versions: &self.definition_place_versions, predicates: &self.predicates, + predicate_place_versions: &self.predicate_place_versions, reachability_constraints: &self.reachability_constraints, boundness_analysis, inner: bindings.iter(), @@ -709,7 +733,9 @@ type EnclosingSnapshots = IndexVec #[derive(Clone, Debug)] pub(crate) struct BindingWithConstraintsIterator<'map, 'db> { pub(crate) all_definitions: &'map IndexVec>, + definition_place_versions: &'map IndexVec, pub(crate) predicates: &'map Predicates<'db>, + pub(crate) predicate_place_versions: &'map PredicatePlaceVersions, pub(crate) reachability_constraints: &'map ReachabilityConstraints, pub(crate) boundness_analysis: BoundnessAnalysis, inner: LiveBindingsIterator<'map>, @@ -720,6 +746,7 @@ impl<'map, 'db> Iterator for BindingWithConstraintsIterator<'map, 'db> { fn next(&mut self) -> Option { let predicates = self.predicates; + let predicate_place_versions = self.predicate_place_versions; let reachability_constraints = self.reachability_constraints; self.inner @@ -729,7 +756,11 @@ impl<'map, 'db> Iterator for BindingWithConstraintsIterator<'map, 'db> { narrowing_constraint: NarrowingEvaluator { constraint: live_binding.narrowing_constraint, predicates, + predicate_place_versions, reachability_constraints, + binding_place_version: Some( + self.definition_place_versions[live_binding.binding], + ), }, reachability_constraint: live_binding.reachability_constraint, }) @@ -747,7 +778,9 @@ pub(crate) struct BindingWithConstraints<'map, 'db> { pub(crate) struct NarrowingEvaluator<'map, 'db> { pub(crate) constraint: ScopedNarrowingConstraint, predicates: &'map Predicates<'db>, + predicate_place_versions: &'map PredicatePlaceVersions, reachability_constraints: &'map ReachabilityConstraints, + binding_place_version: Option, } impl<'db> NarrowingEvaluator<'_, 'db> { @@ -760,9 +793,11 @@ impl<'db> NarrowingEvaluator<'_, 'db> { self.reachability_constraints.narrow_by_constraint( db, self.predicates, + self.predicate_place_versions, self.constraint, base_ty, place, + self.binding_place_version, ) } } @@ -828,9 +863,15 @@ pub(super) struct UseDefMapBuilder<'db> { /// Append-only array of [`DefinitionState`]. all_definitions: IndexVec>, + /// Place version associated with each definition ID. + definition_place_versions: IndexVec, + /// Builder of predicates. pub(super) predicates: PredicatesBuilder<'db>, + /// Place versions to which a given predicate occurrence can apply for narrowing. + predicate_place_versions: PredicatePlaceVersions, + /// Builder of reachability constraints. pub(super) reachability_constraints: ReachabilityConstraintsBuilder, @@ -872,7 +913,9 @@ impl<'db> UseDefMapBuilder<'db> { pub(super) fn new(is_class_scope: bool) -> Self { Self { all_definitions: IndexVec::from_iter([DefinitionState::Undefined]), + definition_place_versions: IndexVec::from_iter([PlaceVersion::default()]), predicates: PredicatesBuilder::default(), + predicate_place_versions: PredicatePlaceVersions::default(), reachability_constraints: ReachabilityConstraintsBuilder::default(), bindings_by_use: IndexVec::new(), reachability: ScopedReachabilityConstraintId::ALWAYS_TRUE, @@ -959,13 +1002,15 @@ impl<'db> UseDefMapBuilder<'db> { self.declarations_by_binding .insert(binding, place_state.declarations().clone()); - place_state.record_binding( + let place_version = place_state.record_binding( def_id, self.reachability, self.is_class_scope, place.is_symbol(), previous_definitions, ); + let version_id = self.definition_place_versions.push(place_version); + debug_assert_eq!(def_id, version_id); let bindings = match place { ScopedPlaceId::Symbol(symbol) => { @@ -1009,6 +1054,8 @@ impl<'db> UseDefMapBuilder<'db> { return; } + self.record_predicate_place_versions(predicate, places); + let atom = self.reachability_constraints.add_atom(predicate); self.record_narrowing_constraint_node_for_places(atom, places); } @@ -1030,11 +1077,46 @@ impl<'db> UseDefMapBuilder<'db> { return; } + self.record_predicate_place_versions(predicate, places); + let atom = self.reachability_constraints.add_atom(predicate); let negated = self.reachability_constraints.add_not_constraint(atom); self.record_narrowing_constraint_node_for_places(negated, places); } + fn record_predicate_place_versions( + &mut self, + predicate: ScopedPredicateId, + places: &PossiblyNarrowedPlaces, + ) { + for place in places { + let bindings = match place { + ScopedPlaceId::Symbol(symbol_id) => { + self.symbol_states.get(*symbol_id).map(PlaceState::bindings) + } + ScopedPlaceId::Member(member_id) => { + self.member_states.get(*member_id).map(PlaceState::bindings) + } + }; + let Some(bindings) = bindings else { + continue; + }; + + let versions = bindings + .iter() + .map(|binding| self.definition_place_versions[binding.binding]); + let entry = self + .predicate_place_versions + .entry((predicate, *place)) + .or_default(); + for version in versions { + if !entry.versions.contains(&version) { + entry.versions.push(version); + } + } + } + } + /// Records a TDD narrowing constraint node for the specified places. fn record_narrowing_constraint_node_for_places( &mut self, @@ -1202,6 +1284,8 @@ impl<'db> UseDefMapBuilder<'db> { let def_id = self .all_definitions .push(DefinitionState::Defined(declaration)); + let version_id = self.definition_place_versions.push(PlaceVersion::default()); + debug_assert_eq!(def_id, version_id); let place_state = match place { ScopedPlaceId::Symbol(symbol) => &mut self.symbol_states[symbol], @@ -1239,13 +1323,15 @@ impl<'db> UseDefMapBuilder<'db> { ScopedPlaceId::Member(member) => &mut self.member_states[member], }; place_state.record_declaration(def_id, self.reachability); - place_state.record_binding( + let place_version = place_state.record_binding( def_id, self.reachability, self.is_class_scope, place.is_symbol(), PreviousDefinitions::AreShadowed, ); + let version_id = self.definition_place_versions.push(place_version); + debug_assert_eq!(def_id, version_id); let reachable_definitions = match place { ScopedPlaceId::Symbol(symbol) => &mut self.reachable_symbol_definitions[symbol], @@ -1272,14 +1358,15 @@ impl<'db> UseDefMapBuilder<'db> { ScopedPlaceId::Symbol(symbol) => &mut self.symbol_states[symbol], ScopedPlaceId::Member(member) => &mut self.member_states[member], }; - - place_state.record_binding( + let place_version = place_state.record_binding( def_id, self.reachability, self.is_class_scope, place.is_symbol(), PreviousDefinitions::AreShadowed, ); + let version_id = self.definition_place_versions.push(place_version); + debug_assert_eq!(def_id, version_id); } pub(super) fn record_use( @@ -1504,11 +1591,13 @@ impl<'db> UseDefMapBuilder<'db> { self.mark_reachability_constraints(); self.all_definitions.shrink_to_fit(); + self.definition_place_versions.shrink_to_fit(); self.symbol_states.shrink_to_fit(); self.member_states.shrink_to_fit(); self.reachable_symbol_definitions.shrink_to_fit(); self.reachable_member_definitions.shrink_to_fit(); self.bindings_by_use.shrink_to_fit(); + self.predicate_place_versions.shrink_to_fit(); self.node_reachability.shrink_to_fit(); self.declarations_by_binding.shrink_to_fit(); self.bindings_by_definition.shrink_to_fit(); @@ -1517,6 +1606,8 @@ impl<'db> UseDefMapBuilder<'db> { UseDefMap { all_definitions: self.all_definitions, predicates: self.predicates.build(), + definition_place_versions: self.definition_place_versions, + predicate_place_versions: self.predicate_place_versions, reachability_constraints: self.reachability_constraints.build(), bindings_by_use: self.bindings_by_use, node_reachability: self.node_reachability, diff --git a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs index 71833f34063979..e34ba4a5ce2350 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def/place_state.rs @@ -69,6 +69,29 @@ impl ScopedDefinitionId { } } +/// A monotonically increasing place generation. +/// +/// The generation increments whenever bindings for a place are shadowed by reassignment. +#[newtype_index] +#[derive(Ord, PartialOrd, salsa::Update, get_size2::GetSize)] +pub(crate) struct PlaceVersion; + +impl Default for PlaceVersion { + fn default() -> Self { + PlaceVersion::from_u32(0) + } +} + +impl PlaceVersion { + pub(crate) fn next(self) -> PlaceVersion { + let next = self + .as_u32() + .checked_add(1) + .expect("PlaceVersion overflowed"); + PlaceVersion::from_u32(next) + } +} + /// Live declarations for a single place at some point in control flow, with their /// corresponding reachability constraints. #[derive(Clone, Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)] @@ -213,7 +236,10 @@ pub(super) struct Bindings { /// "unbound" binding. unbound_narrowing_constraint: Option, /// A list of live bindings for this place, sorted by their `ScopedDefinitionId` + #[allow(clippy::struct_field_names)] live_bindings: SmallVec<[LiveBinding; 2]>, + /// Latest place version seen for this place. + latest_place_version: PlaceVersion, } impl Bindings { @@ -251,6 +277,7 @@ impl Bindings { Self { unbound_narrowing_constraint: None, live_bindings: smallvec![initial_binding], + latest_place_version: PlaceVersion::default(), } } @@ -262,7 +289,7 @@ impl Bindings { is_class_scope: bool, is_place_name: bool, previous_definitions: PreviousDefinitions, - ) { + ) -> PlaceVersion { // If we are in a class scope, and the unbound name binding was previously visible, but we will // now replace it, record the narrowing constraints on it: if is_class_scope && is_place_name && self.live_bindings[0].binding.is_unbound() { @@ -272,12 +299,14 @@ impl Bindings { // constraints. if previous_definitions.are_shadowed() { self.live_bindings.clear(); + self.latest_place_version = self.latest_place_version.next(); } self.live_bindings.push(LiveBinding { binding, narrowing_constraint: ScopedNarrowingConstraint::ALWAYS_TRUE, reachability_constraint, }); + self.latest_place_version } /// Add given constraint to all live bindings. @@ -315,6 +344,7 @@ impl Bindings { reachability_constraints: &mut ReachabilityConstraintsBuilder, ) { let a = std::mem::take(self); + self.latest_place_version = a.latest_place_version.max(b.latest_place_version); if let Some((a, b)) = a .unbound_narrowing_constraint @@ -334,15 +364,29 @@ impl Bindings { for zipped in a.merge_join_by(b, |a, b| a.binding.cmp(&b.binding)) { match zipped { EitherOrBoth::Both(a, b) => { - // If the same definition is visible through both paths, we OR the narrowing - // constraints: the type should be narrowed by whichever path was taken. - let narrowing_constraint = reachability_constraints - .add_or_constraint(a.narrowing_constraint, b.narrowing_constraint); - // For reachability constraints, we also merge using a ternary OR operation: let reachability_constraint = reachability_constraints .add_or_constraint(a.reachability_constraint, b.reachability_constraint); + let narrowing_constraint = if a.narrowing_constraint == b.narrowing_constraint { + // short-circuit: if both sides have the same constraint, we can use that constraint without needing to create a new TDD node. + a.narrowing_constraint + } else if a.reachability_constraint == b.reachability_constraint { + reachability_constraints + .add_or_constraint(a.narrowing_constraint, b.narrowing_constraint) + } else { + // A branch contributes narrowing only when it is reachable. + // Without this gating, `OR(a_narrowing, b_narrowing)` allows an unreachable + // branch with `ALWAYS_TRUE` narrowing to cancel useful narrowing from the + // reachable branch. + let a_narrowing_gated = reachability_constraints + .add_and_constraint(a.narrowing_constraint, a.reachability_constraint); + let b_narrowing_gated = reachability_constraints + .add_and_constraint(b.narrowing_constraint, b.reachability_constraint); + reachability_constraints + .add_or_constraint(a_narrowing_gated, b_narrowing_gated) + }; + self.live_bindings.push(LiveBinding { binding: a.binding, narrowing_constraint, @@ -381,7 +425,7 @@ impl PlaceState { is_class_scope: bool, is_place_name: bool, previous_definitions: PreviousDefinitions, - ) { + ) -> PlaceVersion { debug_assert_ne!(binding_id, ScopedDefinitionId::UNBOUND); self.bindings.record_binding( binding_id, @@ -389,7 +433,7 @@ impl PlaceState { is_class_scope, is_place_name, previous_definitions, - ); + ) } /// Add given constraint to all live bindings. @@ -636,6 +680,31 @@ mod tests { assert_eq!(bindings[1].1, atom0); assert_eq!(bindings[2].0, 3); assert_eq!(bindings[2].1, atom3); + + // An unreachable branch should not dilute narrowing from the reachable branch. + let mut sym4a = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); + sym4a.record_binding( + ScopedDefinitionId::from_u32(4), + ScopedReachabilityConstraintId::ALWAYS_FALSE, + false, + true, + PreviousDefinitions::AreShadowed, + ); + + let mut sym4b = PlaceState::undefined(ScopedReachabilityConstraintId::ALWAYS_TRUE); + sym4b.record_binding( + ScopedDefinitionId::from_u32(4), + ScopedReachabilityConstraintId::ALWAYS_TRUE, + false, + true, + PreviousDefinitions::AreShadowed, + ); + let atom4 = reachability_constraints.add_atom(ScopedPredicateId::new(4)); + sym4b.record_narrowing_constraint(&mut reachability_constraints, atom4); + + sym4a.merge(sym4b, &mut reachability_constraints); + let merged_constraint = sym4a.bindings().iter().next().unwrap().narrowing_constraint; + assert_eq!(merged_constraint, atom4); } #[test] diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index a4d17e190a1633..e7a71731059ec5 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4399,7 +4399,7 @@ impl<'db> Type<'db> { .with_annotated_type(typevar_meta)]; // Intersect with `Any` for the return type to reflect the fact that the `dataclass()` // decorator adds methods to the class - let returns = IntersectionType::from_elements(db, [typevar_meta, Type::any()]); + let returns = IntersectionType::from_two_elements(db, typevar_meta, Type::any()); let signature = Signature::new_generic(Some(context), Parameters::new(db, parameters), returns); Binding::single(self, signature).into() @@ -6023,12 +6023,10 @@ impl<'db> Type<'db> { // but it appears to be what users often expect, and it improves compatibility with // other type checkers such as mypy. // See conversation in https://github.com/astral-sh/ruff/pull/19915. - SpecialFormType::NamedTuple => Ok(IntersectionType::from_elements( + SpecialFormType::NamedTuple => Ok(IntersectionType::from_two_elements( db, - [ - Type::homogeneous_tuple(db, Type::object()), - KnownClass::NamedTupleLike.to_instance(db), - ], + Type::homogeneous_tuple(db, Type::object()), + KnownClass::NamedTupleLike.to_instance(db), )), SpecialFormType::TypingSelf => { let index = semantic_index(db, scope_id.file(db)); @@ -12309,6 +12307,20 @@ impl<'db> UnionType<'db> { .build() } + pub(crate) fn from_elements_no_redundancy_check(db: &'db dyn Db, elements: I) -> Type<'db> + where + I: IntoIterator, + T: Into>, + { + elements + .into_iter() + .fold( + UnionBuilder::new(db).check_redundancy(false), + |builder, element| builder.add(element.into()), + ) + .build() + } + /// Create a union from a list of elements without unpacking type aliases. pub(crate) fn from_elements_leave_aliases(db: &'db dyn Db, elements: I) -> Type<'db> where @@ -12934,6 +12946,7 @@ pub(super) fn walk_intersection_type<'db, V: visitor::TypeVisitor<'db> + ?Sized> } } +#[salsa::tracked] impl<'db> IntersectionType<'db> { pub(crate) fn from_elements(db: &'db dyn Db, elements: I) -> Type<'db> where @@ -12945,6 +12958,19 @@ impl<'db> IntersectionType<'db> { .build() } + #[salsa::tracked( + cycle_initial=|_, id, _, _| Type::divergent(id), + cycle_fn=|db, cycle, previous: &Type<'db>, result: Type<'db>, _, _| { + result.cycle_normalized(db, *previous, cycle) + }, + heap_size=ruff_memory_usage::heap_size + )] + fn from_two_elements(db: &'db dyn Db, a: Type<'db>, b: Type<'db>) -> Type<'db> { + IntersectionBuilder::new(db) + .positive_elements([a, b]) + .build() + } + /// Return a new `IntersectionType` instance with the positive and negative types sorted /// according to a canonical ordering, and other normalizations applied to each element as applicable. /// diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 233b75277f4094..9cbbf92aa6ca1a 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -54,6 +54,81 @@ enum LiteralKind<'db> { Enum { enum_class: ClassLiteral<'db> }, } +/// Extract `(core, guard)` from truthiness-guarded intersections. +/// +/// e.g. +/// - `A & ~AlwaysTruthy` -> `Some((A, ~AlwaysTruthy))` +/// - `A & ~AlwaysFalsy` -> `Some((A, ~AlwaysFalsy))` +/// - `A` -> `None` +/// - `A & ~AlwaysTruthy & ~AlwaysFalsy` -> `None` (not a single-guard shape) +/// +/// This only recognizes the "single truthiness guard" forms used by truthiness narrowing. +fn split_truthiness_guarded_intersection<'db>( + db: &'db dyn Db, + ty: Type<'db>, +) -> Option<(Type<'db>, Type<'db>)> { + let Type::Intersection(intersection) = ty else { + return None; + }; + + let has_not_truthy = intersection.negative(db).contains(&Type::AlwaysTruthy); + let has_not_falsy = intersection.negative(db).contains(&Type::AlwaysFalsy); + let guard = match (has_not_truthy, has_not_falsy) { + (true, false) => Type::AlwaysTruthy.negate(db), + (false, true) => Type::AlwaysFalsy.negate(db), + _ => return None, + }; + + let mut core = IntersectionBuilder::new(db); + for positive in intersection.positive(db) { + core = core.add_positive(*positive); + } + for negative in intersection.negative(db) { + if (guard == Type::AlwaysTruthy.negate(db) && *negative == Type::AlwaysTruthy) + || (guard == Type::AlwaysFalsy.negate(db) && *negative == Type::AlwaysFalsy) + { + continue; + } + core = core.add_negative(*negative); + } + Some((core.build(), guard)) +} + +/// Try to merge a complementary guarded pair into an unguarded core. +/// +/// e.g. +/// - `(A & ~AlwaysTruthy, A & ~AlwaysFalsy)` -> `Some(A)` +/// - `(A & ~AlwaysTruthy, B & ~AlwaysFalsy)` -> `Some(A | B)` if reconstruction is exact +/// - `(A & ~AlwaysTruthy, C)` -> `None` +/// +/// Safety rule: +/// The candidate merge is accepted only if adding each original guard back reconstructs +/// exactly the original operands (`left` and `right`). +fn merge_truthiness_guarded_pair<'db>( + db: &'db dyn Db, + left: Type<'db>, + right: Type<'db>, +) -> Option> { + let (left_core, left_guard) = split_truthiness_guarded_intersection(db, left)?; + let (right_core, right_guard) = split_truthiness_guarded_intersection(db, right)?; + if left_guard == right_guard { + return None; + } + + if left_core.is_equivalent_to(db, right_core) { + return Some(left_core); + } + + let candidate = UnionType::from_elements(db, [left_core, right_core]); + let left_reconstructed = IntersectionType::from_two_elements(db, candidate, left_guard); + let right_reconstructed = IntersectionType::from_two_elements(db, candidate, right_guard); + if left_reconstructed == left && right_reconstructed == right { + Some(candidate) + } else { + None + } +} + impl<'db> Type<'db> { /// Return `true` if this type can be a supertype of some literals of `kind` and not others. fn splits_literals(self, db: &'db dyn Db, kind: LiteralKind) -> bool { @@ -266,11 +341,13 @@ const MAX_NON_RECURSIVE_UNION_LITERALS: usize = 256; /// if reachability analysis etc. fails when analysing these enums. const MAX_NON_RECURSIVE_UNION_ENUM_LITERALS: usize = 8192; +#[allow(clippy::struct_excessive_bools)] pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, unpack_aliases: bool, order_elements: bool, + check_redundancy: bool, /// This is enabled when joining types in a `cycle_recovery` function. /// Since a cycle cannot be created within a `cycle_recovery` function, /// execution of `is_redundant_with` is skipped. @@ -285,6 +362,7 @@ impl<'db> UnionBuilder<'db> { elements: vec![], unpack_aliases: true, order_elements: false, + check_redundancy: true, cycle_recovery: false, recursively_defined: RecursivelyDefined::No, } @@ -300,9 +378,15 @@ impl<'db> UnionBuilder<'db> { self } + pub(crate) fn check_redundancy(mut self, val: bool) -> Self { + self.check_redundancy = val; + self + } + pub(crate) fn cycle_recovery(mut self, val: bool) -> Self { self.cycle_recovery = val; if self.cycle_recovery { + self.check_redundancy = false; self.unpack_aliases = false; } self @@ -658,16 +742,16 @@ impl<'db> UnionBuilder<'db> { } fn push_type(&mut self, ty: Type<'db>, seen_aliases: &mut Vec>) { - let bool_pair = if let Some(LiteralValueTypeKind::Bool(b)) = ty.as_literal_value_kind() { - Some(LiteralValueTypeKind::Bool(!b)) - } else { - None + let mut ty = ty; + let bool_pair = |lit: LiteralValueTypeKind| match lit { + LiteralValueTypeKind::Bool(b) => Some(LiteralValueTypeKind::Bool(!b)), + _ => None, }; // If an alias gets here, it means we aren't unpacking aliases, and we also // shouldn't try to simplify aliases out of the union, because that will require // unpacking them. - let should_simplify_full = !matches!(ty, Type::TypeAlias(_)) && !self.cycle_recovery; + let should_simplify_full = !matches!(ty, Type::TypeAlias(_)) && self.check_redundancy; let mut ty_negated: Option = None; let mut to_remove = SmallVec::<[usize; 2]>::new(); @@ -694,9 +778,16 @@ impl<'db> UnionBuilder<'db> { return; } + // Fold `(T & ~AlwaysTruthy) | (T & ~AlwaysFalsy)` to `T`. + if let Some(merged_type) = merge_truthiness_guarded_pair(self.db, ty, element_type) { + to_remove.push(i); + ty = merged_type; + continue; + } + if element_type .as_literal_value_kind() - .zip(bool_pair) + .zip(ty.as_literal_value_kind().and_then(bool_pair)) .is_some_and(|(element, pair)| element == pair) { self.add_in_place_impl(KnownClass::Bool.to_instance(self.db), seen_aliases); @@ -721,19 +812,27 @@ impl<'db> UnionBuilder<'db> { continue; } - let negated = ty_negated.get_or_insert_with(|| ty.negate(self.db)); - if negated.is_subtype_of(self.db, element_type) { - // We add `ty` to the union. We just checked that `~ty` is a subtype of an - // existing `element`. This also means that `~ty | ty` is a subtype of - // `element | ty`, because both elements in the first union are subtypes of - // the corresponding elements in the second union. But `~ty | ty` is just - // `object`. Since `object` is a subtype of `element | ty`, we can only - // conclude that `element | ty` must be `object` (object has no other - // supertypes). This means we can simplify the whole union to just - // `object`, since all other potential elements would also be subtypes of - // `object`. - self.collapse_to_object(); - return; + // Skip the negate/subtype check for intersection-to-intersection pairs. + // For intersections, ~(A & B & ...) = ~A | ~B | ..., which is a broad union + // of complements. Such a union cannot be a subtype of another intersection + // of class types in practice, making this check always false but expensive. + if !(ty.is_nontrivial_intersection(self.db) + && element_type.is_nontrivial_intersection(self.db)) + { + let negated = ty_negated.get_or_insert_with(|| ty.negate(self.db)); + if negated.is_subtype_of(self.db, element_type) { + // We add `ty` to the union. We just checked that `~ty` is a subtype of an + // existing `element`. This also means that `~ty | ty` is a subtype of + // `element | ty`, because both elements in the first union are subtypes of + // the corresponding elements in the second union. But `~ty | ty` is just + // `object`. Since `object` is a subtype of `element | ty`, we can only + // conclude that `element | ty` must be `object` (object has no other + // supertypes). This means we can simplify the whole union to just + // `object`, since all other potential elements would also be subtypes of + // `object`. + self.collapse_to_object(); + return; + } } } } diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index 97f570a50e3cac..162040260d10f1 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -788,7 +788,7 @@ impl<'db> ConstrainedTypeVar<'db> { // (s₁ ≤ α ≤ t₁) ∧ (s₂ ≤ α ≤ t₂) = (s₁ ∪ s₂) ≤ α ≤ (t₁ ∩ t₂)) let lower = UnionType::from_elements(db, [self.lower(db), other.lower(db)]); - let upper = IntersectionType::from_elements(db, [self_upper, other_upper]); + let upper = IntersectionType::from_two_elements(db, self_upper, other_upper); // If `lower ≰ upper`, then the intersection is empty, since there is no type that is both // greater than `lower`, and less than `upper`. diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 62c9b69c1d2dd0..ff98cbeae9acc0 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -2158,7 +2158,7 @@ impl<'db> SpecializationBuilder<'db> { // check here. self.add_type_mapping( bound_typevar, - IntersectionType::from_elements(self.db, [bound, ty]), + IntersectionType::from_two_elements(self.db, bound, ty), polarity, f, ); diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index c86636f546eaed..46e1ab92ac58e9 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -10840,7 +10840,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // different overloads provide different type context; unioning may be more // correct in those cases. *argument_type = argument_type - .map(|current| IntersectionType::from_elements(db, [inferred_ty, current])) + .map(|current| { + IntersectionType::from_two_elements(db, inferred_ty, current) + }) .or(Some(inferred_ty)); } @@ -11054,7 +11056,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .annotation .is_none_or(|tcx| ty.is_assignable_to(db, tcx)) { - *current = IntersectionType::from_elements(db, [*current, ty]); + *current = IntersectionType::from_two_elements(db, *current, ty); } }) .or_insert(ty); diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 31f8b7fdffdfd1..39ff43799b78dd 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -28,9 +28,10 @@ use super::UnionType; use itertools::Itertools; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::{FxHashMap, FxHashSet, FxHasher}; use smallvec::{SmallVec, smallvec, smallvec_inline}; use std::collections::hash_map::Entry; +use std::hash::{Hash, Hasher}; /// A set of places that could possibly be narrowed by a predicate. /// @@ -62,70 +63,206 @@ pub(crate) fn infer_narrowing_constraint<'db>( ) -> Option> { let constraints = match predicate.node { PredicateNode::Expression(expression) => { - if predicate.is_positive { - all_narrowing_constraints_for_expression(db, expression) - } else { - all_negative_narrowing_constraints_for_expression(db, expression) - } - } - PredicateNode::Pattern(pattern) => { - if predicate.is_positive { - all_narrowing_constraints_for_pattern(db, pattern) - } else { - all_negative_narrowing_constraints_for_pattern(db, pattern) - } + all_narrowing_constraints_for_expression(db, expression) } + PredicateNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern), PredicateNode::ReturnsNever(_) => return None, PredicateNode::StarImportPlaceholder(_) => return None, }; - constraints.and_then(|constraints| constraints.get(&place).cloned()) + constraints.and_then(|constraints| constraints.get(db, place, predicate.is_positive)) } -#[salsa::tracked(returns(as_ref), heap_size=ruff_memory_usage::heap_size)] -fn all_narrowing_constraints_for_pattern<'db>( - db: &'db dyn Db, - pattern: PatternPredicate<'db>, -) -> Option> { - let module = parsed_module(db, pattern.file(db)).load(db); - NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), true).finish() +#[derive(Default, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] +struct PerPlaceDualNarrowingConstraintBuilder<'db> { + positive: Option>, + negative: Option>, } -#[salsa::tracked( - returns(as_ref), - cycle_initial=|_, _, _| None, - heap_size=ruff_memory_usage::heap_size, -)] -fn all_narrowing_constraints_for_expression<'db>( - db: &'db dyn Db, - expression: Expression<'db>, -) -> Option> { - let module = parsed_module(db, expression.file(db)).load(db); - NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), true) - .finish() +type DualNarrowingConstraintsBuilderMap<'db> = + FxHashMap>; + +#[derive(Default, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] +struct DualNarrowingConstraintsBuilder<'db> { + by_place: DualNarrowingConstraintsBuilderMap<'db>, + has_positive: bool, + has_negative: bool, +} + +#[derive(Default, PartialEq, Debug, Eq, Clone, Hash, salsa::Update, get_size2::GetSize)] +struct PerPlaceDualNarrowingConstraint<'db> { + positive: Option>, + negative: Option>, +} + +type DualNarrowingConstraintsMap<'db> = + FxHashMap>; + +#[derive(Default, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] +struct DualNarrowingConstraintsPayload<'db> { + by_place: DualNarrowingConstraintsMap<'db>, + has_positive: bool, + has_negative: bool, +} + +impl Hash for DualNarrowingConstraintsPayload<'_> { + fn hash(&self, state: &mut H) { + self.has_positive.hash(state); + self.has_negative.hash(state); + self.by_place.len().hash(state); + + // HashMap iteration order is unstable, so compute an order-independent aggregate hash. + let mut entries_hash = 0_u64; + for (place, constraints) in &self.by_place { + let mut hasher = FxHasher::default(); + place.hash(&mut hasher); + constraints.hash(&mut hasher); + entries_hash ^= hasher.finish(); + } + entries_hash.hash(state); + } +} + +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +struct DualNarrowingConstraints<'db> { + #[returns(ref)] + data: DualNarrowingConstraintsPayload<'db>, +} + +// The Salsa heap is tracked separately. +impl get_size2::GetSize for DualNarrowingConstraints<'_> {} + +impl<'db> DualNarrowingConstraintsBuilder<'db> { + fn from_sides( + positive: Option>, + negative: Option>, + ) -> Self { + let mut by_place = DualNarrowingConstraintsBuilderMap::default(); + let has_positive = positive.is_some(); + let has_negative = negative.is_some(); + + if let Some(positive) = positive { + for (place, constraint) in positive { + by_place.entry(place).or_default().positive = Some(constraint); + } + } + + if let Some(negative) = negative { + for (place, constraint) in negative { + by_place.entry(place).or_default().negative = Some(constraint); + } + } + + Self { + by_place, + has_positive, + has_negative, + } + } + + fn into_sides( + self, + ) -> ( + Option>, + Option>, + ) { + let mut positive = self.has_positive.then(FxHashMap::default); + let mut negative = self.has_negative.then(FxHashMap::default); + + for (place, constraints) in self.by_place { + if let (Some(positive), Some(constraint)) = (&mut positive, constraints.positive) { + positive.insert(place, constraint); + } + if let (Some(negative), Some(constraint)) = (&mut negative, constraints.negative) { + negative.insert(place, constraint); + } + } + + (positive, negative) + } + + fn shrink_to_fit(&mut self) { + self.by_place.shrink_to_fit(); + } + + fn swap_polarity(mut self) -> Self { + std::mem::swap(&mut self.has_positive, &mut self.has_negative); + for constraints in self.by_place.values_mut() { + std::mem::swap(&mut constraints.positive, &mut constraints.negative); + } + self + } + + fn finish(self, db: &'db dyn Db) -> DualNarrowingConstraints<'db> { + let mut by_place = DualNarrowingConstraintsMap::default(); + for (place, constraints) in self.by_place { + by_place.insert( + place, + PerPlaceDualNarrowingConstraint { + positive: constraints.positive.map(|constraint| constraint.finish(db)), + negative: constraints.negative.map(|constraint| constraint.finish(db)), + }, + ); + } + + DualNarrowingConstraints::new( + db, + DualNarrowingConstraintsPayload { + by_place, + has_positive: self.has_positive, + has_negative: self.has_negative, + }, + ) + } } +impl<'db> DualNarrowingConstraints<'db> { + fn get( + self, + db: &'db dyn Db, + place: ScopedPlaceId, + is_positive: bool, + ) -> Option> { + let data = self.data(db); + if is_positive && !data.has_positive || !is_positive && !data.has_negative { + return None; + } + + data.by_place.get(&place).and_then(|constraints| { + if is_positive { + constraints.positive + } else { + constraints.negative + } + }) + } +} + +#[allow(clippy::unnecessary_wraps)] #[salsa::tracked( returns(as_ref), cycle_initial=|_, _, _| None, heap_size=ruff_memory_usage::heap_size, )] -fn all_negative_narrowing_constraints_for_expression<'db>( +fn all_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, expression.file(db)).load(db); - NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), false) - .finish() + Some( + NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression)) + .finish(db), + ) } +#[allow(clippy::unnecessary_wraps)] #[salsa::tracked(returns(as_ref), heap_size=ruff_memory_usage::heap_size)] -fn all_negative_narrowing_constraints_for_pattern<'db>( +fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternPredicate<'db>, -) -> Option> { +) -> Option> { let module = parsed_module(db, pattern.file(db)).load(db); - NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), false).finish() + Some(NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern)).finish(db)) } /// Functions that can be used to narrow the type of a first argument using a "classinfo" second argument. @@ -264,6 +401,48 @@ impl ClassInfoConstraintFunction { } } +#[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] +struct Conjunctions<'db> { + conjuncts: SmallVec<[Type<'db>; 2]>, +} + +impl<'db> Conjunctions<'db> { + fn singleton(ty: Type<'db>) -> Self { + let mut conjuncts = SmallVec::new(); + conjuncts.push(ty); + Self { conjuncts } + } + + fn and_with(&self, other: &Self) -> Self { + if self.conjuncts.iter().any(Type::is_never) || other.conjuncts.iter().any(Type::is_never) { + return Self::singleton(Type::Never); + } + + let mut conjuncts = self.conjuncts.clone(); + for conjunct in other.conjuncts.iter().copied() { + if !conjuncts.contains(&conjunct) { + conjuncts.push(conjunct); + } + } + Self { conjuncts } + } + + fn evaluate_constraint_type(&self, db: &'db dyn Db) -> Type<'db> { + let mut iter = self.conjuncts.iter().copied(); + let Some(first) = iter.next() else { + return Type::Never; + }; + // Fold conjuncts pairwise using `IntersectionType::from_two_elements`. + // When TDD paths share a common prefix of constraints + // (e.g., match cases accumulating ~P1 & ~P2 & ... & ~Pn), + // intermediate results like `base_ty & ~P1` and `(base_ty & ~P1) & ~P2` + // are cached and reused across paths, avoiding redundant recomputation. + iter.fold(first, |result, conjunct| { + IntersectionType::from_two_elements(db, result, conjunct) + }) + } +} + /// Represents narrowing constraints in Disjunctive Normal Form (DNF). /// /// This is a disjunction (OR) of conjunctions (AND) of constraints. @@ -273,38 +452,51 @@ impl ClassInfoConstraintFunction { /// For example: /// - `f(x) and g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]` /// => and -/// ===> `NarrowingConstraint { intersection_disjunct: Some(A), replacement_disjuncts: [] }` -/// ===> `NarrowingConstraint { intersection_disjunct: None, replacement_disjuncts: [B] }` -/// => `NarrowingConstraint { intersection_disjunct: None, replacement_disjuncts: [B] }` +/// ===> `NarrowingConstraint { intersection_disjuncts: [A], replacement_disjuncts: [] }` +/// ===> `NarrowingConstraint { intersection_disjuncts: [], replacement_disjuncts: [B] }` +/// => `NarrowingConstraint { intersection_disjuncts: [], replacement_disjuncts: [B] }` /// => evaluates to `B` (`TypeGuard` clobbers any previous type information) /// /// - `f(x) or g(x)` where f returns `TypeIs[A]` and g returns `TypeGuard[B]` /// => or -/// ===> `NarrowingConstraint { intersection_disjunct: Some(A), replacement_disjuncts: [] }` -/// ===> `NarrowingConstraint { intersection_disjunct: None, replacement_disjuncts: [B] }` -/// => `NarrowingConstraint { intersection_disjunct: Some(A), replacement_disjuncts: [B] }` +/// ===> `NarrowingConstraint { intersection_disjuncts: [A], replacement_disjuncts: [] }` +/// ===> `NarrowingConstraint { intersection_disjuncts: [], replacement_disjuncts: [B] }` +/// => `NarrowingConstraint { intersection_disjuncts: [A], replacement_disjuncts: [B] }` /// => evaluates to `(P & A) | B`, where `P` is our previously-known type #[derive(Hash, PartialEq, Debug, Eq, Clone, salsa::Update, get_size2::GetSize)] -pub(crate) struct NarrowingConstraint<'db> { +pub(crate) struct NarrowingConstraintBuilder<'db> { /// Intersection constraint (from `isinstance()` narrowing comparisons, `TypeIs`, and - /// similar). We can use a single type here because we can eagerly union disjunctions - /// and eagerly intersect conjunctions. - intersection_disjunct: Option>, + /// similar). We keep these as a disjunction of conjunctions to avoid constructing + /// union/intersection types while merging constraints. + intersection_disjuncts: SmallVec<[Conjunctions<'db>; 1]>, /// "Replacement" constraints: instead of intersecting the previous type with a new type, /// the previous type is simply replaced wholesale with the new type. A common use case for /// these constraints is `typing.TypeGuard`. We can't eagerly union disjunctions because /// `TypeGuard` clobbers the previously-known type; within each replacement disjunct, however, /// we may eagerly intersect conjunctions with a later intersection narrowing. - replacement_disjuncts: SmallVec<[Type<'db>; 1]>, + replacement_disjuncts: SmallVec<[Conjunctions<'db>; 1]>, } -impl<'db> NarrowingConstraint<'db> { +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +pub(crate) struct NarrowingConstraint<'db> { + #[returns(ref)] + inner: NarrowingConstraintBuilder<'db>, +} + +// The Salsa heap is tracked separately. +impl get_size2::GetSize for NarrowingConstraint<'_> {} + +impl<'db> NarrowingConstraintBuilder<'db> { + fn finish(self, db: &'db dyn Db) -> NarrowingConstraint<'db> { + NarrowingConstraint::new(db, self) + } + /// Create an "intersection" constraint: the previous type will be /// intersected with this constraint - pub(crate) fn intersection(constraint: Type<'db>) -> Self { + fn intersection(constraint: Type<'db>) -> Self { Self { - intersection_disjunct: Some(constraint), + intersection_disjuncts: smallvec_inline![Conjunctions::singleton(constraint)], replacement_disjuncts: smallvec![], } } @@ -313,75 +505,129 @@ impl<'db> NarrowingConstraint<'db> { /// replaced wholesale with this constraint fn replacement(constraint: Type<'db>) -> Self { Self { - intersection_disjunct: None, - replacement_disjuncts: smallvec_inline![constraint], + intersection_disjuncts: smallvec![], + replacement_disjuncts: smallvec_inline![Conjunctions::singleton(constraint)], } } /// Merge two constraints, taking their intersection but respecting "replacement" semantics (with /// `other` winning) - pub(crate) fn merge_constraint_and(&self, other: Self, db: &'db dyn Db) -> Self { + fn merge_constraint_and(&self, other: &Self) -> Self { // Distribute AND over OR: (A1 | A2 | ...) AND (B1 | B2 | ...) // becomes (A1 & B1) | (A1 & B2) | ... | (A2 & B1) | ... // // In our representation, the RHS `replacement_disjuncts` will all clobber the LHS disjuncts // when they are `and`ed, so they'll just stay as is. // - // The thing we actually need to deal with is the RHS `intersection_disjunct`. It gets - // intersected with the LHS `intersection_disjunct` to form the new `intersection_disjunct`, + // The thing we actually need to deal with is the RHS `intersection_disjuncts`. Each RHS + // disjunct gets intersected with each LHS disjunct, producing the cartesian product. + // This is still deferred as conjunction lists. + // + // We also intersect each LHS `replacement_disjunct` with every RHS intersection disjunct to form new additional // and intersected with each LHS `replacement_disjunct` to form new additional // `replacement_disjuncts`. - let Some(other_intersection_disjunct) = other.intersection_disjunct else { - return other; - }; + if other.intersection_disjuncts.is_empty() { + return other.clone(); + } - let new_intersection_disjunct = self.intersection_disjunct.map(|intersection_disjunct| { - IntersectionType::from_elements( - db, - [intersection_disjunct, other_intersection_disjunct], - ) - }); + let mut new_intersection_disjuncts: SmallVec<[Conjunctions<'db>; 1]> = SmallVec::new(); + for intersection_disjunct in &self.intersection_disjuncts { + for other_intersection_disjunct in &other.intersection_disjuncts { + let merged = intersection_disjunct.and_with(other_intersection_disjunct); + if !new_intersection_disjuncts.contains(&merged) { + new_intersection_disjuncts.push(merged); + } + } + } - let additional_replacement_disjuncts = - self.replacement_disjuncts - .iter() - .map(|replacement_disjunct| { - IntersectionType::from_elements( - db, - [*replacement_disjunct, other_intersection_disjunct], - ) - }); + let mut additional_replacement_disjuncts: SmallVec<[Conjunctions<'db>; 1]> = + SmallVec::new(); + for replacement_disjunct in &self.replacement_disjuncts { + for other_intersection_disjunct in &other.intersection_disjuncts { + let merged = replacement_disjunct.and_with(other_intersection_disjunct); + if !additional_replacement_disjuncts.contains(&merged) { + additional_replacement_disjuncts.push(merged); + } + } + } - let mut new_replacement_disjuncts = other.replacement_disjuncts; + let mut new_replacement_disjuncts = other.replacement_disjuncts.clone(); new_replacement_disjuncts.extend(additional_replacement_disjuncts); - NarrowingConstraint { - intersection_disjunct: new_intersection_disjunct, + NarrowingConstraintBuilder { + intersection_disjuncts: new_intersection_disjuncts, replacement_disjuncts: new_replacement_disjuncts, } } + /// Merge two constraints with OR semantics (union/disjunction). + fn merge_constraint_or(&self, other: &Self) -> Self { + let mut intersection_disjuncts = self.intersection_disjuncts.clone(); + intersection_disjuncts.extend(other.intersection_disjuncts.iter().cloned()); + let mut replacement_disjuncts = self.replacement_disjuncts.clone(); + replacement_disjuncts.extend(other.replacement_disjuncts.iter().cloned()); + Self { + intersection_disjuncts, + replacement_disjuncts, + } + } + /// Evaluate the type this effectively constrains to /// /// Forgets whether each constraint originated from a `replacement` disjunct or not - pub(crate) fn evaluate_constraint_type(self, db: &'db dyn Db) -> Type<'db> { - UnionType::from_elements( - db, - self.replacement_disjuncts - .into_iter() - .chain(self.intersection_disjunct), - ) + pub(crate) fn evaluate_constraint_type( + &self, + db: &'db dyn Db, + check_redundancy: bool, + ) -> Type<'db> { + let mut union = UnionBuilder::new(db).check_redundancy(check_redundancy); + for conjunctions in self + .replacement_disjuncts + .iter() + .chain(self.intersection_disjuncts.iter()) + { + union = union.add(conjunctions.evaluate_constraint_type(db)); + } + union.build() } } -impl<'db> From> for NarrowingConstraint<'db> { - fn from(constraint: Type<'db>) -> Self { - Self::intersection(constraint) +impl<'db> NarrowingConstraint<'db> { + pub(crate) fn intersection(db: &'db dyn Db, constraint: Type<'db>) -> Self { + NarrowingConstraintBuilder::intersection(constraint).finish(db) + } + + pub(crate) fn merge_constraint_and(self, db: &'db dyn Db, other: Self) -> Self { + if self == other { + return self; + } + self.inner(db) + .merge_constraint_and(other.inner(db)) + .finish(db) + } + + #[allow(unused)] + pub(crate) fn merge_constraint_or(self, db: &'db dyn Db, other: Self) -> Self { + if self == other { + return self; + } + self.inner(db) + .merge_constraint_or(other.inner(db)) + .finish(db) + } + + pub(crate) fn evaluate_constraint_type( + self, + db: &'db dyn Db, + check_redundancy: bool, + ) -> Type<'db> { + self.inner(db) + .evaluate_constraint_type(db, check_redundancy) } } -type NarrowingConstraints<'db> = FxHashMap>; +type NarrowingConstraintBuilders<'db> = FxHashMap>; /// Merge constraints with AND semantics (intersection/conjunction). /// @@ -393,16 +639,15 @@ type NarrowingConstraints<'db> = FxHashMap( - into: &mut NarrowingConstraints<'db>, - from: NarrowingConstraints<'db>, - db: &'db dyn Db, + into: &mut NarrowingConstraintBuilders<'db>, + from: NarrowingConstraintBuilders<'db>, ) { for (key, from_constraint) in from { match into.entry(key) { Entry::Occupied(mut entry) => { let into_constraint = entry.get(); - entry.insert(into_constraint.merge_constraint_and(from_constraint, db)); + entry.insert(into_constraint.merge_constraint_and(&from_constraint)); } Entry::Vacant(entry) => { entry.insert(from_constraint); @@ -419,9 +664,8 @@ fn merge_constraints_and<'db>( /// However, if a place appears in only one branch of the OR, we need to widen it /// to `object` in the overall result (because the other branch doesn't constrain it). fn merge_constraints_or<'db>( - into: &mut NarrowingConstraints<'db>, - from: NarrowingConstraints<'db>, - db: &'db dyn Db, + into: &mut NarrowingConstraintBuilders<'db>, + from: NarrowingConstraintBuilders<'db>, ) { // For places that appear in `into` but not in `from`, widen to object into.retain(|key, _| from.contains_key(key)); @@ -430,16 +674,10 @@ fn merge_constraints_or<'db>( match into.entry(key) { Entry::Occupied(mut entry) => { let into_constraint = entry.get_mut(); - // Union the intersection constraints - into_constraint.intersection_disjunct = match ( - into_constraint.intersection_disjunct, - from_constraint.intersection_disjunct, - ) { - (Some(a), Some(b)) => Some(UnionType::from_elements(db, [a, b])), - (Some(a), None) => Some(a), - (None, Some(b)) => Some(b), - (None, None) => None, - }; + // Union the intersection constraints by concatenating disjunct lists. + into_constraint + .intersection_disjuncts + .extend(from_constraint.intersection_disjuncts); // Concatenate replacement disjuncts into_constraint @@ -497,74 +735,95 @@ struct NarrowingConstraintsBuilder<'db, 'ast> { db: &'db dyn Db, module: &'ast ParsedModuleRef, predicate: PredicateNode<'db>, - is_positive: bool, } impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { - fn new( - db: &'db dyn Db, - module: &'ast ParsedModuleRef, - predicate: PredicateNode<'db>, - is_positive: bool, - ) -> Self { + fn new(db: &'db dyn Db, module: &'ast ParsedModuleRef, predicate: PredicateNode<'db>) -> Self { Self { db, module, predicate, - is_positive, } } - fn finish(mut self) -> Option> { - let mut constraints: Option> = match self.predicate { - PredicateNode::Expression(expression) => { - self.evaluate_expression_predicate(expression, self.is_positive) - } - PredicateNode::Pattern(pattern) => { - self.evaluate_pattern_predicate(pattern, self.is_positive) + fn finish(mut self, db: &'db dyn Db) -> DualNarrowingConstraints<'db> { + let mut constraints = match self.predicate { + PredicateNode::Expression(expression) => self.evaluate_expression_predicate(expression), + PredicateNode::Pattern(pattern) => self.evaluate_pattern_predicate(pattern), + PredicateNode::ReturnsNever(_) | PredicateNode::StarImportPlaceholder(_) => { + return DualNarrowingConstraints::new( + db, + DualNarrowingConstraintsPayload::default(), + ); } - PredicateNode::ReturnsNever(_) => return None, - PredicateNode::StarImportPlaceholder(_) => return None, }; - if let Some(ref mut constraints) = constraints { - constraints.shrink_to_fit(); + constraints.shrink_to_fit(); + + constraints.finish(db) + } + + fn merge_constraints_and_sequence( + sub_constraints: Vec>>, + ) -> Option> { + let mut aggregation: Option> = None; + for sub_constraint in sub_constraints.into_iter().flatten() { + if let Some(ref mut some_aggregation) = aggregation { + merge_constraints_and(some_aggregation, sub_constraint); + } else { + aggregation = Some(sub_constraint); + } } + aggregation + } - constraints + fn merge_constraints_or_sequence( + sub_constraints: Vec>>, + ) -> Option> { + let (mut first, rest) = { + let mut it = sub_constraints.into_iter(); + (it.next()?, it) + }; + + if let Some(ref mut first) = first { + for rest_constraint in rest { + if let Some(rest_constraint) = rest_constraint { + merge_constraints_or(first, rest_constraint); + } else { + return None; + } + } + } + first } fn evaluate_expression_predicate( &mut self, expression: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraintsBuilder<'db> { let expression_node = expression.node_ref(self.db, self.module); - self.evaluate_expression_node_predicate(expression_node, expression, is_positive) + self.evaluate_expression_node_predicate(expression_node, expression) } fn evaluate_expression_node_predicate( &mut self, expression_node: &ruff_python_ast::Expr, expression: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraintsBuilder<'db> { match expression_node { ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => { - self.evaluate_simple_expr(expression_node, is_positive) + self.evaluate_simple_expr(expression_node) } ast::Expr::Compare(expr_compare) => { - self.evaluate_expr_compare(expr_compare, expression, is_positive) - } - ast::Expr::Call(expr_call) => { - self.evaluate_expr_call(expr_call, expression, is_positive) - } - ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => { - self.evaluate_expression_node_predicate(&unary_op.operand, expression, !is_positive) - } - ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive), - ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named, is_positive), - _ => None, + self.evaluate_expr_compare(expr_compare, expression) + } + ast::Expr::Call(expr_call) => self.evaluate_expr_call(expr_call, expression), + ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => self + .evaluate_expression_node_predicate(&unary_op.operand, expression) + .swap_polarity(), + ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression), + ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named), + _ => DualNarrowingConstraintsBuilder::default(), } } @@ -572,38 +831,32 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, pattern_predicate_kind: &PatternPredicateKind<'db>, subject: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraintsBuilder<'db> { match pattern_predicate_kind { PatternPredicateKind::Singleton(singleton) => { - self.evaluate_match_pattern_singleton(subject, *singleton, is_positive) + self.evaluate_match_pattern_singleton(subject, *singleton) } PatternPredicateKind::Class(cls, kind) => { - self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive) - } - PatternPredicateKind::Value(expr) => { - self.evaluate_match_pattern_value(subject, *expr, is_positive) + self.evaluate_match_pattern_class(subject, *cls, *kind) } + PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr), PatternPredicateKind::Or(predicates) => { - self.evaluate_match_pattern_or(subject, predicates, is_positive) + self.evaluate_match_pattern_or(subject, predicates) } PatternPredicateKind::As(pattern, _) => pattern .as_deref() - .and_then(|p| self.evaluate_pattern_predicate_kind(p, subject, is_positive)), - PatternPredicateKind::Unsupported => None, + .map_or_else(DualNarrowingConstraintsBuilder::default, |p| { + self.evaluate_pattern_predicate_kind(p, subject) + }), + PatternPredicateKind::Unsupported => DualNarrowingConstraintsBuilder::default(), } } fn evaluate_pattern_predicate( &mut self, pattern: PatternPredicate<'db>, - is_positive: bool, - ) -> Option> { - self.evaluate_pattern_predicate_kind( - pattern.kind(self.db), - pattern.subject(self.db), - is_positive, - ) + ) -> DualNarrowingConstraintsBuilder<'db> { + self.evaluate_pattern_predicate_kind(pattern.kind(self.db), pattern.subject(self.db)) } fn places(&self) -> &'db PlaceTable { @@ -724,32 +977,29 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { } } - fn evaluate_simple_expr( - &mut self, - expr: &ast::Expr, - is_positive: bool, - ) -> Option> { - let target = PlaceExpr::try_from_expr(expr)?; - let place = self.expect_place(&target); - - let ty = if is_positive { - Type::AlwaysFalsy.negate(self.db) - } else { - Type::AlwaysTruthy.negate(self.db) + fn evaluate_simple_expr(&mut self, expr: &ast::Expr) -> DualNarrowingConstraintsBuilder<'db> { + let Some(target) = PlaceExpr::try_from_expr(expr) else { + return DualNarrowingConstraintsBuilder::default(); }; + let place = self.expect_place(&target); - Some(NarrowingConstraints::from_iter([( + let positive = NarrowingConstraintBuilders::from_iter([( place, - NarrowingConstraint::intersection(ty), - )])) + NarrowingConstraintBuilder::intersection(Type::AlwaysFalsy.negate(self.db)), + )]); + let negative = NarrowingConstraintBuilders::from_iter([( + place, + NarrowingConstraintBuilder::intersection(Type::AlwaysTruthy.negate(self.db)), + )]); + + DualNarrowingConstraintsBuilder::from_sides(Some(positive), Some(negative)) } fn evaluate_expr_named( &mut self, expr_named: &ast::ExprNamed, - is_positive: bool, - ) -> Option> { - self.evaluate_simple_expr(&expr_named.target, is_positive) + ) -> DualNarrowingConstraintsBuilder<'db> { + self.evaluate_simple_expr(&expr_named.target) } fn evaluate_expr_eq(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option> { @@ -959,10 +1209,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { lhs_ty: Type<'db>, rhs_ty: Type<'db>, op: ast::CmpOp, - is_positive: bool, ) -> Option> { - let op = if is_positive { op } else { op.negate() }; - // `Divergent` shows up as an initial value in cycle recovery. If it appears on either side // of a potentially narrowing comparison, we don't want it to turn that comparison into a // no-op (e.g. because `Divergent` is not a singleton in the `IsNot` branch below), because @@ -1007,8 +1254,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expr_compare: &ast::ExprCompare, expression: Expression<'db>, + ) -> DualNarrowingConstraintsBuilder<'db> { + let inference = infer_expression_types(self.db, expression, TypeContext::default()); + DualNarrowingConstraintsBuilder::from_sides( + self.evaluate_expr_compare_for_polarity(expr_compare, inference, true), + self.evaluate_expr_compare_for_polarity(expr_compare, inference, false), + ) + } + + fn evaluate_expr_compare_for_polarity( + &mut self, + expr_compare: &ast::ExprCompare, + inference: &ExpressionInference<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool { matches!( expr, @@ -1081,12 +1340,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } - let inference = infer_expression_types(self.db, expression, TypeContext::default()); - let comparator_tuples = std::iter::once(&**left) .chain(comparators) .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); - let mut constraints = NarrowingConstraints::default(); + let mut constraints = NarrowingConstraintBuilders::default(); // Narrow unions of tuples based on element checks. For example: // @@ -1123,7 +1380,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { }); if filtered != Type::Union(union) { let place = self.expect_place(&subscript_place_expr); - constraints.insert(place, NarrowingConstraint::replacement(filtered)); + constraints.insert(place, NarrowingConstraintBuilder::replacement(filtered)); } } @@ -1158,7 +1415,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { constraints .entry(place) .and_modify(|existing| { - *existing = existing.merge_constraint_and(constraint.clone(), self.db); + *existing = existing.merge_constraint_and(&constraint); }) .or_insert(constraint); } else if let Some((place, constraint)) = self.narrow_tuple_subscript( @@ -1171,7 +1428,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { constraints .entry(place) .and_modify(|existing| { - *existing = existing.merge_constraint_and(constraint.clone(), self.db); + *existing = existing.merge_constraint_and(&constraint); }) .or_insert(constraint); } @@ -1252,7 +1509,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if narrowed != resolved_rhs_type { let place = self.expect_place(&rhs_place_expr); - constraints.insert(place, NarrowingConstraint::replacement(narrowed)); + constraints.insert(place, NarrowingConstraintBuilder::replacement(narrowed)); } } } @@ -1307,7 +1564,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let place = self.expect_place(&target); constraints.insert( place, - NarrowingConstraint::intersection( + NarrowingConstraintBuilder::intersection( Type::instance(self.db, other_class.top_materialization(self.db)) .negate_if(self.db, !is_positive), ), @@ -1326,14 +1583,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // - `if x not in y` if narrowable_ast(left) && let Some(narrowable) = PlaceExpr::try_from_expr(left) - && let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, *op, is_positive) + && let Some(ty) = self.evaluate_expr_compare_op( + lhs_ty, + rhs_ty, + if is_positive { *op } else { op.negate() }, + ) { let place = self.expect_place(&narrowable); - let constraint = NarrowingConstraint::intersection(ty); + let constraint = NarrowingConstraintBuilder::intersection(ty); constraints .entry(place) .and_modify(|existing| { - *existing = existing.merge_constraint_and(constraint.clone(), self.db); + *existing = existing.merge_constraint_and(&constraint); }) .or_insert(constraint); } @@ -1348,19 +1609,23 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if !matches!(op, ast::CmpOp::In | ast::CmpOp::NotIn) && narrowable_ast(right) && let Some(narrowable) = PlaceExpr::try_from_expr(right) - && let Some(ty) = self.evaluate_expr_compare_op(rhs_ty, lhs_ty, *op, is_positive) + && let Some(ty) = self.evaluate_expr_compare_op( + rhs_ty, + lhs_ty, + if is_positive { *op } else { op.negate() }, + ) { let place = self.expect_place(&narrowable); - let constraint = NarrowingConstraint::intersection(ty); + let constraint = NarrowingConstraintBuilder::intersection(ty); constraints .entry(place) .and_modify(|existing| { - *existing = existing.merge_constraint_and(constraint.clone(), self.db); + *existing = existing.merge_constraint_and(&constraint); }) .or_insert(constraint); // Use the narrowed type for subsequent comparisons in a chain. - last_rhs_ty = Some(IntersectionType::from_elements(self.db, [rhs_ty, ty])); + last_rhs_ty = Some(IntersectionType::from_two_elements(self.db, rhs_ty, ty)); } else { last_rhs_ty = Some(rhs_ty); } @@ -1372,18 +1637,45 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, expr_call: &ast::ExprCall, expression: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraintsBuilder<'db> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); - if let Some(type_guard_call_constraints) = - self.evaluate_type_guard_call(inference, expr_call, is_positive) + // If the return type of expr_call is TypeGuard (positive) / TypeIs: + if let Some(positive_constraints) = + self.evaluate_type_guard_call_for_polarity(inference, expr_call, true) { - return Some(type_guard_call_constraints); + let negative_constraints = + self.evaluate_type_guard_call_for_polarity(inference, expr_call, false); + return DualNarrowingConstraintsBuilder::from_sides( + Some(positive_constraints), + negative_constraints, + ); } let callable_ty = inference.expression_type(&*expr_call.func); + if let Type::ClassLiteral(class_type) = callable_ty + && expr_call.arguments.args.len() == 1 + && expr_call.arguments.keywords.is_empty() + && class_type.is_known(self.db, KnownClass::Bool) + { + return self + .evaluate_expression_node_predicate(&expr_call.arguments.args[0], expression); + } + + DualNarrowingConstraintsBuilder::from_sides( + self.evaluate_expr_call_for_polarity(expr_call, inference, callable_ty, true), + self.evaluate_expr_call_for_polarity(expr_call, inference, callable_ty, false), + ) + } + + fn evaluate_expr_call_for_polarity( + &mut self, + expr_call: &ast::ExprCall, + inference: &ExpressionInference<'db>, + callable_ty: Type<'db>, + is_positive: bool, + ) -> Option> { match callable_ty { // For the expression `len(E)`, we narrow the type based on whether len(E) is truthy // (i.e., whether E is non-empty). We only narrow the parts of the type where we know @@ -1401,9 +1693,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if let Some(narrowed_ty) = Self::narrow_type_by_len(self.db, arg_ty, is_positive) { let target = PlaceExpr::try_from_expr(arg)?; let place = self.expect_place(&target); - Some(NarrowingConstraints::from_iter([( + Some(NarrowingConstraintBuilders::from_iter([( place, - NarrowingConstraint::intersection(narrowed_ty), + NarrowingConstraintBuilder::intersection(narrowed_ty), )])) } else { None @@ -1432,9 +1724,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let constraint = Type::protocol_with_readonly_members(self.db, [(attr, Type::object())]); - return Some(NarrowingConstraints::from_iter([( + return Some(NarrowingConstraintBuilders::from_iter([( place, - NarrowingConstraint::intersection( + NarrowingConstraintBuilder::intersection( constraint.negate_if(self.db, !is_positive), ), )])); @@ -1447,38 +1739,26 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { function .generate_constraint(self.db, class_info_ty) .map(|constraint| { - NarrowingConstraints::from_iter([( + NarrowingConstraintBuilders::from_iter([( place, - NarrowingConstraint::intersection( + NarrowingConstraintBuilder::intersection( constraint.negate_if(self.db, !is_positive), ), )]) }) } - // for the expression `bool(E)`, we further narrow the type based on `E` - Type::ClassLiteral(class_type) - if expr_call.arguments.args.len() == 1 - && expr_call.arguments.keywords.is_empty() - && class_type.is_known(self.db, KnownClass::Bool) => - { - self.evaluate_expression_node_predicate( - &expr_call.arguments.args[0], - expression, - is_positive, - ) - } _ => None, } } // Helper to evaluate TypeGuard/TypeIs narrowing for a call expression. // This is based on the call expression's return type, so it applies to any callable type. - fn evaluate_type_guard_call( + fn evaluate_type_guard_call_for_polarity( &mut self, inference: &ExpressionInference<'db>, expr_call: &ast::ExprCall, is_positive: bool, - ) -> Option> { + ) -> Option> { let return_ty = inference.expression_type(expr_call); let place_and_constraint = match return_ty { @@ -1486,7 +1766,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let (_, place) = type_is.place_info(self.db)?; Some(( place, - NarrowingConstraint::intersection( + NarrowingConstraintBuilder::intersection( type_is .return_type(self.db) .negate_if(self.db, !is_positive), @@ -1498,21 +1778,34 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let (_, place) = type_guard.place_info(self.db)?; Some(( place, - NarrowingConstraint::replacement(type_guard.return_type(self.db)), + NarrowingConstraintBuilder::replacement(type_guard.return_type(self.db)), )) } _ => None, }?; - Some(NarrowingConstraints::from_iter([place_and_constraint])) + Some(NarrowingConstraintBuilders::from_iter([ + place_and_constraint, + ])) } fn evaluate_match_pattern_singleton( &mut self, subject: Expression<'db>, singleton: ast::Singleton, + ) -> DualNarrowingConstraintsBuilder<'db> { + DualNarrowingConstraintsBuilder::from_sides( + self.evaluate_match_pattern_singleton_for_polarity(subject, singleton, true), + self.evaluate_match_pattern_singleton_for_polarity(subject, singleton, false), + ) + } + + fn evaluate_match_pattern_singleton_for_polarity( + &mut self, + subject: Expression<'db>, + singleton: ast::Singleton, is_positive: bool, - ) -> Option> { + ) -> Option> { let subject = PlaceExpr::try_from_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); @@ -1522,9 +1815,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { ast::Singleton::False => Type::bool_literal(false), }; let ty = ty.negate_if(self.db, !is_positive); - Some(NarrowingConstraints::from_iter([( + Some(NarrowingConstraintBuilders::from_iter([( place, - NarrowingConstraint::intersection(ty), + NarrowingConstraintBuilder::intersection(ty), )])) } @@ -1533,8 +1826,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subject: Expression<'db>, cls: Expression<'db>, kind: ClassPatternKind, + ) -> DualNarrowingConstraintsBuilder<'db> { + DualNarrowingConstraintsBuilder::from_sides( + self.evaluate_match_pattern_class_for_polarity(subject, cls, kind, true), + self.evaluate_match_pattern_class_for_polarity(subject, cls, kind, false), + ) + } + + fn evaluate_match_pattern_class_for_polarity( + &mut self, + subject: Expression<'db>, + cls: Expression<'db>, + kind: ClassPatternKind, is_positive: bool, - ) -> Option> { + ) -> Option> { if !kind.is_irrefutable() && !is_positive { // A class pattern like `case Point(x=0, y=0)` is not irrefutable. In the positive case, // we can still narrow the type of the match subject to `Point`. But in the negative case, @@ -1557,9 +1862,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { _ => return None, }; - Some(NarrowingConstraints::from_iter([( + Some(NarrowingConstraintBuilders::from_iter([( place, - NarrowingConstraint::intersection(narrowed_type), + NarrowingConstraintBuilder::intersection(narrowed_type), )])) } @@ -1567,8 +1872,19 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, value: Expression<'db>, + ) -> DualNarrowingConstraintsBuilder<'db> { + DualNarrowingConstraintsBuilder::from_sides( + self.evaluate_match_pattern_value_for_polarity(subject, value, true), + self.evaluate_match_pattern_value_for_polarity(subject, value, false), + ) + } + + fn evaluate_match_pattern_value_for_polarity( + &mut self, + subject: Expression<'db>, + value: Expression<'db>, is_positive: bool, - ) -> Option> { + ) -> Option> { let subject_node = subject.node_ref(self.db, self.module); let place = { let subject = PlaceExpr::try_from_expr(subject_node)?; @@ -1581,9 +1897,20 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); let mut constraints = self - .evaluate_expr_compare_op(subject_ty, value_ty, ast::CmpOp::Eq, is_positive) + .evaluate_expr_compare_op( + subject_ty, + value_ty, + if is_positive { + ast::CmpOp::Eq + } else { + ast::CmpOp::NotEq + }, + ) .map(|ty| { - NarrowingConstraints::from_iter([(place, NarrowingConstraint::intersection(ty))]) + NarrowingConstraintBuilders::from_iter([( + place, + NarrowingConstraintBuilder::intersection(ty), + )]) }) .unwrap_or_default(); @@ -1629,39 +1956,45 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { &mut self, subject: Expression<'db>, predicates: &Vec>, - is_positive: bool, - ) -> Option> { - let db = self.db; + ) -> DualNarrowingConstraintsBuilder<'db> { + let mut positive: Option> = None; + let mut negative: Option> = None; + + for predicate in predicates { + let (sub_positive, sub_negative) = self + .evaluate_pattern_predicate_kind(predicate, subject) + .into_sides(); + + if let Some(sub_positive) = sub_positive { + if let Some(ref mut aggregated) = positive { + merge_constraints_or(aggregated, sub_positive); + } else { + positive = Some(sub_positive); + } + } - // DeMorgan's law---if the overall `or` is negated, we need to `and` the negated sub-constraints. - let merge_constraints = if is_positive { - merge_constraints_or - } else { - merge_constraints_and - }; + if let Some(sub_negative) = sub_negative { + if let Some(ref mut aggregated) = negative { + merge_constraints_and(aggregated, sub_negative); + } else { + negative = Some(sub_negative); + } + } + } - predicates - .iter() - .filter_map(|predicate| { - self.evaluate_pattern_predicate_kind(predicate, subject, is_positive) - }) - .reduce(|mut constraints, constraints_| { - merge_constraints(&mut constraints, constraints_, db); - constraints - }) + DualNarrowingConstraintsBuilder::from_sides(positive, negative) } fn evaluate_bool_op( &mut self, expr_bool_op: &ExprBoolOp, expression: Expression<'db>, - is_positive: bool, - ) -> Option> { + ) -> DualNarrowingConstraintsBuilder<'db> { let inference = infer_expression_types(self.db, expression, TypeContext::default()); let sub_constraints = expr_bool_op .values .iter() - // filter our arms with statically known truthiness + // Filter out arms with statically known truthiness. .filter(|expr| { inference.expression_type(*expr).bool(self.db) != match expr_bool_op.op { @@ -1669,40 +2002,27 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { BoolOp::Or => Truthiness::AlwaysFalse, } }) - .map(|sub_expr| { - self.evaluate_expression_node_predicate(sub_expr, expression, is_positive) - }) + .map(|sub_expr| self.evaluate_expression_node_predicate(sub_expr, expression)) .collect::>(); - match (expr_bool_op.op, is_positive) { - (BoolOp::And, true) | (BoolOp::Or, false) => { - let mut aggregation: Option = None; - for sub_constraint in sub_constraints.into_iter().flatten() { - if let Some(ref mut some_aggregation) = aggregation { - merge_constraints_and(some_aggregation, sub_constraint, self.db); - } else { - aggregation = Some(sub_constraint); - } - } - aggregation - } - (BoolOp::Or, true) | (BoolOp::And, false) => { - let (mut first, rest) = { - let mut it = sub_constraints.into_iter(); - (it.next()?, it) - }; - if let Some(ref mut first) = first { - for rest_constraint in rest { - if let Some(rest_constraint) = rest_constraint { - merge_constraints_or(first, rest_constraint, self.db); - } else { - return None; - } - } - } - first - } - } + let (positive_sub_constraints, negative_sub_constraints): (Vec<_>, Vec<_>) = + sub_constraints + .into_iter() + .map(DualNarrowingConstraintsBuilder::into_sides) + .unzip(); + + let (positive, negative) = match expr_bool_op.op { + BoolOp::And => ( + Self::merge_constraints_and_sequence(positive_sub_constraints), + Self::merge_constraints_or_sequence(negative_sub_constraints), + ), + BoolOp::Or => ( + Self::merge_constraints_or_sequence(positive_sub_constraints), + Self::merge_constraints_and_sequence(negative_sub_constraints), + ), + }; + + DualNarrowingConstraintsBuilder::from_sides(positive, negative) } /// Narrow tagged unions of `TypedDict`s with `Literal` keys. @@ -1719,7 +2039,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subscript_key_type: Type<'db>, rhs_type: Type<'db>, constrain_with_equality: bool, - ) -> Option<(ScopedPlaceId, NarrowingConstraint<'db>)> { + ) -> Option<(ScopedPlaceId, NarrowingConstraintBuilder<'db>)> { // Check preconditions: we need a TypedDict, a string key, and a supported tag literal. if !is_or_contains_typeddict(self.db, subscript_value_type) { return None; @@ -1767,7 +2087,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // As mentioned above, the synthesized `TypedDict` is always negated. let intersection = Type::TypedDict(synthesized_typeddict).negate(self.db); let place = self.expect_place(&subscript_place_expr); - Some((place, NarrowingConstraint::intersection(intersection))) + Some(( + place, + NarrowingConstraintBuilder::intersection(intersection), + )) } /// Narrow tagged unions of tuples with `Literal` elements. @@ -1791,7 +2114,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { subscript_index_type: Type<'db>, rhs_type: Type<'db>, constrain_with_equality: bool, - ) -> Option<(ScopedPlaceId, NarrowingConstraint<'db>)> { + ) -> Option<(ScopedPlaceId, NarrowingConstraintBuilder<'db>)> { // We need a union type for narrowing to be useful. let Type::Union(union) = subscript_value_type.resolve_type_alias(self.db) else { return None; @@ -1841,7 +2164,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // Only create a constraint if we actually narrowed something. if filtered != Type::Union(union) { let place = self.expect_place(&subscript_place_expr); - Some((place, NarrowingConstraint::replacement(filtered))) + Some((place, NarrowingConstraintBuilder::replacement(filtered))) } else { None } diff --git a/crates/ty_python_semantic/src/types/relation.rs b/crates/ty_python_semantic/src/types/relation.rs index fe4f96ed88df83..6492ca2d57083a 100644 --- a/crates/ty_python_semantic/src/types/relation.rs +++ b/crates/ty_python_semantic/src/types/relation.rs @@ -313,6 +313,39 @@ impl<'db> Type<'db> { return true; } + // Fast path for intersection types: use set-based subset check instead of + // the full `has_relation_to` machinery. This is critical for narrowing where + // many intersection types with overlapping positive elements are produced. + if let (Type::Intersection(self_inter), Type::Intersection(other_inter)) = (self, other) { + let self_pos = self_inter.positive(db); + let other_pos = other_inter.positive(db); + let self_neg = self_inter.negative(db); + let other_neg = other_inter.negative(db); + let other_pos_subset = other_pos.iter().all(|p| self_pos.contains(p)); + let other_neg_subset = other_neg.iter().all(|n| self_neg.contains(n)); + + // Intersection(pos_self, neg_self) is redundant with Intersection(pos_other, neg_other) if: + // pos_other ⊆ pos_self AND neg_other ⊆ neg_self + // Conversely, if pos_other ⊄ pos_self (some positive of other is missing from self), + // then self is NOT redundant with other. + if other_pos_subset && other_neg_subset { + return true; + } + + // If all positive elements are `NominalInstance` types and some positive + // of `other` is not contained in `self`'s positives, we can assume + // non-redundancy without the full `has_relation_to` check. + // This is not strictly correct for classes with inheritance relationship, + // but such a false negative is safe: it only causes the union to retain + // an extra element without affecting correctness. + if !other_pos_subset + && self_pos.iter().all(|t| t.is_nominal_instance()) + && other_pos.iter().all(|t| t.is_nominal_instance()) + { + return false; + } + } + is_redundant_with_impl(db, self, other) } diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index b32d4009ec89f4..8cacff51d213e0 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -2136,11 +2136,11 @@ impl<'db> TupleSpecBuilder<'db> { && suffix.len() == var.suffix_elements().len() { for (existing, new) in prefix.iter_mut().zip(var.prefix_elements()) { - *existing = IntersectionType::from_elements(db, [*existing, *new]); + *existing = IntersectionType::from_two_elements(db, *existing, *new); } - *variable = IntersectionType::from_elements(db, [*variable, var.variable()]); + *variable = IntersectionType::from_two_elements(db, *variable, var.variable()); for (existing, new) in suffix.iter_mut().zip(var.suffix_elements()) { - *existing = IntersectionType::from_elements(db, [*existing, *new]); + *existing = IntersectionType::from_two_elements(db, *existing, *new); } return Some(self); }