From c52a073c308605a9657f45a8d1a34ad462d56514 Mon Sep 17 00:00:00 2001 From: Kavi Gupta Date: Wed, 29 Nov 2023 18:58:42 -0500 Subject: [PATCH] WIP multi argument lambdas --- src/analysis.rs | 14 +++++++------- src/eval.rs | 3 ++- src/expr.rs | 42 +++++++++++++++++++++++------------------- src/parse_expr.rs | 31 +++++++++++++++++++++++++------ src/slow_types.rs | 3 ++- 5 files changed, 59 insertions(+), 34 deletions(-) diff --git a/src/analysis.rs b/src/analysis.rs index 49166e8..0038df4 100644 --- a/src/analysis.rs +++ b/src/analysis.rs @@ -74,7 +74,7 @@ impl Analysis for ExprCost { Node::App(f, x) => { analyzed.shared.cost_app + analyzed.nodes[*f] + analyzed.nodes[*x] } - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { analyzed.shared.cost_lam + analyzed.nodes[*b] } } @@ -91,7 +91,7 @@ impl Analysis for &ExprCost { Node::App(f, x) => { analyzed.shared.cost_app + analyzed.nodes[*f] + analyzed.nodes[*x] } - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { analyzed.shared.cost_lam + analyzed.nodes[*b] } } @@ -111,7 +111,7 @@ impl Analysis for DepthAnalysis { Node::App(f, x) => { 1 + std::cmp::max(analyzed.nodes[*f], analyzed.nodes[*x]) } - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { 1 + analyzed.nodes[*b] } } @@ -134,10 +134,10 @@ impl Analysis for FreeVarAnalysis { free.extend(analyzed[*f].iter()); free.extend(analyzed[*x].iter()); } - Node::Lam(b, _) => { + Node::Lam(b, count, _) => { free.extend(analyzed[*b].iter() - .filter(|i| **i > 0) - .map(|i| i - 1) + .filter(|i| **i > count - 1) + .map(|i| i - count) ); } } @@ -161,7 +161,7 @@ impl Analysis for IVarAnalysis { free.extend(analyzed[*f].iter()); free.extend(analyzed[*x].iter()); } - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { free.extend(analyzed[*b].iter()); } } diff --git a/src/eval.rs b/src/eval.rs index e4f089c..5cdd520 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -229,7 +229,8 @@ impl<'a, D: Domain> Evaluator<'a,D> { None => panic!("Prim `{}` not found",p), } } - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { + // tODO Val::LamClosure(*b, env.clone()) } }; diff --git a/src/expr.rs b/src/expr.rs index 7c37ce5..4a32de1 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -26,7 +26,7 @@ pub enum Node where Var(i32, Tag), // db index ($i), tag IVar(i32), // abstraction ("invention") variable App(Idx,Idx), // f, x - Lam(Idx, Tag), // body, tag + Lam(Idx, i32, Tag), // body, lam count, tag } /// An untyped lambda calculus expression or set of expressions @@ -85,7 +85,7 @@ impl ExprOwned { Node::Var(_, _) => cost_fn.cost_var, Node::Prim(p) => *cost_fn.cost_prim.get(p).unwrap_or(&cost_fn.cost_prim_default), Node::App(_, _) => cost_fn.cost_app, - Node::Lam(_, _) => cost_fn.cost_lam, + Node::Lam(_, _, _) => cost_fn.cost_lam, }).sum::() } pub fn depth(&self) -> usize { @@ -147,7 +147,7 @@ impl ExprSet { let span = match node { Node::Var(_, _) | Node::Prim(_) | Node::IVar(_) => idx .. idx+1, Node::App(f, x) => min(min(spans[f].start,spans[x].start),idx) .. max(max(spans[f].end,spans[x].end),idx+1), - Node::Lam(b, _) => min(spans[b].start,idx) .. max(spans[b].end,idx+1) + Node::Lam(b, _, _) => min(spans[b].start,idx) .. max(spans[b].end,idx+1) }; spans.push(span); } @@ -215,7 +215,7 @@ impl<'a> Expr<'a> { match self.node() { Node::Var(_, _) | Node::Prim(_) | Node::IVar(_) => vec![].into_iter(), Node::App(f, x) => vec![*f, *x].into_iter(), - Node::Lam(b, _) => vec![*b].into_iter() + Node::Lam(b, _, _) => vec![*b].into_iter() } } /// assuming this is an App, get the subexpression to the left @@ -238,7 +238,7 @@ impl<'a> Expr<'a> { #[inline(always)] pub fn body(&self) -> Self { match self.node() { - Node::Lam(b, _) => self.get(*b), + Node::Lam(b, _, _) => self.get(*b), _ => panic!("get_body called on non-Lam") } } @@ -260,7 +260,7 @@ impl<'a> Expr<'a> { Node::Var(_, _) => cost_fn.cost_var, Node::Prim(p) => *cost_fn.cost_prim.get(p).unwrap_or(&cost_fn.cost_prim_default), Node::App(_, _) => cost_fn.cost_app, - Node::Lam(_, _) => cost_fn.cost_lam, + Node::Lam(_, _, _) => cost_fn.cost_lam, }).sum::(); debug_assert_eq!(res, self.cost_rec(cost_fn)); res @@ -277,7 +277,7 @@ impl<'a> Expr<'a> { Node::App(f, x) => { cost_fn.cost_app + self.get(*f).cost_rec(cost_fn) + self.get(*x).cost_rec(cost_fn) } - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { cost_fn.cost_lam + self.get(*b).cost_rec(cost_fn) } } @@ -293,7 +293,7 @@ impl<'a> Expr<'a> { match node { Node::Prim(_) | Node::Var(_, _) | Node::IVar(_) => node.clone(), Node::App(f, x) => Node::App((*f as i32 + shift) as usize, (*x as i32 + shift) as usize), - Node::Lam(b, tag) => Node::Lam((*b as i32 + shift) as usize, *tag), + Node::Lam(b, arity, tag) => Node::Lam((*b as i32 + shift) as usize, *arity, *tag), } })); @@ -341,9 +341,9 @@ impl<'a> Expr<'a> { let x = helper(e.get(*x), other_set); other_set.add(Node::App(f, x)) } - Node::Lam(b, tag) => { + Node::Lam(b, arity, tag) => { let b = helper(e.get(*b), other_set); - other_set.add(Node::Lam(b, *tag)) + other_set.add(Node::Lam(b, *arity, *tag)) } } } @@ -360,7 +360,7 @@ impl<'a> Expr<'a> { Order::ParentFirst => (*f == HOLE || *f > self.idx) && (*x == HOLE || *x > self.idx), Order::Any => *f != self.idx && *x != self.idx, }, - Node::Lam(b, _) => match self.set.order { + Node::Lam(b, _, _) => match self.set.order { Order::ChildFirst => *b == HOLE || *b < self.idx, Order::ParentFirst => *b == HOLE || *b > self.idx, Order::Any => *b != self.idx, @@ -411,7 +411,7 @@ impl<'a> ExprMut<'a> { *y = idx; } }, - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { assert_eq!(*b, HOLE, "invalid expand() on non-hole"); *b = idx } @@ -428,7 +428,7 @@ impl<'a> ExprMut<'a> { assert_eq!(*y, HOLE, "invalid expand_right() on non-hole"); *y = idx; }, - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { assert_eq!(*b, HOLE, "invalid expand_right() on non-hole"); *b = idx } @@ -447,7 +447,7 @@ impl<'a> ExprMut<'a> { *x = HOLE; } }, - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { *b = HOLE } _ => panic!("invalid unexpand() on non-lam non-app: {:?}", self.node()) @@ -464,7 +464,7 @@ impl<'a> ExprMut<'a> { *y = HOLE; } }, - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { *b = HOLE } _ => panic!("invalid unexpand_right() on non-lam non-app: {:?}", self.node()) @@ -490,9 +490,9 @@ impl<'a> ExprMut<'a> { let x = self.get(x).shift(incr_by, init_depth, analyzed_free_vars); self.set.add(Node::App(f, x)) }, - Node::Lam(b, tag) => { - let b = self.get(b).shift(incr_by, init_depth+1, analyzed_free_vars); - self.set.add(Node::Lam(b, tag)) + Node::Lam(b, arity, tag) => { + let b = self.get(b).shift(incr_by, init_depth+arity, analyzed_free_vars); + self.set.add(Node::Lam(b, arity, tag)) }, } } @@ -623,7 +623,7 @@ mod tests { let app1 = e.add(Node::App(HOLE,HOLE)); let app2 = e.add(Node::App(HOLE,HOLE)); let plus = e.add(Node::Prim("+".into())); - let lam = e.add(Node::Lam(HOLE, -1)); + let lam = e.add(Node::Lam(HOLE, 1, -1)); e.get_mut(app1).expand(app2); e.get_mut(app2).expand(plus); e.get_mut(app1).expand(lam); @@ -672,6 +672,10 @@ mod tests { assert_eq!(AnalyzedExpr::new(FreeVarAnalysis).analyze_get(e.get(idx)), &vec![0].into_iter().collect::>()); assert_eq!(AnalyzedExpr::new(IVarAnalysis).analyze_get(e.get(idx)), &vec![0].into_iter().collect::>()); + let idx_2 = e.parse_extend("(lam (lam ($1 #0 $7 $8 $3)))").unwrap(); + assert_eq!(AnalyzedExpr::new(FreeVarAnalysis).analyze_get(e.get(idx_2)), &vec![1, 5, 6].into_iter().collect::>()); + let idx_3 = e.parse_extend("(lam:2 ($1 #0 $7 $8 $3))").unwrap(); + assert_eq!(AnalyzedExpr::new(FreeVarAnalysis).analyze_get(e.get(idx_3)), &vec![1, 5, 6].into_iter().collect::>()); } } diff --git a/src/parse_expr.rs b/src/parse_expr.rs index c80b818..186af0d 100644 --- a/src/parse_expr.rs +++ b/src/parse_expr.rs @@ -25,8 +25,11 @@ impl Display for Node { }, Self::Prim(p) => write!(f,"{}",p), Self::App(_,_) => write!(f,"app"), - Self::Lam(_, tag) => { + Self::Lam(_, arity, tag) => { write!(f,"lam")?; + if *arity != 1 { + write!(f, ":{}", arity)?; + } if *tag != -1 { write!(f, "_{}", tag)?; } @@ -54,8 +57,11 @@ impl<'a> Display for Expr<'a> { fmt_local(e.get(*x), false, f)?; if !left_of_app { write!(f,")") } else { Ok(()) } }, - Node::Lam(b, tag) => { + Node::Lam(b, arity, tag) => { write!(f,"(lam")?; + if *arity != 1 { + write!(f, ":{}", arity)?; + } if *tag != -1 { write!(f, "_{}", tag)?; } @@ -143,9 +149,22 @@ impl ExprSet { s = &s[..start]; if item_str == "lam" || item_str == "lambda" - || item_str.starts_with("lam_") || item_str.starts_with("lambda_") { + || item_str.starts_with("lam_") || item_str.starts_with("lambda_") + || item_str.starts_with("lam:") || item_str.starts_with("lambda:") { // split on _ and parse the number let mut tag = -1; + let mut arity = 1; + if item_str.contains(':') { + // the number after : but potentially before _, or just after : if no _ + let mut split = item_str.split(':'); + split.next().unwrap(); // strip "lam" + let mut arity_str = split.next().unwrap(); + if arity_str.contains('_') { + let mut split = arity_str.split('_'); + arity_str = split.next().unwrap(); + } + arity = arity_str.parse::().map_err(|e|e.to_string())?; + } if item_str.contains('_') { let mut split = item_str.split('_'); split.next().unwrap(); // strip "lam" @@ -170,7 +189,7 @@ impl ExprSet { return Err(format!("ExprSet parse error: `lam` must always be applied to exactly one argument, like `(lam (foo bar))`: {}",s_init)) } let b: Idx = items.pop().unwrap(); - items.push(self.add(Node::Lam(b, tag))); + items.push(self.add(Node::Lam(b, arity, tag))); // println!("added lam"); if eof { if items.len() != 1 { @@ -283,8 +302,8 @@ mod tests { assert_eq!(set.get(e).node(), &Node::Var(23, 0)); let e = set.parse_extend("(lam_123 (+ $0_1 $1_1))").unwrap(); - let (_, tag) = match set.get(e).node() { - Node::Lam(body, tag) => (*body, *tag), + let (_, _, tag) = match set.get(e).node() { + Node::Lam(body, arity, tag) => (*body, *arity, *tag), x => panic!("expected lam, got {}", x) }; assert_eq!(tag, 123); diff --git a/src/slow_types.rs b/src/slow_types.rs index 1c5255e..5d074f5 100644 --- a/src/slow_types.rs +++ b/src/slow_types.rs @@ -405,7 +405,8 @@ impl<'a> Expr<'a> { ctx.unify(&f_tp, &SlowType::arrow(x_tp, return_tp.clone()))?; Ok(return_tp.apply(ctx)) }, - Node::Lam(b, _) => { + Node::Lam(b, _, _) => { + // TOOD add arity to types let var_tp = ctx.fresh_type_var(); // todo maybe optimize by making this a vecdeque for faster insert/remove at the zero index env.push_front(var_tp.clone());