diff --git a/crates/ty_python_semantic/resources/mdtest/enums.md b/crates/ty_python_semantic/resources/mdtest/enums.md index c3b1e55c53e85..8f7088c8eaf4a 100644 --- a/crates/ty_python_semantic/resources/mdtest/enums.md +++ b/crates/ty_python_semantic/resources/mdtest/enums.md @@ -100,6 +100,175 @@ class Answer(Enum): reveal_type(enum_members(Answer)) ``` +### Declared `_value_` annotation + +If a `_value_` annotation is defined on an `Enum` class, all enum member values must be compatible +with the declared type: + +```pyi +from enum import Enum + +class Color(Enum): + _value_: int + RED = 1 + GREEN = "green" # error: [invalid-assignment] + BLUE = ... + YELLOW = None # error: [invalid-assignment] + # In stub files, `[]` is not exempt from type checking (only `...` is). + PURPLE = [] # error: [invalid-assignment] +``` + +When `_value_` is annotated, `.value` and `._value_` are inferred as the declared type: + +```py +from enum import Enum + +class Color2(Enum): + _value_: int + RED = 1 + GREEN = 2 + +reveal_type(Color2.RED.value) # revealed: int +reveal_type(Color2.RED._value_) # revealed: int +``` + +### `_value_` annotation with `__init__` + +When `__init__` is defined, member values are validated by synthesizing a call to `__init__`. The +`_value_` annotation still constrains assignments to `self._value_` inside `__init__`: + +```py +from enum import Enum + +class Planet(Enum): + _value_: int + + def __init__(self, value: int, mass: float, radius: float): + self._value_ = value + + MERCURY = (1, 3.303e23, 2.4397e6) + SATURN = "saturn" # error: [invalid-assignment] + +reveal_type(Planet.MERCURY.value) # revealed: int +reveal_type(Planet.MERCURY._value_) # revealed: int +``` + +### `_value_` annotation incompatible with `__init__` + +When `_value_` and `__init__` disagree, the assignment inside `__init__` is flagged: + +```py +from enum import Enum + +class Planet(Enum): + _value_: str + + def __init__(self, value: int, mass: float, radius: float): + self._value_ = value # error: [invalid-assignment] + + MERCURY = (1, 3.303e23, 2.4397e6) + SATURN = "saturn" # error: [invalid-assignment] + +reveal_type(Planet.MERCURY.value) # revealed: str +reveal_type(Planet.MERCURY._value_) # revealed: str +``` + +### `__init__` without `_value_` annotation + +When `__init__` is defined but no explicit `_value_` annotation exists, member values are validated +against the `__init__` signature. Values that are incompatible with `__init__` are flagged: + +```py +from enum import Enum + +class Planet2(Enum): + def __init__(self, mass: float, radius: float): + self.mass = mass + self.radius = radius + + MERCURY = (3.303e23, 2.4397e6) + VENUS = (4.869e24, 6.0518e6) + INVALID = "not a planet" # error: [invalid-assignment] + +reveal_type(Planet2.MERCURY.value) # revealed: Any +reveal_type(Planet2.MERCURY._value_) # revealed: Any +``` + +### Inherited `_value_` annotation + +A `_value_` annotation on a parent enum is inherited by subclasses. Member values are validated +against the inherited annotation, and `.value` uses the declared type: + +```py +from enum import Enum + +class Base(Enum): + _value_: int + +class Child(Base): + A = 1 + B = "not an int" # error: [invalid-assignment] + +reveal_type(Child.A.value) # revealed: int +``` + +This also works through multiple levels of inheritance, where `_value_` is declared on an +intermediate class: + +```py +from enum import Enum + +class Grandparent(Enum): + pass + +class Parent(Grandparent): + _value_: int + +class Child(Parent): + A = 1 + B = "not an int" # error: [invalid-assignment] + +reveal_type(Child.A.value) # revealed: int +``` + +### Inherited `__init__` + +A custom `__init__` on a parent enum is inherited by subclasses. Member values are validated against +the inherited `__init__` signature: + +```py +from enum import Enum + +class Base(Enum): + def __init__(self, a: int, b: str): + self._value_ = a + +class Child(Base): + A = (1, "foo") + B = "should be checked against __init__" # error: [invalid-assignment] + +reveal_type(Child.A.value) # revealed: Any +``` + +This also works through multiple levels of inheritance: + +```py +from enum import Enum + +class Grandparent(Enum): + def __init__(self, a: int, b: str): + self._value_ = a + +class Parent(Grandparent): + pass + +class Child(Parent): + A = (1, "foo") + B = "bad" # error: [invalid-assignment] + +reveal_type(Child.A.value) # revealed: Any +``` + ### Non-member attributes with disallowed type Methods, callables, descriptors (including properties), and nested classes that are defined in the @@ -358,7 +527,8 @@ class SingleMember(StrEnum): reveal_type(SingleMember.SINGLE.value) # revealed: Literal["single"] ``` -Using `auto()` with `IntEnum` also works as expected: +Using `auto()` with `IntEnum` also works as expected. `IntEnum` declares `_value_: int` in typeshed, +so `.value` is typed as `int` rather than a precise literal: ```py from enum import IntEnum, auto @@ -367,8 +537,8 @@ class Answer(IntEnum): YES = auto() NO = auto() -reveal_type(Answer.YES.value) # revealed: Literal[1] -reveal_type(Answer.NO.value) # revealed: Literal[2] +reveal_type(Answer.YES.value) # revealed: int +reveal_type(Answer.NO.value) # revealed: int ``` As does using `auto()` for other enums that use `int` as a mixin: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index a4d17e190a163..d621755c3e82f 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -3465,7 +3465,7 @@ impl<'db> Type<'db> { { let enum_literal = literal.as_enum().unwrap(); enum_metadata(db, enum_literal.enum_class(db)) - .and_then(|metadata| metadata.members.get(enum_literal.name(db))) + .and_then(|metadata| metadata.value_type(enum_literal.name(db))) .map_or_else(|| Place::Undefined, Place::bound) .into() } @@ -3489,10 +3489,10 @@ impl<'db> Type<'db> { { enum_metadata(db, instance.class_literal(db)) .and_then(|metadata| { - let (_, ty) = metadata.members.get_index(0)?; - Some(Place::bound(*ty)) + let (name, _) = metadata.members.get_index(0)?; + metadata.value_type(name) }) - .unwrap_or_default() + .map_or_else(Place::default, Place::bound) .into() } diff --git a/crates/ty_python_semantic/src/types/enums.rs b/crates/ty_python_semantic/src/types/enums.rs index 50b80465043b4..0cb4eb7632091 100644 --- a/crates/ty_python_semantic/src/types/enums.rs +++ b/crates/ty_python_semantic/src/types/enums.rs @@ -7,10 +7,10 @@ use crate::{ place::{ DefinedPlace, Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations, }, - semantic_index::{place_table, use_def_map}, + semantic_index::{place_table, scope::ScopeId, use_def_map}, types::{ ClassBase, ClassLiteral, DynamicType, EnumLiteralType, KnownClass, LiteralValueTypeKind, - MemberLookupPolicy, StaticClassLiteral, Type, TypeQualifiers, + MemberLookupPolicy, StaticClassLiteral, Type, TypeQualifiers, function::FunctionType, }, }; @@ -18,15 +18,43 @@ use crate::{ pub(crate) struct EnumMetadata<'db> { pub(crate) members: FxIndexMap>, pub(crate) aliases: FxHashMap, + + /// The explicit `_value_` annotation type, if declared. + pub(crate) value_annotation: Option>, + + /// The custom `__init__` function, if defined on this enum. + /// + /// When present, member values are validated by synthesizing a call to + /// `__init__` rather than by simple type assignability. + pub(crate) init_function: Option>, } impl get_size2::GetSize for EnumMetadata<'_> {} -impl EnumMetadata<'_> { +impl<'db> EnumMetadata<'db> { fn empty() -> Self { EnumMetadata { members: FxIndexMap::default(), aliases: FxHashMap::default(), + value_annotation: None, + init_function: None, + } + } + + /// Returns the type of `.value`/`._value_` for a given enum member. + /// + /// Priority: explicit `_value_` annotation, then `__init__` → `Any`, + /// then the inferred member value type. + pub(crate) fn value_type(&self, member_name: &Name) -> Option> { + if !self.members.contains_key(member_name) { + return None; + } + if let Some(annotation) = self.value_annotation { + Some(annotation) + } else if self.init_function.is_some() { + Some(Type::Dynamic(DynamicType::Any)) + } else { + self.members.get(member_name).copied() } } @@ -287,7 +315,93 @@ pub(crate) fn enum_metadata<'db>( return None; } - Some(EnumMetadata { members, aliases }) + // Look up an explicit `_value_` annotation, if present. Falls back to + // checking parent enum classes in the MRO. + let value_annotation = place_table(db, scope_id) + .symbol_id("_value_") + .and_then(|symbol_id| { + let declarations = use_def_map.end_of_scope_symbol_declarations(symbol_id); + place_from_declarations(db, declarations) + .ignore_conflicting_declarations() + .ignore_possibly_undefined() + }) + .or_else(|| inherited_value_annotation(db, class)); + + // Look up a custom `__init__`, falling back to parent enum classes. + let init_function = custom_init(db, scope_id).or_else(|| inherited_init(db, class)); + + Some(EnumMetadata { + members, + aliases, + value_annotation, + init_function, + }) +} + +/// Iterates over parent enum classes in the MRO, skipping known classes +/// (like `Enum`, `StrEnum`, etc.) that we handle specially. +fn iter_parent_enum_classes<'db>( + db: &'db dyn Db, + class: StaticClassLiteral<'db>, +) -> impl Iterator> + 'db { + class + .iter_mro(db, None) + .skip(1) + .filter_map(ClassBase::into_class) + .filter_map(move |class_type| { + let base = class_type.class_literal(db).as_static()?; + (base.known(db).is_none() && is_enum_class_by_inheritance(db, base)).then_some(base) + }) +} + +/// Looks up an inherited `_value_` annotation from parent enum classes in the MRO. +fn inherited_value_annotation<'db>( + db: &'db dyn Db, + class: StaticClassLiteral<'db>, +) -> Option> { + for base_class in iter_parent_enum_classes(db, class) { + let scope_id = base_class.body_scope(db); + let use_def = use_def_map(db, scope_id); + if let Some(symbol_id) = place_table(db, scope_id).symbol_id("_value_") { + let declarations = use_def.end_of_scope_symbol_declarations(symbol_id); + if let Some(ty) = place_from_declarations(db, declarations) + .ignore_conflicting_declarations() + .ignore_possibly_undefined() + { + return Some(ty); + } + } + } + None +} + +/// Looks up an inherited `__init__` from parent enum classes in the MRO. +fn inherited_init<'db>( + db: &'db dyn Db, + class: StaticClassLiteral<'db>, +) -> Option> { + for base_class in iter_parent_enum_classes(db, class) { + if let Some(f) = custom_init(db, base_class.body_scope(db)) { + return Some(f); + } + } + None +} + +/// Returns the custom `__init__` function type if one is defined on the enum. +fn custom_init<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Option> { + let init_symbol_id = place_table(db, scope).symbol_id("__init__")?; + let init_type = place_from_declarations( + db, + use_def_map(db, scope).end_of_scope_symbol_declarations(init_symbol_id), + ) + .ignore_conflicting_declarations() + .ignore_possibly_undefined()?; + + match init_type { + Type::FunctionLiteral(f) => Some(f), + _ => None, + } } pub(crate) fn enum_member_literals<'a, 'db: 'a>( diff --git a/crates/ty_python_semantic/src/types/overrides.rs b/crates/ty_python_semantic/src/types/overrides.rs index eea86fe686354..c573af5c3dd2d 100644 --- a/crates/ty_python_semantic/src/types/overrides.rs +++ b/crates/ty_python_semantic/src/types/overrides.rs @@ -23,17 +23,20 @@ use crate::{ }, types::{ CallableType, ClassBase, ClassType, KnownClass, Parameter, Parameters, Signature, - StaticClassLiteral, Type, TypeQualifiers, + StaticClassLiteral, Type, TypeContext, TypeQualifiers, + call::CallArguments, class::{CodeGeneratorKind, FieldKind}, context::InferContext, diagnostic::{ - INVALID_DATACLASS, INVALID_EXPLICIT_OVERRIDE, INVALID_METHOD_OVERRIDE, - INVALID_NAMED_TUPLE, OVERRIDE_OF_FINAL_METHOD, OVERRIDE_OF_FINAL_VARIABLE, - report_invalid_method_override, report_overridden_final_method, - report_overridden_final_variable, + INVALID_ASSIGNMENT, INVALID_DATACLASS, INVALID_EXPLICIT_OVERRIDE, + INVALID_METHOD_OVERRIDE, INVALID_NAMED_TUPLE, OVERRIDE_OF_FINAL_METHOD, + OVERRIDE_OF_FINAL_VARIABLE, report_invalid_method_override, + report_overridden_final_method, report_overridden_final_variable, }, + enums::{EnumMetadata, enum_metadata}, function::{FunctionDecorators, FunctionType, KnownFunction}, list_members::{Member, MemberWithDefinition, all_end_of_scope_members}, + tuple::Tuple, }, }; @@ -66,15 +69,24 @@ pub(super) fn check_class<'db>(context: &InferContext<'db, '_>, class: StaticCla let class_specialized = class.identity_specialization(db); let scope = class.body_scope(db); let own_class_members: FxHashSet<_> = all_end_of_scope_members(db, scope).collect(); + let enum_info = enum_metadata(db, class.into()); for member in own_class_members { - check_class_declaration(context, configuration, class_specialized, scope, &member); + check_class_declaration( + context, + configuration, + enum_info, + class_specialized, + scope, + &member, + ); } } fn check_class_declaration<'db>( context: &InferContext<'db, '_>, configuration: OverrideRulesConfig, + enum_info: Option<&EnumMetadata<'db>>, class: ClassType<'db>, class_scope: ScopeId<'db>, member: &MemberWithDefinition<'db>, @@ -173,6 +185,68 @@ fn check_class_declaration<'db>( Some(CodeGeneratorKind::TypedDict) | None => {} } + // Check for invalid Enum member values. + if let Some(enum_info) = enum_info { + if member.name != "_value_" + && matches!( + first_reachable_definition.kind(db), + DefinitionKind::Assignment(_) + ) + { + let is_enum_member = enum_info.resolve_member(&member.name).is_some(); + if is_enum_member { + let member_value_type = member.ty; + + // TODO ideally this would be a syntactic check that only matches on literal `...` + // in the source, rather than matching on the type. But this would require storing + // additional information in `EnumMetadata`. + let is_ellipsis = matches!( + member_value_type, + Type::NominalInstance(nominal_instance) + if nominal_instance.has_known_class(db, KnownClass::EllipsisType) + ); + // `auto()` values are computed at runtime by the enum metaclass, + // so we can't validate them against _value_ or __init__ at the type level. + let is_auto = matches!( + member_value_type, + Type::NominalInstance(nominal_instance) + if nominal_instance.has_known_class(db, KnownClass::Auto) + ); + let skip_type_check = (context.in_stub() && is_ellipsis) || is_auto; + + if !skip_type_check { + if let Some(init_function) = enum_info.init_function { + check_enum_member_against_init( + context, + init_function, + instance_of_class, + member_value_type, + &member.name, + *first_reachable_definition, + ); + } else if let Some(expected_type) = enum_info.value_annotation { + if !member_value_type.is_assignable_to(db, expected_type) { + if let Some(builder) = context.report_lint( + &INVALID_ASSIGNMENT, + first_reachable_definition.focus_range(db, context.module()), + ) { + let mut diagnostic = builder.into_diagnostic(format_args!( + "Enum member `{}` value is not assignable to expected type", + &member.name + )); + diagnostic.info(format_args!( + "Expected `{}`, got `{}`", + expected_type.display(db), + member_value_type.display(db) + )); + } + } + } + } + } + } + } + let mut subclass_overrides_superclass_declaration = false; let mut has_dynamic_superclass = false; let mut has_typeddict_in_mro = false; @@ -653,3 +727,61 @@ fn check_post_init_signature<'db>( as positional-only parameters", ); } + +/// Validates an enum member value against the enum's `__init__` signature. +/// +/// The enum metaclass unpacks tuple values as positional arguments to `__init__`, +/// and passes non-tuple values as a single argument. This function synthesizes +/// a call to `__init__` with the appropriate arguments and reports a diagnostic +/// if the call would fail. +fn check_enum_member_against_init<'db>( + context: &InferContext<'db, '_>, + init_function: FunctionType<'db>, + self_type: Type<'db>, + member_value_type: Type<'db>, + member_name: &Name, + definition: Definition<'db>, +) { + let db = context.db(); + + // The enum metaclass unpacks tuple values as positional args: + // MEMBER = (a, b, c) → __init__(self, a, b, c) + // MEMBER = x → __init__(self, x) + let args: Vec> = if let Type::NominalInstance(instance) = member_value_type { + if let Some(spec) = instance.tuple_spec(db) { + if let Tuple::Fixed(fixed) = &*spec { + fixed.all_elements().to_vec() + } else { + // Variable-length tuples: can't determine exact args, skip validation. + return; + } + } else { + vec![member_value_type] + } + } else { + vec![member_value_type] + }; + + let call_args = CallArguments::positional(args); + let call_args = call_args.with_self(Some(self_type)); + + let result = Type::FunctionLiteral(init_function) + .bindings(db) + .match_parameters(db, &call_args) + .check_types(db, &call_args, TypeContext::default(), &[]); + + if result.is_err() { + if let Some(builder) = context.report_lint( + &INVALID_ASSIGNMENT, + definition.focus_range(db, context.module()), + ) { + let mut diagnostic = builder.into_diagnostic(format_args!( + "Enum member `{member_name}` is incompatible with `__init__`", + )); + diagnostic.info(format_args!( + "Expected compatible arguments for `{}`", + Type::FunctionLiteral(init_function).display(db), + )); + } + } +}