Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 186 additions & 51 deletions crates/plotnik-lib/src/analyze/type_check/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ use crate::parser::ast::{
use crate::parser::cst::SyntaxKind;
use crate::query::source_map::SourceId;

/// Type annotation kind from `@capture :: Type` syntax.
///
/// The caller decides how to use the annotation based on context:
/// - `String`: always converts the capture to string type
/// - `TypeName`: either names a struct (for scope-creating captures) or creates a Node alias
#[derive(Clone, Copy, Debug)]
enum AnnotationKind {
/// `:: string` - extract text as string
String,
/// `:: TypeName` - custom type name
TypeName(Symbol),
}

/// Inference context for a single pass over the AST.
pub struct InferenceVisitor<'a, 'd> {
pub ctx: &'a mut TypeContext,
Expand Down Expand Up @@ -80,29 +93,32 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
}
}

/// Named node: matches one position, bubbles up child captures.
/// Named node: matches one position, bubbles up child captures or propagates output.
fn infer_named_node(&mut self, node: &NamedNode) -> TermInfo {
let mut merged_fields: BTreeMap<Symbol, FieldInfo> = BTreeMap::new();
let mut output_children: Vec<(TextRange, TypeId)> = Vec::new();

for child in node.children() {
let child_info = self.infer_expr(&child);

if let TypeFlow::Bubble(type_id) = child_info.flow
&& let Some(fields) = self.ctx.get_struct_fields(type_id)
{
for (name, info) in fields {
// Named nodes merge fields silently (union behavior)
merged_fields.entry(*name).or_insert(*info);
match &child_info.flow {
TypeFlow::Bubble(type_id) => {
if let Some(fields) = self.ctx.get_struct_fields(*type_id) {
for (name, info) in fields {
merged_fields.entry(*name).or_insert(*info);
}
}
}
TypeFlow::Scalar(type_id) => {
if self.produces_output(*type_id) {
output_children.push((child.text_range(), *type_id));
}
}
TypeFlow::Void => {}
}
}

let flow = if merged_fields.is_empty() {
TypeFlow::Void
} else {
TypeFlow::Bubble(self.ctx.intern_struct(merged_fields))
};

let flow = self.compute_merged_flow(merged_fields, output_children, node.text_range());
TermInfo::new(Arity::One, flow)
}

Expand Down Expand Up @@ -135,7 +151,7 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
self.infer_expr(body)
}

/// Sequence: Arity aggregation and strict field merging (no duplicates).
/// Sequence: Arity aggregation, strict field merging, and output propagation.
fn infer_seq_expr(&mut self, seq: &SeqExpr) -> TermInfo {
let children: Vec<_> = seq.children().collect();

Expand All @@ -148,25 +164,27 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
};

let mut merged_fields: BTreeMap<Symbol, FieldInfo> = BTreeMap::new();
let mut output_children: Vec<(TextRange, TypeId)> = Vec::new();

for child in &children {
let child_info = self.infer_expr(child);

if let TypeFlow::Bubble(type_id) = child_info.flow {
// Clone fields to release immutable borrow on self.ctx,
// allowing mutable borrow of self for merge_seq_fields.
if let Some(fields) = self.ctx.get_struct_fields(type_id).cloned() {
self.merge_seq_fields(&mut merged_fields, &fields, child.text_range());
match &child_info.flow {
TypeFlow::Bubble(type_id) => {
if let Some(fields) = self.ctx.get_struct_fields(*type_id).cloned() {
self.merge_seq_fields(&mut merged_fields, &fields, child.text_range());
}
}
TypeFlow::Scalar(type_id) => {
if self.produces_output(*type_id) {
output_children.push((child.text_range(), *type_id));
}
}
TypeFlow::Void => {}
}
}

let flow = if merged_fields.is_empty() {
TypeFlow::Void
} else {
TypeFlow::Bubble(self.ctx.intern_struct(merged_fields))
};

let flow = self.compute_merged_flow(merged_fields, output_children, seq.text_range());
TermInfo::new(arity, flow)
}

Expand Down Expand Up @@ -286,10 +304,10 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
};
let capture_name = self.interner.intern(&name_tok.text()[1..]); // Strip @ prefix

let annotation_type = self.resolve_annotation(cap);
let annotation = self.resolve_annotation(cap);
let Some(inner) = cap.inner() else {
// Capture without inner -> creates a Node field
let type_id = annotation_type.unwrap_or(TYPE_NODE);
// Capture without inner -> creates a Node field with optional annotation
let type_id = self.annotation_to_alias(annotation, TYPE_NODE);
let field = FieldInfo::required(type_id);
return TermInfo::new(
Arity::One,
Expand All @@ -309,7 +327,7 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
if should_merge_fields {
// Named node/ref/etc with bubbling fields: capture adds a field,
// inner fields bubble up alongside.
let captured_type = self.determine_non_scope_captured_type(&inner, annotation_type);
let captured_type = self.determine_non_scope_captured_type(&inner, annotation);
let field_info = if is_optional {
FieldInfo::optional(captured_type)
} else {
Expand All @@ -334,7 +352,7 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
} else {
// All other cases: scope-creating captures, scalar flows, void flows.
// Inner becomes the captured type (if applicable).
let captured_type = self.determine_captured_type(&inner, &inner_info, annotation_type);
let captured_type = self.determine_captured_type(&inner, &inner_info, annotation);
let field_info = if is_optional {
FieldInfo::optional(captured_type)
} else {
Expand Down Expand Up @@ -369,33 +387,46 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
}

/// Determines captured type for non-scope-creating expressions.
///
/// For non-scope captures, fields bubble up alongside the capture field.
/// The annotation applies to the capture's type (usually Node or a recursive ref).
fn determine_non_scope_captured_type(
&mut self,
inner: &Expr,
annotation: Option<TypeId>,
annotation: Option<AnnotationKind>,
) -> TypeId {
if let Some(ref_type) = self.get_recursive_ref_type(inner) {
annotation.unwrap_or(ref_type)
} else {
annotation.unwrap_or(TYPE_NODE)
}
let base_type = self.get_recursive_ref_type(inner).unwrap_or(TYPE_NODE);
self.annotation_to_alias(annotation, base_type)
}

/// Resolves explicit type annotation like `@foo: string`.
fn resolve_annotation(&mut self, cap: &CapturedExpr) -> Option<TypeId> {
/// Resolves explicit type annotation like `@foo :: string` or `@foo :: TypeName`.
///
/// Returns the annotation kind without creating types - the caller decides
/// how to use the annotation based on the capture's flow.
fn resolve_annotation(&mut self, cap: &CapturedExpr) -> Option<AnnotationKind> {
cap.type_annotation().and_then(|t| {
t.name().map(|n| {
let text = n.text();
if text == "string" {
TYPE_STRING
AnnotationKind::String
} else {
let sym = self.interner.intern(text);
self.ctx.intern_type(TypeShape::Custom(sym))
AnnotationKind::TypeName(self.interner.intern(text))
}
})
})
}

/// Converts annotation to a type, creating a Node alias for custom type names.
///
/// Used for non-struct contexts where TypeName should create an alias to Node.
fn annotation_to_alias(&mut self, annotation: Option<AnnotationKind>, base: TypeId) -> TypeId {
match annotation {
Some(AnnotationKind::String) => TYPE_STRING,
Some(AnnotationKind::TypeName(name)) => self.ctx.intern_type(TypeShape::Custom(name)),
None => base,
}
}

/// Logic for how quantifier on the inner expression affects the capture field.
/// Returns (Info, is_optional).
fn resolve_capture_inner(&mut self, inner: &Expr) -> (TermInfo, bool) {
Expand All @@ -415,34 +446,59 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
}

/// Transforms the inner flow into a specific TypeId for the field.
///
/// Handles type annotation semantics based on the flow:
/// - Void/Scalar + TypeName: creates a Node alias (current Custom behavior)
/// - Bubble + TypeName: names the struct type instead of replacing it
fn determine_captured_type(
&mut self,
inner: &Expr,
inner_info: &TermInfo,
annotation: Option<TypeId>,
annotation: Option<AnnotationKind>,
) -> TypeId {
match &inner_info.flow {
TypeFlow::Void => {
if let Some(ref_type) = self.get_recursive_ref_type(inner) {
annotation.unwrap_or(ref_type)
} else {
annotation.unwrap_or(TYPE_NODE)
}
let base_type = self.get_recursive_ref_type(inner).unwrap_or(TYPE_NODE);
self.annotation_to_alias(annotation, base_type)
}
TypeFlow::Scalar(type_id) => {
// For array types with annotation, replace the element type
// e.g., `(identifier)* @names :: string` → string[] not string
if let Some(ann) = annotation
if let Some(AnnotationKind::String) = annotation
&& let Some(TypeShape::Array { non_empty, .. }) = self.ctx.get_type(*type_id)
{
return self.ctx.intern_type(TypeShape::Array {
element: ann,
element: TYPE_STRING,
non_empty: *non_empty,
});
}
annotation.unwrap_or(*type_id)
match annotation {
Some(AnnotationKind::String) => TYPE_STRING,
Some(AnnotationKind::TypeName(name)) => {
// For enum types, name the enum instead of creating an alias
if matches!(self.ctx.get_type(*type_id), Some(TypeShape::Enum(_))) {
self.ctx.set_type_name(*type_id, name);
*type_id
} else {
self.ctx.intern_type(TypeShape::Custom(name))
}
}
None => *type_id,
}
}
TypeFlow::Bubble(type_id) => {
// Bubble flow means inner has struct fields (scope-creating capture).
// TypeName annotation should NAME the struct, not replace it with an alias.
match annotation {
Some(AnnotationKind::String) => TYPE_STRING,
Some(AnnotationKind::TypeName(name)) => {
// Register the name for this struct type
self.ctx.set_type_name(*type_id, name);
*type_id
}
None => *type_id,
}
}
TypeFlow::Bubble(type_id) => annotation.unwrap_or(*type_id),
}
}

Expand Down Expand Up @@ -649,6 +705,85 @@ impl<'a, 'd> InferenceVisitor<'a, 'd> {
}
}

/// Check if a type produces meaningful output for propagation.
///
/// Meaningful outputs are structured types (enums, structs, refs) or arrays/optionals
/// of such types. Simple `Node[]` from quantified named nodes is NOT meaningful.
fn produces_output(&self, type_id: TypeId) -> bool {
let Some(shape) = self.ctx.get_type(type_id) else {
return false;
};
match shape {
TypeShape::Enum(_) | TypeShape::Struct(_) | TypeShape::Ref(_) => true,
TypeShape::Array { element, .. } => {
*element != TYPE_NODE && self.produces_output(*element)
}
TypeShape::Optional(inner) => *inner != TYPE_NODE && self.produces_output(*inner),
TypeShape::Node | TypeShape::String | TypeShape::Void | TypeShape::Custom(_) => false,
}
}

/// Compute flow from merged bubble fields and output-producing children.
///
/// Rules:
/// - No bubbles, 0 outputs → Void
/// - No bubbles, 1 output → Forward output (propagate)
/// - No bubbles, 2+ outputs → Error (ambiguous)
/// - Bubbles, 0 outputs → Bubble(struct)
/// - Bubbles, 1+ outputs → Error (require capture)
fn compute_merged_flow(
&mut self,
merged_fields: BTreeMap<Symbol, FieldInfo>,
output_children: Vec<(TextRange, TypeId)>,
parent_range: TextRange,
) -> TypeFlow {
let has_bubbles = !merged_fields.is_empty();

match (has_bubbles, output_children.len()) {
(false, 0) => TypeFlow::Void,
(false, 1) => TypeFlow::Scalar(output_children[0].1),
(false, _) => {
self.report_ambiguous_outputs(parent_range, &output_children);
TypeFlow::Void
}
(true, 0) => TypeFlow::Bubble(self.ctx.intern_struct(merged_fields)),
(true, _) => {
self.report_uncaptured_output_with_captures(&output_children);
TypeFlow::Bubble(self.ctx.intern_struct(merged_fields))
}
}
}

fn report_ambiguous_outputs(
&mut self,
parent_range: TextRange,
outputs: &[(TextRange, TypeId)],
) {
self.diag
.report(
self.source_id,
DiagnosticKind::AmbiguousUncapturedOutputs,
parent_range,
)
.message(format!(
"{} expressions produce output without capture",
outputs.len()
))
.emit();
}

fn report_uncaptured_output_with_captures(&mut self, outputs: &[(TextRange, TypeId)]) {
for (range, _) in outputs {
self.diag
.report(
self.source_id,
DiagnosticKind::UncapturedOutputWithCaptures,
*range,
)
.emit();
}
}

fn report_unify_error(&mut self, range: TextRange, err: &UnifyError) {
let (kind, msg, hint) = match err {
UnifyError::ScalarInUntagged => (
Expand Down
Loading