Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl Analysis for ExprCost {
Node::App(f, x) => {
analyzed.shared.cost_app + analyzed.nodes[*f] + analyzed.nodes[*x]
}
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
analyzed.shared.cost_lam + analyzed.nodes[*b]
}
}
Expand All @@ -91,7 +91,7 @@ impl Analysis for &ExprCost {
Node::App(f, x) => {
analyzed.shared.cost_app + analyzed.nodes[*f] + analyzed.nodes[*x]
}
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
analyzed.shared.cost_lam + analyzed.nodes[*b]
}
}
Expand All @@ -111,7 +111,7 @@ impl Analysis for DepthAnalysis {
Node::App(f, x) => {
1 + std::cmp::max(analyzed.nodes[*f], analyzed.nodes[*x])
}
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
1 + analyzed.nodes[*b]
}
}
Expand All @@ -134,10 +134,10 @@ impl Analysis for FreeVarAnalysis {
free.extend(analyzed[*f].iter());
free.extend(analyzed[*x].iter());
}
Node::Lam(b, _) => {
Node::Lam(b, count, _) => {
free.extend(analyzed[*b].iter()
.filter(|i| **i > 0)
.map(|i| i - 1)
.filter(|i| **i > count - 1)
.map(|i| i - count)
);
}
}
Expand All @@ -161,7 +161,7 @@ impl Analysis for IVarAnalysis {
free.extend(analyzed[*f].iter());
free.extend(analyzed[*x].iter());
}
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
free.extend(analyzed[*b].iter());
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ impl<'a, D: Domain> Evaluator<'a,D> {
None => panic!("Prim `{}` not found",p),
}
}
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
// tODO
Val::LamClosure(*b, env.clone())
}
};
Expand Down
42 changes: 23 additions & 19 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub enum Node where
Var(i32, Tag), // db index ($i), tag
IVar(i32), // abstraction ("invention") variable
App(Idx,Idx), // f, x
Lam(Idx, Tag), // body, tag
Lam(Idx, i32, Tag), // body, lam count, tag
}

/// An untyped lambda calculus expression or set of expressions
Expand Down Expand Up @@ -85,7 +85,7 @@ impl ExprOwned {
Node::Var(_, _) => cost_fn.cost_var,
Node::Prim(p) => *cost_fn.cost_prim.get(p).unwrap_or(&cost_fn.cost_prim_default),
Node::App(_, _) => cost_fn.cost_app,
Node::Lam(_, _) => cost_fn.cost_lam,
Node::Lam(_, _, _) => cost_fn.cost_lam,
}).sum::<i32>()
}
pub fn depth(&self) -> usize {
Expand Down Expand Up @@ -147,7 +147,7 @@ impl ExprSet {
let span = match node {
Node::Var(_, _) | Node::Prim(_) | Node::IVar(_) => idx .. idx+1,
Node::App(f, x) => min(min(spans[f].start,spans[x].start),idx) .. max(max(spans[f].end,spans[x].end),idx+1),
Node::Lam(b, _) => min(spans[b].start,idx) .. max(spans[b].end,idx+1)
Node::Lam(b, _, _) => min(spans[b].start,idx) .. max(spans[b].end,idx+1)
};
spans.push(span);
}
Expand Down Expand Up @@ -215,7 +215,7 @@ impl<'a> Expr<'a> {
match self.node() {
Node::Var(_, _) | Node::Prim(_) | Node::IVar(_) => vec![].into_iter(),
Node::App(f, x) => vec![*f, *x].into_iter(),
Node::Lam(b, _) => vec![*b].into_iter()
Node::Lam(b, _, _) => vec![*b].into_iter()
}
}
/// assuming this is an App, get the subexpression to the left
Expand All @@ -238,7 +238,7 @@ impl<'a> Expr<'a> {
#[inline(always)]
pub fn body(&self) -> Self {
match self.node() {
Node::Lam(b, _) => self.get(*b),
Node::Lam(b, _, _) => self.get(*b),
_ => panic!("get_body called on non-Lam")
}
}
Expand All @@ -260,7 +260,7 @@ impl<'a> Expr<'a> {
Node::Var(_, _) => cost_fn.cost_var,
Node::Prim(p) => *cost_fn.cost_prim.get(p).unwrap_or(&cost_fn.cost_prim_default),
Node::App(_, _) => cost_fn.cost_app,
Node::Lam(_, _) => cost_fn.cost_lam,
Node::Lam(_, _, _) => cost_fn.cost_lam,
}).sum::<i32>();
debug_assert_eq!(res, self.cost_rec(cost_fn));
res
Expand All @@ -277,7 +277,7 @@ impl<'a> Expr<'a> {
Node::App(f, x) => {
cost_fn.cost_app + self.get(*f).cost_rec(cost_fn) + self.get(*x).cost_rec(cost_fn)
}
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
cost_fn.cost_lam + self.get(*b).cost_rec(cost_fn)
}
}
Expand All @@ -293,7 +293,7 @@ impl<'a> Expr<'a> {
match node {
Node::Prim(_) | Node::Var(_, _) | Node::IVar(_) => node.clone(),
Node::App(f, x) => Node::App((*f as i32 + shift) as usize, (*x as i32 + shift) as usize),
Node::Lam(b, tag) => Node::Lam((*b as i32 + shift) as usize, *tag),
Node::Lam(b, arity, tag) => Node::Lam((*b as i32 + shift) as usize, *arity, *tag),
}
}));

Expand Down Expand Up @@ -341,9 +341,9 @@ impl<'a> Expr<'a> {
let x = helper(e.get(*x), other_set);
other_set.add(Node::App(f, x))
}
Node::Lam(b, tag) => {
Node::Lam(b, arity, tag) => {
let b = helper(e.get(*b), other_set);
other_set.add(Node::Lam(b, *tag))
other_set.add(Node::Lam(b, *arity, *tag))
}
}
}
Expand All @@ -360,7 +360,7 @@ impl<'a> Expr<'a> {
Order::ParentFirst => (*f == HOLE || *f > self.idx) && (*x == HOLE || *x > self.idx),
Order::Any => *f != self.idx && *x != self.idx,
},
Node::Lam(b, _) => match self.set.order {
Node::Lam(b, _, _) => match self.set.order {
Order::ChildFirst => *b == HOLE || *b < self.idx,
Order::ParentFirst => *b == HOLE || *b > self.idx,
Order::Any => *b != self.idx,
Expand Down Expand Up @@ -411,7 +411,7 @@ impl<'a> ExprMut<'a> {
*y = idx;
}
},
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
assert_eq!(*b, HOLE, "invalid expand() on non-hole");
*b = idx
}
Expand All @@ -428,7 +428,7 @@ impl<'a> ExprMut<'a> {
assert_eq!(*y, HOLE, "invalid expand_right() on non-hole");
*y = idx;
},
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
assert_eq!(*b, HOLE, "invalid expand_right() on non-hole");
*b = idx
}
Expand All @@ -447,7 +447,7 @@ impl<'a> ExprMut<'a> {
*x = HOLE;
}
},
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
*b = HOLE
}
_ => panic!("invalid unexpand() on non-lam non-app: {:?}", self.node())
Expand All @@ -464,7 +464,7 @@ impl<'a> ExprMut<'a> {
*y = HOLE;
}
},
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
*b = HOLE
}
_ => panic!("invalid unexpand_right() on non-lam non-app: {:?}", self.node())
Expand All @@ -490,9 +490,9 @@ impl<'a> ExprMut<'a> {
let x = self.get(x).shift(incr_by, init_depth, analyzed_free_vars);
self.set.add(Node::App(f, x))
},
Node::Lam(b, tag) => {
let b = self.get(b).shift(incr_by, init_depth+1, analyzed_free_vars);
self.set.add(Node::Lam(b, tag))
Node::Lam(b, arity, tag) => {
let b = self.get(b).shift(incr_by, init_depth+arity, analyzed_free_vars);
self.set.add(Node::Lam(b, arity, tag))
},
}
}
Expand Down Expand Up @@ -623,7 +623,7 @@ mod tests {
let app1 = e.add(Node::App(HOLE,HOLE));
let app2 = e.add(Node::App(HOLE,HOLE));
let plus = e.add(Node::Prim("+".into()));
let lam = e.add(Node::Lam(HOLE, -1));
let lam = e.add(Node::Lam(HOLE, 1, -1));
e.get_mut(app1).expand(app2);
e.get_mut(app2).expand(plus);
e.get_mut(app1).expand(lam);
Expand Down Expand Up @@ -672,6 +672,10 @@ mod tests {
assert_eq!(AnalyzedExpr::new(FreeVarAnalysis).analyze_get(e.get(idx)), &vec![0].into_iter().collect::<FxHashSet<i32>>());
assert_eq!(AnalyzedExpr::new(IVarAnalysis).analyze_get(e.get(idx)), &vec![0].into_iter().collect::<FxHashSet<i32>>());

let idx_2 = e.parse_extend("(lam (lam ($1 #0 $7 $8 $3)))").unwrap();
assert_eq!(AnalyzedExpr::new(FreeVarAnalysis).analyze_get(e.get(idx_2)), &vec![1, 5, 6].into_iter().collect::<FxHashSet<i32>>());

let idx_3 = e.parse_extend("(lam:2 ($1 #0 $7 $8 $3))").unwrap();
assert_eq!(AnalyzedExpr::new(FreeVarAnalysis).analyze_get(e.get(idx_3)), &vec![1, 5, 6].into_iter().collect::<FxHashSet<i32>>());
}
}
31 changes: 25 additions & 6 deletions src/parse_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ impl Display for Node {
},
Self::Prim(p) => write!(f,"{}",p),
Self::App(_,_) => write!(f,"app"),
Self::Lam(_, tag) => {
Self::Lam(_, arity, tag) => {
write!(f,"lam")?;
if *arity != 1 {
write!(f, ":{}", arity)?;
}
if *tag != -1 {
write!(f, "_{}", tag)?;
}
Expand Down Expand Up @@ -54,8 +57,11 @@ impl<'a> Display for Expr<'a> {
fmt_local(e.get(*x), false, f)?;
if !left_of_app { write!(f,")") } else { Ok(()) }
},
Node::Lam(b, tag) => {
Node::Lam(b, arity, tag) => {
write!(f,"(lam")?;
if *arity != 1 {
write!(f, ":{}", arity)?;
}
if *tag != -1 {
write!(f, "_{}", tag)?;
}
Expand Down Expand Up @@ -143,9 +149,22 @@ impl ExprSet {
s = &s[..start];

if item_str == "lam" || item_str == "lambda"
|| item_str.starts_with("lam_") || item_str.starts_with("lambda_") {
|| item_str.starts_with("lam_") || item_str.starts_with("lambda_")
|| item_str.starts_with("lam:") || item_str.starts_with("lambda:") {
// split on _ and parse the number
let mut tag = -1;
let mut arity = 1;
if item_str.contains(':') {
// the number after : but potentially before _, or just after : if no _
let mut split = item_str.split(':');
split.next().unwrap(); // strip "lam"
let mut arity_str = split.next().unwrap();
if arity_str.contains('_') {
let mut split = arity_str.split('_');
arity_str = split.next().unwrap();
}
arity = arity_str.parse::<i32>().map_err(|e|e.to_string())?;
}
if item_str.contains('_') {
let mut split = item_str.split('_');
split.next().unwrap(); // strip "lam"
Expand All @@ -170,7 +189,7 @@ impl ExprSet {
return Err(format!("ExprSet parse error: `lam` must always be applied to exactly one argument, like `(lam (foo bar))`: {}",s_init))
}
let b: Idx = items.pop().unwrap();
items.push(self.add(Node::Lam(b, tag)));
items.push(self.add(Node::Lam(b, arity, tag)));
// println!("added lam");
if eof {
if items.len() != 1 {
Expand Down Expand Up @@ -283,8 +302,8 @@ mod tests {
assert_eq!(set.get(e).node(), &Node::Var(23, 0));

let e = set.parse_extend("(lam_123 (+ $0_1 $1_1))").unwrap();
let (_, tag) = match set.get(e).node() {
Node::Lam(body, tag) => (*body, *tag),
let (_, _, tag) = match set.get(e).node() {
Node::Lam(body, arity, tag) => (*body, *arity, *tag),
x => panic!("expected lam, got {}", x)
};
assert_eq!(tag, 123);
Expand Down
3 changes: 2 additions & 1 deletion src/slow_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ impl<'a> Expr<'a> {
ctx.unify(&f_tp, &SlowType::arrow(x_tp, return_tp.clone()))?;
Ok(return_tp.apply(ctx))
},
Node::Lam(b, _) => {
Node::Lam(b, _, _) => {
// TOOD add arity to types
let var_tp = ctx.fresh_type_var();
// todo maybe optimize by making this a vecdeque for faster insert/remove at the zero index
env.push_front(var_tp.clone());
Expand Down