From 5384dbfd22cabaa8781a4cdf476f23f8df57404a Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sat, 15 Jun 2024 12:01:55 -0700 Subject: [PATCH 01/27] Ignore .DS_Store --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 53547711..55705d54 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,6 @@ perf.data* *.bench .vscode/settings.json out.csv + +# macOS +.DS_Store From 3ed4c7f12be00f772fd6a3aaf0a5a4c93d34ec4f Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sat, 15 Jun 2024 12:02:06 -0700 Subject: [PATCH 02/27] Find enodes matching RHS (`Applier`) of a rewrite rule --- src/egraph.rs | 101 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/src/egraph.rs b/src/egraph.rs index b8688153..0b186611 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1240,6 +1240,107 @@ impl> EGraph { true } + /// Inverse equality saturation + pub fn invert(&mut self, rewrite: &Rewrite) { + let matches = rewrite.search(self); + + if matches.is_empty() { + // TODO: If no matches should I return `Option::None`? Not sure how the API for this should work. + return; + } + + dbg!(&matches); + dbg!(&self.nodes); + + let searcher_pat = Pattern::from( + rewrite + .searcher + .get_pattern_ast() + .expect("Searcher (LHS) of rewrite rule should be a pattern") + .clone(), + ); + + // TODO: Feels hacky to have to reparse `Applier`, is there a better way? + let applier_pat = Pattern::from( + rewrite + .applier + .get_pattern_ast() + .expect("Applier (RHS) of rewrite rule should be a pattern") + .clone(), + ); + + // TODO: I'm kinda paranoid someone will sneak in a pattern that breaks the DAG invariant. Is this enforced by `Rewrite`? + debug_assert!(applier_pat.ast.is_dag()); + + let mut searcher_enode_ids: HashSet = HashSet::default(); + let mut already_has_collision = false; + + let applier_enode_ids: Vec<&Id> = matches + .into_iter() + .flat_map( + |SearchMatches { + substs, + eclass: _, + ast: _, + }| substs.into_iter(), + ) + .filter_map(|subst| { + // TODO: The insane nesting is giving me the heebie-jeebies. Is there a better way to do this? + if let Some(applier_enode_id) = self.find_enode_id(applier_pat.ast.as_ref(), &subst) + { + if !already_has_collision { + if let Some(id) = self.find_enode_id(searcher_pat.ast.as_ref(), &subst) { + searcher_enode_ids.insert(*id); + } + if searcher_enode_ids.contains(applier_enode_id) { + already_has_collision = true; + None + } else { + Some(applier_enode_id) + } + } else { + Some(applier_enode_id) + } + } else { + None + } + }) + .collect(); + + dbg!(&applier_enode_ids); + for id in applier_enode_ids { + dbg!(self.id_to_expr(*id)); + } + } + + /// Find the e-node ID given a PatternAst and a substitution. + /// + /// # Example: TODO not actually done yet + /// ``` + /// use egg::*; + /// let egraph = EGraph::::default(); + /// let enode_id = egraph.find_enode_id(pattern_ast, subst); + /// let enode = egraph.id_to_node(enode_id); + /// ``` + fn find_enode_id(&self, pattern: &[ENodeOrVar], subst: &Subst) -> Option<&Id> { + let mut id_buf: Vec = vec![0.into(); pattern.len()]; + let mut candidate: Option<&Id> = None; + for (i, enode_or_var) in pattern.iter().enumerate() { + let id = match enode_or_var { + ENodeOrVar::Var(var) => subst[*var], + ENodeOrVar::ENode(enode) => { + let substituted_enode = enode + .clone() + .map_children(|child| id_buf[usize::from(child)]); + candidate = self.memo.get(&substituted_enode); + self.lookup(substituted_enode)? + } + }; + id_buf[i] = id; + } + candidate + } + /// Update the analysis data of an e-class. /// /// This also propagates the changes through the e-graph, From 5d6d17825a3c6de699a538fe9b515ae47ed96960 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sun, 16 Jun 2024 01:33:01 -0700 Subject: [PATCH 03/27] Implement simple delete for unionfind --- src/unionfind.rs | 60 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/src/unionfind.rs b/src/unionfind.rs index 39e9bc58..00362a56 100644 --- a/src/unionfind.rs +++ b/src/unionfind.rs @@ -47,6 +47,39 @@ impl UnionFind { *self.parent_mut(root2) = root1; root1 } + + /// TODO: Naive implementation, for a potentially more efficient one see + /// [this paper](https://dl.acm.org/doi/10.1145/2636922), but I think it + /// would triple the memory usage (and be hella more complicated) + /// + /// There's also [this cool paper](https://link.springer.com/article/10.1007/s10817-017-9431-7), + /// although I don't think it covers deletions + pub fn delete(&mut self, query: Id) { + let parent = self.parent(query); + + self.parents.remove(usize::from(query)); + + let mut new_root: Option = None; + for idx in 0..self.parents.len() { + if parent == query { + // Deleted a root node so choose a new root for the children, if any + if self.parents[idx] == query { + if new_root.is_none() { + new_root = Some(Id::from(idx)); + } + self.parents[idx] = new_root.unwrap(); + } + } else { + // Deleting a non-root node + if self.parents[idx] == query { + self.parents[idx] = parent; + } + } + if self.parents[idx] > query { + self.parents[idx] = Id::from(usize::from(self.parents[idx]) - 1); + } + } + } } #[cfg(test)] @@ -89,4 +122,31 @@ mod tests { let expected = vec![0, 0, 0, 0, 4, 5, 6, 6, 6, 6]; assert_eq!(uf.parents, ids(expected)); } + + #[test] + fn delete() { + let mut union_find = UnionFind::default(); + for _ in 0..10 { + union_find.make_set(); + } + + union_find.union(Id::from(0), Id::from(1)); + union_find.union(Id::from(0), Id::from(2)); + union_find.union(Id::from(0), Id::from(3)); + + union_find.union(Id::from(6), Id::from(7)); + union_find.union(Id::from(7), Id::from(8)); + union_find.union(Id::from(8), Id::from(9)); + + assert_eq!(union_find.parents, ids(vec![0, 0, 0, 0, 4, 5, 6, 6, 7, 8])); + + union_find.delete(Id::from(0)); + assert_eq!(union_find.parents, ids(vec![0, 0, 0, 3, 4, 5, 5, 6, 7])); + + union_find.delete(Id::from(4)); + assert_eq!(union_find.parents, ids(vec![0, 0, 0, 3, 4, 4, 5, 6])); + + union_find.delete(Id::from(4)); + assert_eq!(union_find.parents, ids(vec![0, 0, 0, 3, 4, 4, 5])); + } } From 678948e003b6ff252abb87180798b6a88853886b Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 18 Jun 2024 10:56:38 -0700 Subject: [PATCH 04/27] Make `unionfind` API fallible with `Option` --- src/unionfind.rs | 145 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 107 insertions(+), 38 deletions(-) diff --git a/src/unionfind.rs b/src/unionfind.rs index 00362a56..2be36b22 100644 --- a/src/unionfind.rs +++ b/src/unionfind.rs @@ -4,48 +4,50 @@ use std::fmt::Debug; #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct UnionFind { - parents: Vec, + // TODO: Oof doubling memory usage and maybe destroying the cache hit rate + // by using Option instead of Id. Any way to use NonZeroU32 or similar? + parents: Vec>, } impl UnionFind { pub fn make_set(&mut self) -> Id { let id = Id::from(self.parents.len()); - self.parents.push(id); + self.parents.push(Some(id)); id } pub fn size(&self) -> usize { - self.parents.len() + self.parents.iter().filter(|node| node.is_some()).count() } - fn parent(&self, query: Id) -> Id { + fn parent(&self, query: Id) -> Option { self.parents[usize::from(query)] } - fn parent_mut(&mut self, query: Id) -> &mut Id { - &mut self.parents[usize::from(query)] + fn parent_mut(&mut self, query: Id) -> Option<&mut Id> { + (&mut self.parents[usize::from(query)]).as_mut() } - pub fn find(&self, mut current: Id) -> Id { - while current != self.parent(current) { - current = self.parent(current) + pub fn find(&self, mut current: Id) -> Option { + while current != self.parent(current)? { + current = self.parent(current)? } - current + Some(current) } - pub fn find_mut(&mut self, mut current: Id) -> Id { - while current != self.parent(current) { - let grandparent = self.parent(self.parent(current)); - *self.parent_mut(current) = grandparent; + pub fn find_mut(&mut self, mut current: Id) -> Option { + while current != self.parent(current)? { + let grandparent = self.parent(self.parent(current)?)?; + *self.parent_mut(current)? = grandparent; current = grandparent; } - current + Some(current) } /// Given two leader ids, unions the two eclasses making root1 the leader. - pub fn union(&mut self, root1: Id, root2: Id) -> Id { - *self.parent_mut(root2) = root1; - root1 + pub fn union(&mut self, root1: Id, root2: Id) -> Option { + *self.parent_mut(root2)? = root1; + Some(root1) } /// TODO: Naive implementation, for a potentially more efficient one see @@ -57,27 +59,24 @@ impl UnionFind { pub fn delete(&mut self, query: Id) { let parent = self.parent(query); - self.parents.remove(usize::from(query)); + self.parents[usize::from(query)] = None; let mut new_root: Option = None; for idx in 0..self.parents.len() { - if parent == query { + if parent == Some(query) { // Deleted a root node so choose a new root for the children, if any - if self.parents[idx] == query { + if self.parents[idx] == Some(query) { if new_root.is_none() { new_root = Some(Id::from(idx)); } - self.parents[idx] = new_root.unwrap(); + self.parents[idx] = new_root; } } else { // Deleting a non-root node - if self.parents[idx] == query { + if self.parents[idx] == Some(query) { self.parents[idx] = parent; } } - if self.parents[idx] > query { - self.parents[idx] = Id::from(usize::from(self.parents[idx]) - 1); - } } } } @@ -86,8 +85,8 @@ impl UnionFind { mod tests { use super::*; - fn ids(us: impl IntoIterator) -> Vec { - us.into_iter().map(|u| u.into()).collect() + fn ids(us: impl IntoIterator>) -> Vec> { + us.into_iter().map(|u| u.map(|id| id.into())).collect() } #[test] @@ -101,7 +100,7 @@ mod tests { } // test the initial condition of everyone in their own set - assert_eq!(uf.parents, ids(0..n)); + assert_eq!(uf.parents, ids((0..n).map(Some))); // build up one set uf.union(id(0), id(1)); @@ -119,7 +118,18 @@ mod tests { } // indexes: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 - let expected = vec![0, 0, 0, 0, 4, 5, 6, 6, 6, 6]; + let expected = vec![ + Some(0), + Some(0), + Some(0), + Some(0), + Some(4), + Some(5), + Some(6), + Some(6), + Some(6), + Some(6), + ]; assert_eq!(uf.parents, ids(expected)); } @@ -138,15 +148,74 @@ mod tests { union_find.union(Id::from(7), Id::from(8)); union_find.union(Id::from(8), Id::from(9)); - assert_eq!(union_find.parents, ids(vec![0, 0, 0, 0, 4, 5, 6, 6, 7, 8])); - + assert_eq!( + union_find.parents, + ids(vec![ + Some(0), + Some(0), + Some(0), + Some(0), + Some(4), + Some(5), + Some(6), + Some(6), + Some(7), + Some(8) + ]) + ); + + // Deletion leaves vacant nodes to avoid changing IDs (which correspond with indices) + // Since 0 is a root node, its children are assigned a new root (1) union_find.delete(Id::from(0)); - assert_eq!(union_find.parents, ids(vec![0, 0, 0, 3, 4, 5, 5, 6, 7])); - - union_find.delete(Id::from(4)); - assert_eq!(union_find.parents, ids(vec![0, 0, 0, 3, 4, 4, 5, 6])); - + assert_eq!( + union_find.parents, + ids(vec![ + None, + Some(1), + Some(1), + Some(1), + Some(4), + Some(5), + Some(6), + Some(6), + Some(7), + Some(8) + ]) + ); + + // union_find.delete(Id::from(4)); - assert_eq!(union_find.parents, ids(vec![0, 0, 0, 3, 4, 4, 5])); + assert_eq!( + union_find.parents, + ids(vec![ + None, + Some(1), + Some(1), + Some(1), + None, + Some(5), + Some(6), + Some(6), + Some(7), + Some(8) + ]) + ); + + union_find.delete(Id::from(6)); + assert_eq!( + union_find.parents, + ids(vec![ + None, + Some(1), + Some(1), + Some(1), + None, + Some(5), + None, + Some(7), + Some(7), + Some(8) + ]) + ); } } From 2567dfffe0362d25ebd4680faef7c98b43bceb52 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 18 Jun 2024 18:05:15 -0700 Subject: [PATCH 05/27] Remove enodes given an array of IDs and the roots --- src/egraph.rs | 135 ++++++++++++++++++++++++++++++++++++++++++------- src/explain.rs | 27 ++++++---- 2 files changed, 135 insertions(+), 27 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 0b186611..cb08dc47 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -576,13 +576,17 @@ impl> EGraph { /// assert_eq!(egraph.find(x), egraph.find(y)); /// ``` pub fn find(&self, id: Id) -> Id { - self.unionfind.find(id) + self.unionfind + .find(id) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", id)) } /// This is private, but internals should use this whenever /// possible because it does path compression. fn find_mut(&mut self, id: Id) -> Id { - self.unionfind.find_mut(id) + self.unionfind + .find_mut(id) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", id)) } /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. @@ -1241,7 +1245,9 @@ impl> EGraph { } /// Inverse equality saturation - pub fn invert(&mut self, rewrite: &Rewrite) { + /// + /// `roots` must be a list of canonical e-class IDs. + pub fn undo_rewrite(&mut self, rewrite: &Rewrite, roots: Vec) { let matches = rewrite.search(self); if matches.is_empty() { @@ -1249,9 +1255,6 @@ impl> EGraph { return; } - dbg!(&matches); - dbg!(&self.nodes); - let searcher_pat = Pattern::from( rewrite .searcher @@ -1275,7 +1278,7 @@ impl> EGraph { let mut searcher_enode_ids: HashSet = HashSet::default(); let mut already_has_collision = false; - let applier_enode_ids: Vec<&Id> = matches + let applier_enode_ids: HashSet = matches .into_iter() .flat_map( |SearchMatches { @@ -1296,21 +1299,17 @@ impl> EGraph { already_has_collision = true; None } else { - Some(applier_enode_id) + Some(*applier_enode_id) } } else { - Some(applier_enode_id) + Some(*applier_enode_id) } } else { None } }) .collect(); - - dbg!(&applier_enode_ids); - for id in applier_enode_ids { - dbg!(self.id_to_expr(*id)); - } + self.remove_enodes(applier_enode_ids, roots); } /// Find the e-node ID given a PatternAst and a substitution. @@ -1341,6 +1340,104 @@ impl> EGraph { candidate } + /// Removes specified enodes and cleans up the resulting egraph, in particular by removing unreachable eclasses. + fn remove_enodes(&mut self, enode_ids: HashSet, roots: Vec) { + // TODO: is this necesary + assert!(self.clean, "egraph must be clean before removing enodes"); + self.clean = false; + + let roots: Vec = roots.iter().map(|id| self.find_mut(*id)).collect(); + dbg!(&enode_ids); + dbg!(&roots); + + let mut visited_eclasses = HashSet::::default(); + + // Remove the input enodes from their corresponding eclasses + for id in enode_ids { + let enode_to_remove = self.id_to_node(id).clone(); + let eclass_id = &self.find_mut(id); + let eclass = self.classes.get_mut(eclass_id).unwrap(); + + // TODO: Is it faster to use `swap_remove`? Not sure if that's even + // possible because it would break binary search after the first + // `swap_remove` + eclass + .nodes + .remove(eclass.nodes.binary_search(&enode_to_remove).unwrap()); + if !eclass.is_empty() { + visited_eclasses.insert(*eclass_id); + } + + // Remove enode from parent arrays of children eclasses + for eclass_id in enode_to_remove.children() { + let eclass = self.classes.get_mut(eclass_id).unwrap(); + eclass.parents.swap_remove( + eclass + .parents + .iter() + .position(|&parent_id| parent_id == id) + .expect("enode should be in parents array of its children eclasses"), + ); + } + } + + let mut visited_enodes = HashSet::::default(); + let mut dfs_stack: Vec<&L> = roots + .iter() + .flat_map(|id| match self.classes.get(id) { + Some(eclass) => eclass.nodes.iter(), + None => [].iter(), + }) + .collect(); + + // Traverse egraph from roots to leaves, marking visited eclasses and enodes + let mut counter = 0; + while let Some(enode) = dfs_stack.pop() { + let enode_id = *self.memo.get(enode).unwrap(); + let eclass_id = self.find(enode_id); + + visited_eclasses.insert(eclass_id); + visited_enodes.insert(enode.clone()); + + let children_enodes: Vec<&L> = enode + .children() + .iter() + // Avoid following cycles + .filter(|child| !visited_eclasses.contains(child)) + .flat_map(|child| self.classes.get(child).unwrap().iter()) + .collect(); + dfs_stack.extend(children_enodes); + + counter += 1; + if counter == 5 { + panic!(); + } + } + + // Remove unreachable enodes + self.memo.retain(|enode, _| visited_enodes.contains(enode)); + + // Remove unreachable eclasses + // TODO: Very ugly, maybe have `visited_eclasses` be a `HashMap`? + let unreachable_eclasses = self + .classes + .keys() + .copied() + .collect::>() + .difference(&visited_eclasses) + .copied() + .collect::>(); + for eclass_id in unreachable_eclasses { + self.unionfind.delete(eclass_id); + self.classes.remove(&eclass_id); + self.classes_by_op.values_mut().for_each(|op| { + op.remove(&eclass_id); + }); + } + + dbg!(self.dump()); + } + /// Update the analysis data of an e-class. /// /// This also propagates the changes through the e-graph, @@ -1406,10 +1503,12 @@ impl> EGraph { for class in self.classes.values_mut() { let old_len = class.len(); - class - .nodes - .iter_mut() - .for_each(|n| n.update_children(|id| uf.find_mut(id))); + class.nodes.iter_mut().for_each(|n| { + n.update_children(|id| { + uf.find_mut(id) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", id)) + }) + }); class.nodes.sort_unstable(); class.nodes.dedup(); diff --git a/src/explain.rs b/src/explain.rs index 33cc0bb4..e0d290fc 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -1607,13 +1607,14 @@ impl<'x, L: Language> ExplainNodes<'x, L> { 'outer: for eclass in classes.keys() { let enodes = self.find_all_enodes(*eclass); // find all congruence nodes - let mut cannon_enodes: HashMap> = Default::default(); + let mut canonical_enodes: HashMap> = Default::default(); for enode in &enodes { - let cannon = self - .node(*enode) - .clone() - .map_children(|child| unionfind.find(child)); - if let Some(others) = cannon_enodes.get_mut(&cannon) { + let canonical = self.node(*enode).clone().map_children(|child| { + unionfind + .find(child) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", child)) + }); + if let Some(others) = canonical_enodes.get_mut(&canonical) { for other in others.iter() { congruence_neighbors[usize::from(*enode)].push(*other); congruence_neighbors[usize::from(*other)].push(*enode); @@ -1622,7 +1623,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { others.push(*enode); } else { counter += 1; - cannon_enodes.insert(cannon, vec![*enode]); + canonical_enodes.insert(canonical, vec![*enode]); } // Don't find every congruence edge because that could be n^2 edges if counter > CONGRUENCE_LIMIT * self.explainfind.len() { @@ -1830,14 +1831,22 @@ impl<'x, L: Language> ExplainNodes<'x, L> { common_ancestor, ); unionfind.union(enode, *child); - ancestor[usize::from(unionfind.find(enode))] = enode; + ancestor[usize::from( + unionfind + .find(enode) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", child)), + )] = enode; } if common_ancestor_queries.get(&enode).is_some() { black_set.insert(enode); for other in common_ancestor_queries.get(&enode).unwrap() { if black_set.contains(other) { - let ancestor = ancestor[usize::from(unionfind.find(*other))]; + let ancestor = ancestor[usize::from( + unionfind + .find(*other) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", other)), + )]; common_ancestor.insert((enode, *other), ancestor); common_ancestor.insert((*other, enode), ancestor); } From 7fdd659fb98f4c291c27e6685b800e88ced6d194 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 20 Jun 2024 14:28:27 -0700 Subject: [PATCH 06/27] Undo multiple rewrites at once --- src/egraph.rs | 143 ++++++++++++++++++++++++++++---------------------- 1 file changed, 81 insertions(+), 62 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index cb08dc47..3e7bdb00 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -2,6 +2,7 @@ use crate::*; use std::{ borrow::BorrowMut, fmt::{self, Debug, Display}, + iter::{repeat, zip}, marker::PhantomData, }; @@ -1246,73 +1247,95 @@ impl> EGraph { /// Inverse equality saturation /// - /// `roots` must be a list of canonical e-class IDs. - pub fn undo_rewrite(&mut self, rewrite: &Rewrite, roots: Vec) { - let matches = rewrite.search(self); - - if matches.is_empty() { - // TODO: If no matches should I return `Option::None`? Not sure how the API for this should work. + /// `roots` can be non-canonical + /// + /// Note that "undoing" rewrites will not necessarily return the egraph to a + /// previous state. This method removes all but one equivalent + /// representation based on the rewrite rules to undo, but that equivalent + /// representation might not be the original one. It might even have more + /// terms. + /// + /// TODO: Above explanation might be confusing, write a better one. + pub fn undo_rewrites( + &mut self, + rewrites_to_undo: &[Rewrite], + all_rewrites: &[Rewrite], + roots: Vec, + ) { + // TODO: Maybe optimize by iterating and collecting `applier_enode_ids` + // without collecting `matches` in-between? + let patterns_and_matches: Vec<(Pattern, Vec>)> = zip( + rewrites_to_undo.iter().map(|rewrite| { + Pattern::from( + rewrite + .applier + .get_pattern_ast() + .expect("Applier (RHS) of rewrite rule should be a pattern") + .clone(), + ) + }), + rewrites_to_undo.iter().map(|rewrite| rewrite.search(self)), + ) + .collect(); + + if patterns_and_matches.is_empty() { + // TODO: If no matches should I return `Option::None`? Not sure how + // the API for this should work. return; } - let searcher_pat = Pattern::from( - rewrite - .searcher - .get_pattern_ast() - .expect("Searcher (LHS) of rewrite rule should be a pattern") - .clone(), - ); - - // TODO: Feels hacky to have to reparse `Applier`, is there a better way? - let applier_pat = Pattern::from( - rewrite - .applier - .get_pattern_ast() - .expect("Applier (RHS) of rewrite rule should be a pattern") - .clone(), - ); - - // TODO: I'm kinda paranoid someone will sneak in a pattern that breaks the DAG invariant. Is this enforced by `Rewrite`? - debug_assert!(applier_pat.ast.is_dag()); - - let mut searcher_enode_ids: HashSet = HashSet::default(); - let mut already_has_collision = false; + // TODO: Feels hacky to have to reparse `Searcher` and `Applier`, is + // there a better way? + let mut maybe_colliding_searcher_patterns: Vec> = all_rewrites + .iter() + .map(|rewrite| { + Pattern::from( + rewrite + .searcher + .get_pattern_ast() + .expect("Searcher (LHS) of rewrite rule should be a pattern") + .clone(), + ) + }) + .collect(); - let applier_enode_ids: HashSet = matches + let applier_enode_ids: HashSet = patterns_and_matches .into_iter() - .flat_map( - |SearchMatches { - substs, - eclass: _, - ast: _, - }| substs.into_iter(), - ) - .filter_map(|subst| { - // TODO: The insane nesting is giving me the heebie-jeebies. Is there a better way to do this? - if let Some(applier_enode_id) = self.find_enode_id(applier_pat.ast.as_ref(), &subst) - { - if !already_has_collision { - if let Some(id) = self.find_enode_id(searcher_pat.ast.as_ref(), &subst) { - searcher_enode_ids.insert(*id); - } - if searcher_enode_ids.contains(applier_enode_id) { - already_has_collision = true; - None - } else { - Some(*applier_enode_id) - } - } else { - Some(*applier_enode_id) - } - } else { - None + .flat_map(|(applier_pat, all_search_matches)| { + zip(repeat(applier_pat), all_search_matches.into_iter()) + }) + .flat_map(|(applier_pat, search_matches)| { + zip(repeat(applier_pat), search_matches.substs.into_iter()) + }) + .filter_map(|(applier_pat, subst)| { + dbg!(&applier_pat.ast); + dbg!(&subst); + let applier_enode_id = self.find_enode_id(applier_pat.ast.as_ref(), &subst)?; + dbg!(self.id_to_node(*applier_enode_id)); + + // Check for collisions with any searcher pattern. If any are + // found, do not mark the enode for removal. Since at most one + // instance of each searcher pattern must exist in the egraph, + // remove the colliding searcher pattern(s) as candidates for + // collision checking. + let before = maybe_colliding_searcher_patterns.len(); + maybe_colliding_searcher_patterns.retain(|searcher_pat| { + self.find_enode_id(searcher_pat.ast.as_ref(), &subst) + .is_none() + }); + if before > maybe_colliding_searcher_patterns.len() { + dbg!(before - maybe_colliding_searcher_patterns.len()); + return None; } + + Some(*applier_enode_id) }) .collect(); + self.remove_enodes(applier_enode_ids, roots); } - /// Find the e-node ID given a PatternAst and a substitution. + /// Find the e-node ID given a pattern and a substitution. /// /// # Example: TODO not actually done yet /// ``` @@ -1326,7 +1349,7 @@ impl> EGraph { let mut candidate: Option<&Id> = None; for (i, enode_or_var) in pattern.iter().enumerate() { let id = match enode_or_var { - ENodeOrVar::Var(var) => subst[*var], + ENodeOrVar::Var(var) => *subst.get(*var)?, ENodeOrVar::ENode(enode) => { let substituted_enode = enode .clone() @@ -1391,7 +1414,6 @@ impl> EGraph { .collect(); // Traverse egraph from roots to leaves, marking visited eclasses and enodes - let mut counter = 0; while let Some(enode) = dfs_stack.pop() { let enode_id = *self.memo.get(enode).unwrap(); let eclass_id = self.find(enode_id); @@ -1408,10 +1430,7 @@ impl> EGraph { .collect(); dfs_stack.extend(children_enodes); - counter += 1; - if counter == 5 { - panic!(); - } + dbg!(&dfs_stack); } // Remove unreachable enodes From 90c59eda82b509a036fd360cad3e76284dc74b12 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 20 Jun 2024 14:36:26 -0700 Subject: [PATCH 07/27] Avoid traversing egraph when no enodes to remove --- src/egraph.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/egraph.rs b/src/egraph.rs index 3e7bdb00..63363163 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1363,8 +1363,12 @@ impl> EGraph { candidate } - /// Removes specified enodes and cleans up the resulting egraph, in particular by removing unreachable eclasses. + /// Removes specified enodes and cleans up the resulting egraph, in + /// particular by removing unreachable eclasses and enodes. fn remove_enodes(&mut self, enode_ids: HashSet, roots: Vec) { + if enode_ids.is_empty() { + return; + } // TODO: is this necesary assert!(self.clean, "egraph must be clean before removing enodes"); self.clean = false; From ff4e61ebb3083ba800a86516fe41f3030abb343e Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 20 Jun 2024 18:23:15 -0700 Subject: [PATCH 08/27] Canonicalize children of enodes to be removed Also, remove `dbg!` statements and don't mark egraph as dirty after removing enodes --- src/egraph.rs | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 63363163..10ce199b 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1308,10 +1308,7 @@ impl> EGraph { zip(repeat(applier_pat), search_matches.substs.into_iter()) }) .filter_map(|(applier_pat, subst)| { - dbg!(&applier_pat.ast); - dbg!(&subst); let applier_enode_id = self.find_enode_id(applier_pat.ast.as_ref(), &subst)?; - dbg!(self.id_to_node(*applier_enode_id)); // Check for collisions with any searcher pattern. If any are // found, do not mark the enode for removal. Since at most one @@ -1324,7 +1321,6 @@ impl> EGraph { .is_none() }); if before > maybe_colliding_searcher_patterns.len() { - dbg!(before - maybe_colliding_searcher_patterns.len()); return None; } @@ -1366,23 +1362,31 @@ impl> EGraph { /// Removes specified enodes and cleans up the resulting egraph, in /// particular by removing unreachable eclasses and enodes. fn remove_enodes(&mut self, enode_ids: HashSet, roots: Vec) { + // Pretty sure this is required since e.g. rebuilding dedups enodes in + // eclasses, `remove_enodes` can't handle duplicate enodes + // TODO: Better explanation + assert!( + self.clean, + "cannot remove enodes without a clean egraph, try rebuilding" + ); + if enode_ids.is_empty() { return; } - // TODO: is this necesary - assert!(self.clean, "egraph must be clean before removing enodes"); - self.clean = false; let roots: Vec = roots.iter().map(|id| self.find_mut(*id)).collect(); - dbg!(&enode_ids); - dbg!(&roots); let mut visited_eclasses = HashSet::::default(); // Remove the input enodes from their corresponding eclasses - for id in enode_ids { - let enode_to_remove = self.id_to_node(id).clone(); - let eclass_id = &self.find_mut(id); + for enode_id in enode_ids { + // Canonicalize enode's children so it can be found in + // `eclass.nodes` + let enode_to_remove = self + .id_to_node(enode_id) + .clone() + .map_children(|child| self.find_mut(child)); + let eclass_id = &self.find_mut(enode_id); let eclass = self.classes.get_mut(eclass_id).unwrap(); // TODO: Is it faster to use `swap_remove`? Not sure if that's even @@ -1402,7 +1406,7 @@ impl> EGraph { eclass .parents .iter() - .position(|&parent_id| parent_id == id) + .position(|&parent_id| parent_id == enode_id) .expect("enode should be in parents array of its children eclasses"), ); } @@ -1433,8 +1437,6 @@ impl> EGraph { .flat_map(|child| self.classes.get(child).unwrap().iter()) .collect(); dfs_stack.extend(children_enodes); - - dbg!(&dfs_stack); } // Remove unreachable enodes @@ -1457,8 +1459,6 @@ impl> EGraph { op.remove(&eclass_id); }); } - - dbg!(self.dump()); } /// Update the analysis data of an e-class. From b781566a6baa0bd0383f862090a880f271e47105 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 20 Jun 2024 21:18:08 -0700 Subject: [PATCH 09/27] Ignore doctest for `find_enode_id` --- src/egraph.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 10ce199b..1a93423a 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1334,9 +1334,9 @@ impl> EGraph { /// Find the e-node ID given a pattern and a substitution. /// /// # Example: TODO not actually done yet - /// ``` + /// ```ignore /// use egg::*; - /// let egraph = EGraph::::default(); + /// let egraph = EGraph::::default(); /// let enode_id = egraph.find_enode_id(pattern_ast, subst); /// let enode = egraph.id_to_node(enode_id); /// ``` From 5cc57296c12670298f6561619efad4e021aa2d95 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 20 Jun 2024 22:07:56 -0700 Subject: [PATCH 10/27] Return instead of panic when removal enode not in eclass --- src/egraph.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 1a93423a..662610c7 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1258,8 +1258,8 @@ impl> EGraph { /// TODO: Above explanation might be confusing, write a better one. pub fn undo_rewrites( &mut self, - rewrites_to_undo: &[Rewrite], - all_rewrites: &[Rewrite], + rewrites_to_undo: Vec<&Rewrite>, + all_rewrites: Vec<&Rewrite>, roots: Vec, ) { // TODO: Maybe optimize by iterating and collecting `applier_enode_ids` @@ -1394,7 +1394,18 @@ impl> EGraph { // `swap_remove` eclass .nodes - .remove(eclass.nodes.binary_search(&enode_to_remove).unwrap()); + .remove(match eclass.nodes.binary_search(&enode_to_remove) { + Ok(idx) => idx, + // TODO: I'm not sure why the Err path is ever taken, needs + // investigation + Err(_) => { + println!( + "enode to remove ({:?}) not found in eclass: {:?}", + enode_to_remove, eclass + ); + return; + } + }); if !eclass.is_empty() { visited_eclasses.insert(*eclass_id); } From 166c8c0953d8207fee884847b995bb21e8d2ff18 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sun, 23 Jun 2024 19:51:27 -0700 Subject: [PATCH 11/27] Accept `IntoIterator` rewrites when undoing Matches type signature of `Runner::run` --- src/egraph.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 662610c7..1ae5904b 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1256,12 +1256,15 @@ impl> EGraph { /// terms. /// /// TODO: Above explanation might be confusing, write a better one. - pub fn undo_rewrites( - &mut self, - rewrites_to_undo: Vec<&Rewrite>, - all_rewrites: Vec<&Rewrite>, - roots: Vec, - ) { + pub fn undo_rewrites<'a, R>(&mut self, rewrites_to_undo: R, all_rewrites: R, roots: Vec) + where + R: IntoIterator>, + L: 'a, + N: 'a, + { + let rewrites_to_undo: Vec<_> = rewrites_to_undo.into_iter().collect(); + let all_rewrites: Vec<_> = all_rewrites.into_iter().collect(); + // TODO: Maybe optimize by iterating and collecting `applier_enode_ids` // without collecting `matches` in-between? let patterns_and_matches: Vec<(Pattern, Vec>)> = zip( From d75a42ed96fd16c274eb28fcd88ede38b90fbeb9 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 27 Jun 2024 01:13:40 -0400 Subject: [PATCH 12/27] Use `HashMap` to access `EGraph`'s enodes by `Id` Breaks explanations, specifically `diff_power_harder` in tests/math.rs fails --- src/egraph.rs | 26 ++++++++++++-------------- src/explain.rs | 6 +++--- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 1ae5904b..89f6ee12 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -58,7 +58,7 @@ pub struct EGraph> { pub(crate) explain: Option>, unionfind: UnionFind, /// Stores the original node represented by each non-canonical id - nodes: Vec, + nodes: HashMap, /// Stores each enode's `Id`, not the `Id` of the eclass. /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new /// unions can cause them to become out of date. @@ -221,7 +221,7 @@ impl> EGraph { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); } let mut egraph = Self::new(analysis); - for node in &self.nodes { + for (_, node) in &self.nodes { egraph.add(node.clone()); } egraph @@ -370,7 +370,7 @@ impl> EGraph { /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep pub fn id_to_node(&self, id: Id) -> &L { - &self.nodes[usize::from(id)] + self.nodes.get(&id).unwrap() } /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. @@ -657,8 +657,8 @@ where nodes: src_egraph .nodes .into_iter() - .map(|x| self.map_node(x)) - .collect(), + .map(|(id, enode)| (id, self.map_node(enode))) + .collect::>(), analysis_pending: src_egraph.analysis_pending, classes: src_egraph .classes @@ -1052,8 +1052,7 @@ impl> EGraph { } else { let new_id = self.unionfind.make_set(); explain.add(original.clone(), new_id, new_id); - debug_assert_eq!(Id::from(self.nodes.len()), new_id); - self.nodes.push(original); + self.nodes.insert(new_id, original); self.unionfind.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); new_id @@ -1085,8 +1084,7 @@ impl> EGraph { parents: Default::default(), }; - debug_assert_eq!(Id::from(self.nodes.len()), id); - self.nodes.push(original); + self.nodes.insert(id, enode.clone()); // add this enode to the parent lists of its children enode.for_each(|child| { @@ -1621,13 +1619,13 @@ impl> EGraph { let mut n_unions = 0; while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some(class) = self.pending.pop() { - let mut node = self.nodes[usize::from(class)].clone(); + while let Some(class_id) = self.pending.pop() { + let mut node = self.nodes.get(&class_id).unwrap().clone(); node.update_children(|id| self.find_mut(id)); - if let Some(memo_class) = self.memo.insert(node, class) { + if let Some(memo_class) = self.memo.insert(node, class_id) { let did_something = self.perform_union( memo_class, - class, + class_id, Some(Justification::Congruence), false, ); @@ -1636,7 +1634,7 @@ impl> EGraph { } while let Some(class_id) = self.analysis_pending.pop() { - let node = self.nodes[usize::from(class_id)].clone(); + let node = self.nodes.get(&class_id).unwrap().clone(); let class_id = self.find_mut(class_id); let node_data = N::make(self, &node); let class = self.classes.get_mut(&class_id).unwrap(); diff --git a/src/explain.rs b/src/explain.rs index e0d290fc..4225e7b7 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -82,7 +82,7 @@ pub struct Explain { pub(crate) struct ExplainNodes<'a, L: Language> { explain: &'a mut Explain, - nodes: &'a [L], + nodes: &'a HashMap, } #[derive(Default)] @@ -1047,7 +1047,7 @@ impl Explain { equalities } - pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> { + pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a HashMap) -> ExplainNodes<'a, L> { ExplainNodes { explain: self, nodes, @@ -1071,7 +1071,7 @@ impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> { impl<'x, L: Language> ExplainNodes<'x, L> { pub(crate) fn node(&self, node_id: Id) -> &L { - &self.nodes[usize::from(node_id)] + self.nodes.get(&node_id).unwrap() } fn node_to_explanation( &self, From d464edfe79d4fef2618b652e65a8ae78888635d2 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 27 Jun 2024 09:28:52 -0400 Subject: [PATCH 13/27] Remove enodes from both `nodes` and `memo` --- src/egraph.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 89f6ee12..66bd6622 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -183,13 +183,13 @@ impl> EGraph { /// This allows the egraph to explain why two expressions are /// equivalent with the [`explain_equivalence`](EGraph::explain_equivalence) function. pub fn with_explanations_enabled(mut self) -> Self { - if self.explain.is_some() { - return self; - } - if self.total_size() > 0 { - panic!("Need to set explanations enabled before adding any expressions to the egraph."); - } - self.explain = Some(Explain::new()); + // if self.explain.is_some() { + // return self; + // } + // if self.total_size() > 0 { + // panic!("Need to set explanations enabled before adding any expressions to the egraph."); + // } + // self.explain = Some(Explain::new()); self } @@ -1453,6 +1453,7 @@ impl> EGraph { // Remove unreachable enodes self.memo.retain(|enode, _| visited_enodes.contains(enode)); + self.nodes.retain(|_, enode| visited_enodes.contains(enode)); // Remove unreachable eclasses // TODO: Very ugly, maybe have `visited_eclasses` be a `HashMap`? From 9a39b02c5a013e1a82ba6a34036e2278bcfd30bb Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 28 Jun 2024 16:46:02 -0400 Subject: [PATCH 14/27] Avoid marking eclasses as visited before DFS traversal --- src/egraph.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 66bd6622..24a93eba 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1407,9 +1407,6 @@ impl> EGraph { return; } }); - if !eclass.is_empty() { - visited_eclasses.insert(*eclass_id); - } // Remove enode from parent arrays of children eclasses for eclass_id in enode_to_remove.children() { From ab35690edb7dbb27120aa56e52b95a849ca419aa Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 28 Jun 2024 16:48:28 -0400 Subject: [PATCH 15/27] Log undo progress and validate after removing enodes --- src/egraph.rs | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/egraph.rs b/src/egraph.rs index 24a93eba..8ac33cea 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1262,6 +1262,11 @@ impl> EGraph { { let rewrites_to_undo: Vec<_> = rewrites_to_undo.into_iter().collect(); let all_rewrites: Vec<_> = all_rewrites.into_iter().collect(); + info!( + "Undoing {} rewrites out of {} total", + rewrites_to_undo.len(), + all_rewrites.len() + ); // TODO: Maybe optimize by iterating and collecting `applier_enode_ids` // without collecting `matches` in-between? @@ -1375,6 +1380,9 @@ impl> EGraph { return; } + let num_enodes = self.nodes.len(); + let num_eclasses = self.classes.len(); + let roots: Vec = roots.iter().map(|id| self.find_mut(*id)).collect(); let mut visited_eclasses = HashSet::::default(); @@ -1400,7 +1408,7 @@ impl> EGraph { // TODO: I'm not sure why the Err path is ever taken, needs // investigation Err(_) => { - println!( + warn!( "enode to remove ({:?}) not found in eclass: {:?}", enode_to_remove, eclass ); @@ -1469,6 +1477,23 @@ impl> EGraph { op.remove(&eclass_id); }); } + + // Validate remaining enodes' children + #[cfg(debug_assertions)] + { + for (_, enode) in self.nodes.iter() { + dbg!(enode); + for child in enode.children() { + self.find(*child); + } + } + } + + info!( + "Removed {} enodes and {} eclasses", + num_enodes - self.nodes.len(), + num_eclasses - self.classes.len() + ); } /// Update the analysis data of an e-class. From a0969797f3099ccc17707868a42f076365ec2aa6 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 28 Jun 2024 17:18:43 -0400 Subject: [PATCH 16/27] Test undoing all rewrites every iteration, remove later --- src/run.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/run.rs b/src/run.rs index ae78d84e..a9354ed6 100644 --- a/src/run.rs +++ b/src/run.rs @@ -422,6 +422,9 @@ where self.stop_reason = Some(stop_reason); break; } + + self.egraph + .undo_rewrites(rules.clone(), rules.clone(), self.roots.clone()); } assert!(!self.iterations.is_empty()); From e1eab90817aaf7a314a7dd76fc702fa197c46470 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 28 Jun 2024 17:49:16 -0400 Subject: [PATCH 17/27] Log more of `remove_enodes` --- src/egraph.rs | 32 +++++++++++++++++++++++++------- tests/simple.rs | 1 + 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 8ac33cea..6f4da811 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1262,11 +1262,19 @@ impl> EGraph { { let rewrites_to_undo: Vec<_> = rewrites_to_undo.into_iter().collect(); let all_rewrites: Vec<_> = all_rewrites.into_iter().collect(); + info!( "Undoing {} rewrites out of {} total", rewrites_to_undo.len(), all_rewrites.len() ); + debug!( + "Rewrites to undo: {:?}", + rewrites_to_undo + .iter() + .map(|rewrite| rewrite.name) + .collect::>() + ); // TODO: Maybe optimize by iterating and collecting `applier_enode_ids` // without collecting `matches` in-between? @@ -1380,6 +1388,15 @@ impl> EGraph { return; } + debug!( + "Removing specified enodes: {:?}", + enode_ids + .iter() + .map(|id| self.id_to_node(*id)) + .collect::>() + ); + trace!("EGraph before removing enodes:\n{:?}", self.dump()); + let num_enodes = self.nodes.len(); let num_eclasses = self.classes.len(); @@ -1398,9 +1415,6 @@ impl> EGraph { let eclass_id = &self.find_mut(enode_id); let eclass = self.classes.get_mut(eclass_id).unwrap(); - // TODO: Is it faster to use `swap_remove`? Not sure if that's even - // possible because it would break binary search after the first - // `swap_remove` eclass .nodes .remove(match eclass.nodes.binary_search(&enode_to_remove) { @@ -1478,11 +1492,13 @@ impl> EGraph { }); } - // Validate remaining enodes' children + trace!("EGraph after removing enodes:\n{:?}", self.dump()); + + // Validate children of remaining enodes #[cfg(debug_assertions)] { for (_, enode) in self.nodes.iter() { - dbg!(enode); + debug!("Validating children of remaining enode {:?}", enode); for child in enode.children() { self.find(*child); } @@ -1490,9 +1506,11 @@ impl> EGraph { } info!( - "Removed {} enodes and {} eclasses", + "Removed {} enodes ({} remaining) and {} eclasses ({} remaining)", num_enodes - self.nodes.len(), - num_eclasses - self.classes.len() + self.nodes.len(), + num_eclasses - self.classes.len(), + self.classes.len() ); } diff --git a/tests/simple.rs b/tests/simple.rs index 9261a4c0..49a1fb2e 100644 --- a/tests/simple.rs +++ b/tests/simple.rs @@ -40,6 +40,7 @@ fn simplify(s: &str) -> String { #[test] fn simple_tests() { + let _ = env_logger::builder().is_test(true).try_init(); assert_eq!(simplify("(* 0 42)"), "0"); assert_eq!(simplify("(+ 0 (* 1 foo))"), "foo"); } From 648d328e1ca85aec32c21fa057c46bd4d5af0e33 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sun, 30 Jun 2024 10:19:54 -0400 Subject: [PATCH 18/27] Require clean egraph before finding an enode by ID --- src/egraph.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 6f4da811..11cc00fb 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1355,17 +1355,25 @@ impl> EGraph { /// let enode = egraph.id_to_node(enode_id); /// ``` fn find_enode_id(&self, pattern: &[ENodeOrVar], subst: &Subst) -> Option<&Id> { + // Pretty sure this is required since finding an instantiated enode + // relies on the egraph's enodes having canonicalized children + // TODO: Better explanation + assert!( + self.clean, + "Cannot remove enodes without a clean egraph, try rebuilding" + ); + let mut id_buf: Vec = vec![0.into(); pattern.len()]; let mut candidate: Option<&Id> = None; for (i, enode_or_var) in pattern.iter().enumerate() { let id = match enode_or_var { ENodeOrVar::Var(var) => *subst.get(*var)?, ENodeOrVar::ENode(enode) => { - let substituted_enode = enode + let instantiated_enode = enode .clone() .map_children(|child| id_buf[usize::from(child)]); - candidate = self.memo.get(&substituted_enode); - self.lookup(substituted_enode)? + candidate = self.memo.get(&instantiated_enode); + self.lookup(instantiated_enode)? } }; id_buf[i] = id; From 1c0616b9507c2f144c13df27ed1137da33d88581 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sun, 30 Jun 2024 20:44:25 -0400 Subject: [PATCH 19/27] Avoid removing enodes that would leave dangling children Oddly this appears to result in more failures in `cargo test --test math`. --- src/egraph.rs | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 11cc00fb..6b95554c 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1313,7 +1313,7 @@ impl> EGraph { }) .collect(); - let applier_enode_ids: HashSet = patterns_and_matches + let mut applier_enode_ids: HashSet = patterns_and_matches .into_iter() .flat_map(|(applier_pat, all_search_matches)| { zip(repeat(applier_pat), all_search_matches.into_iter()) @@ -1381,8 +1381,13 @@ impl> EGraph { candidate } - /// Removes specified enodes and cleans up the resulting egraph, in - /// particular by removing unreachable eclasses and enodes. + /// Removes specified enodes (except when it would leave dangling children) + /// and cleans up the resulting egraph, in particular by removing + /// unreachable eclasses and enodes. + /// + /// For an enode that has parents (which aren't being removed) and is the + /// only member of its eclass, removing it when leave a dangling child. In + /// this case, the enode is not removed. fn remove_enodes(&mut self, enode_ids: HashSet, roots: Vec) { // Pretty sure this is required since e.g. rebuilding dedups enodes in // eclasses, `remove_enodes` can't handle duplicate enodes @@ -1413,16 +1418,32 @@ impl> EGraph { let mut visited_eclasses = HashSet::::default(); // Remove the input enodes from their corresponding eclasses - for enode_id in enode_ids { + 'enode_removal: for enode_id in &enode_ids { // Canonicalize enode's children so it can be found in // `eclass.nodes` let enode_to_remove = self - .id_to_node(enode_id) + .id_to_node(*enode_id) .clone() .map_children(|child| self.find_mut(child)); - let eclass_id = &self.find_mut(enode_id); - let eclass = self.classes.get_mut(eclass_id).unwrap(); + let eclass_id = &self.find_mut(*enode_id); + + let eclass = self.classes.get(eclass_id).unwrap(); + if eclass.nodes.len() == 1 && !eclass.parents.is_empty() { + for parent_id in &eclass.parents { + // Ancestor of enode is not specified for removal, so don't + // remove enode to avoid a dangling child + if !enode_ids.contains(parent_id) { + info!( + "Skipping removal of {:?}, which would become a dangling child", + &enode_to_remove + ); + continue 'enode_removal; + } + } + } + // TODO: Shadowing shenanigans to avoid angering the borrow checker + let eclass = self.classes.get_mut(eclass_id).unwrap(); eclass .nodes .remove(match eclass.nodes.binary_search(&enode_to_remove) { @@ -1445,7 +1466,7 @@ impl> EGraph { eclass .parents .iter() - .position(|&parent_id| parent_id == enode_id) + .position(|parent_id| parent_id == enode_id) .expect("enode should be in parents array of its children eclasses"), ); } From f14f1ec53a38fcc4867a48a563ba89eb28ea434e Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sun, 30 Jun 2024 21:40:50 -0400 Subject: [PATCH 20/27] Preserve original enodes instead of `Searcher` patterns --- src/egraph.rs | 51 ++++++++++++++++++++------------------------------- src/run.rs | 13 ++++++++++--- 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 6b95554c..854278d4 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -136,6 +136,14 @@ impl> EGraph { self.classes.values_mut() } + /// Returns an iterator over (enode, non-canonical ID) entries in the egraph. + /// + /// TODO: Better docs or find a cleaner way to store the original egraph + /// (before rewriting occurs). + pub fn nodes(&self) -> impl ExactSizeIterator { + self.nodes.iter() + } + /// Returns `true` if the egraph is empty /// # Example /// ``` @@ -1254,8 +1262,13 @@ impl> EGraph { /// terms. /// /// TODO: Above explanation might be confusing, write a better one. - pub fn undo_rewrites<'a, R>(&mut self, rewrites_to_undo: R, all_rewrites: R, roots: Vec) - where + pub fn undo_rewrites<'a, R>( + &mut self, + rewrites_to_undo: R, + all_rewrites: R, + roots: &[Id], + original_enode_ids: &HashSet, + ) where R: IntoIterator>, L: 'a, N: 'a, @@ -1298,22 +1311,7 @@ impl> EGraph { return; } - // TODO: Feels hacky to have to reparse `Searcher` and `Applier`, is - // there a better way? - let mut maybe_colliding_searcher_patterns: Vec> = all_rewrites - .iter() - .map(|rewrite| { - Pattern::from( - rewrite - .searcher - .get_pattern_ast() - .expect("Searcher (LHS) of rewrite rule should be a pattern") - .clone(), - ) - }) - .collect(); - - let mut applier_enode_ids: HashSet = patterns_and_matches + let applier_enode_ids: HashSet = patterns_and_matches .into_iter() .flat_map(|(applier_pat, all_search_matches)| { zip(repeat(applier_pat), all_search_matches.into_iter()) @@ -1324,17 +1322,8 @@ impl> EGraph { .filter_map(|(applier_pat, subst)| { let applier_enode_id = self.find_enode_id(applier_pat.ast.as_ref(), &subst)?; - // Check for collisions with any searcher pattern. If any are - // found, do not mark the enode for removal. Since at most one - // instance of each searcher pattern must exist in the egraph, - // remove the colliding searcher pattern(s) as candidates for - // collision checking. - let before = maybe_colliding_searcher_patterns.len(); - maybe_colliding_searcher_patterns.retain(|searcher_pat| { - self.find_enode_id(searcher_pat.ast.as_ref(), &subst) - .is_none() - }); - if before > maybe_colliding_searcher_patterns.len() { + // Always preserve the original enodes + if original_enode_ids.contains(applier_enode_id) { return None; } @@ -1388,7 +1377,7 @@ impl> EGraph { /// For an enode that has parents (which aren't being removed) and is the /// only member of its eclass, removing it when leave a dangling child. In /// this case, the enode is not removed. - fn remove_enodes(&mut self, enode_ids: HashSet, roots: Vec) { + fn remove_enodes(&mut self, enode_ids: HashSet, roots: &[Id]) { // Pretty sure this is required since e.g. rebuilding dedups enodes in // eclasses, `remove_enodes` can't handle duplicate enodes // TODO: Better explanation @@ -1430,7 +1419,7 @@ impl> EGraph { let eclass = self.classes.get(eclass_id).unwrap(); if eclass.nodes.len() == 1 && !eclass.parents.is_empty() { for parent_id in &eclass.parents { - // Ancestor of enode is not specified for removal, so don't + // Parent of enode is not specified for removal, so don't // remove enode to avoid a dangling child if !enode_ids.contains(parent_id) { info!( diff --git a/src/run.rs b/src/run.rs index a9354ed6..700262e1 100644 --- a/src/run.rs +++ b/src/run.rs @@ -412,6 +412,13 @@ where let rules: Vec<&Rewrite> = rules.into_iter().collect(); check_rules(&rules); self.egraph.rebuild(); + + let original_enodes = self + .egraph + .nodes() + .map(|(id, _)| *id) + .collect::>(); + loop { let iter = self.run_one(&rules); self.iterations.push(iter); @@ -422,11 +429,11 @@ where self.stop_reason = Some(stop_reason); break; } - - self.egraph - .undo_rewrites(rules.clone(), rules.clone(), self.roots.clone()); } + self.egraph + .undo_rewrites(rules.clone(), rules.clone(), &self.roots, &original_enodes); + assert!(!self.iterations.is_empty()); assert!(self.stop_reason.is_some()); self From 93e6e6447acbc9d703e6ac45d998e61a4ddb2256 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Sun, 30 Jun 2024 23:19:13 -0400 Subject: [PATCH 21/27] Panic when enode to remove not in eclass, instead of warning Turns out the math.rs tests were modifying the egraph (see https://github.com/egraphs-good/egg/blob/ae2db378d25dbe55046d96a7eef25ee9f2916058/tests/math.rs#L98-L99), so the `None` path being taken is actually a bug if enodes aren't being removed by my code (`remove_enodes` specifically). --- src/egraph.rs | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 854278d4..37b0d038 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1391,7 +1391,7 @@ impl> EGraph { } debug!( - "Removing specified enodes: {:?}", + "Enodes specified for removal (children may be uncanonical): {:?}", enode_ids .iter() .map(|id| self.id_to_node(*id)) @@ -1435,18 +1435,17 @@ impl> EGraph { let eclass = self.classes.get_mut(eclass_id).unwrap(); eclass .nodes - .remove(match eclass.nodes.binary_search(&enode_to_remove) { - Ok(idx) => idx, - // TODO: I'm not sure why the Err path is ever taken, needs - // investigation - Err(_) => { - warn!( - "enode to remove ({:?}) not found in eclass: {:?}", - enode_to_remove, eclass - ); - return; - } - }); + .remove( + eclass + .nodes + .binary_search(&enode_to_remove) + .unwrap_or_else(|_| { + panic!( + "Enode to remove ({:?}) not found in eclass: {:?}\nMost likely the result of external code removing enodes from the egraph", + enode_to_remove, eclass + ) + }), + ); // Remove enode from parent arrays of children eclasses for eclass_id in enode_to_remove.children() { @@ -1459,6 +1458,8 @@ impl> EGraph { .expect("enode should be in parents array of its children eclasses"), ); } + + debug!("Removed {:?}", self.id_to_node(*enode_id)); } let mut visited_enodes = HashSet::::default(); From 304a901832ab3d244b170cd08af08a59c3493f1a Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Mon, 1 Jul 2024 01:14:33 -0400 Subject: [PATCH 22/27] Fix skipping removal of original enodes Compare by value instead of ID, since an enode may have multiple IDs (at least in `egraph.nodes`). --- src/egraph.rs | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 37b0d038..0a6a597c 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1265,22 +1265,16 @@ impl> EGraph { pub fn undo_rewrites<'a, R>( &mut self, rewrites_to_undo: R, - all_rewrites: R, roots: &[Id], - original_enode_ids: &HashSet, + original_enodes: &HashSet, ) where R: IntoIterator>, L: 'a, N: 'a, { let rewrites_to_undo: Vec<_> = rewrites_to_undo.into_iter().collect(); - let all_rewrites: Vec<_> = all_rewrites.into_iter().collect(); - info!( - "Undoing {} rewrites out of {} total", - rewrites_to_undo.len(), - all_rewrites.len() - ); + info!("Undoing {} rewrites", rewrites_to_undo.len(),); debug!( "Rewrites to undo: {:?}", rewrites_to_undo @@ -1322,8 +1316,22 @@ impl> EGraph { .filter_map(|(applier_pat, subst)| { let applier_enode_id = self.find_enode_id(applier_pat.ast.as_ref(), &subst)?; - // Always preserve the original enodes - if original_enode_ids.contains(applier_enode_id) { + let canonicalized_enode = self + .id_to_node(*applier_enode_id) + .clone() + .map_children(|child| self.find(child)); + + // Always preserve the original enodes. Does not compare by ID + // since an enode may have multiple IDs (at least in + // `self.nodes`). + // + // TODO: Is there a way to use IDs? Or a way this would have + // false negatives? + if original_enodes.contains(&canonicalized_enode) { + debug!( + "Skip marking {:?} for removal since it's in the original egraph", + canonicalized_enode + ); return None; } @@ -1391,7 +1399,7 @@ impl> EGraph { } debug!( - "Enodes specified for removal (children may be uncanonical): {:?}", + "Enodes which may be removed (with uncanonical children): {:?}", enode_ids .iter() .map(|id| self.id_to_node(*id)) @@ -1416,6 +1424,7 @@ impl> EGraph { .map_children(|child| self.find_mut(child)); let eclass_id = &self.find_mut(*enode_id); + // Skip if removing enode would leave a dangling child let eclass = self.classes.get(eclass_id).unwrap(); if eclass.nodes.len() == 1 && !eclass.parents.is_empty() { for parent_id in &eclass.parents { @@ -1423,7 +1432,7 @@ impl> EGraph { // remove enode to avoid a dangling child if !enode_ids.contains(parent_id) { info!( - "Skipping removal of {:?}, which would become a dangling child", + "Skipping removal of {:?} to avoid becoming a dangling child", &enode_to_remove ); continue 'enode_removal; @@ -1458,8 +1467,6 @@ impl> EGraph { .expect("enode should be in parents array of its children eclasses"), ); } - - debug!("Removed {:?}", self.id_to_node(*enode_id)); } let mut visited_enodes = HashSet::::default(); From a05863a7876d16b7015af97caf8530c1dce2c3e5 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Mon, 1 Jul 2024 22:29:50 -0400 Subject: [PATCH 23/27] Extract removing unreachable enodes into function Also, hardcode running inverse equality saturation every iteration for now. --- src/egraph.rs | 62 ++++++++++++++++++++++++--------------------------- src/run.rs | 11 +++++---- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 0a6a597c..8fd4ce74 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -191,13 +191,13 @@ impl> EGraph { /// This allows the egraph to explain why two expressions are /// equivalent with the [`explain_equivalence`](EGraph::explain_equivalence) function. pub fn with_explanations_enabled(mut self) -> Self { - // if self.explain.is_some() { - // return self; - // } - // if self.total_size() > 0 { - // panic!("Need to set explanations enabled before adding any expressions to the egraph."); - // } - // self.explain = Some(Explain::new()); + if self.explain.is_some() { + return self; + } + if self.total_size() > 0 { + panic!("Need to set explanations enabled before adding any expressions to the egraph."); + } + self.explain = Some(Explain::new()); self } @@ -1251,17 +1251,10 @@ impl> EGraph { true } - /// Inverse equality saturation - /// - /// `roots` can be non-canonical + /// Inverse equality saturation. /// - /// Note that "undoing" rewrites will not necessarily return the egraph to a - /// previous state. This method removes all but one equivalent - /// representation based on the rewrite rules to undo, but that equivalent - /// representation might not be the original one. It might even have more - /// terms. - /// - /// TODO: Above explanation might be confusing, write a better one. + /// Removes RHS instances of the input rewrites, but will not remove enodes + /// belonging to the original egraph (before rewriting). pub fn undo_rewrites<'a, R>( &mut self, rewrites_to_undo: R, @@ -1283,8 +1276,6 @@ impl> EGraph { .collect::>() ); - // TODO: Maybe optimize by iterating and collecting `applier_enode_ids` - // without collecting `matches` in-between? let patterns_and_matches: Vec<(Pattern, Vec>)> = zip( rewrites_to_undo.iter().map(|rewrite| { Pattern::from( @@ -1410,10 +1401,6 @@ impl> EGraph { let num_enodes = self.nodes.len(); let num_eclasses = self.classes.len(); - let roots: Vec = roots.iter().map(|id| self.find_mut(*id)).collect(); - - let mut visited_eclasses = HashSet::::default(); - // Remove the input enodes from their corresponding eclasses 'enode_removal: for enode_id in &enode_ids { // Canonicalize enode's children so it can be found in @@ -1469,7 +1456,24 @@ impl> EGraph { } } + self.remove_unreachable(roots); + + info!( + "Removed {} enodes ({} remaining) and {} eclasses ({} remaining)", + num_enodes - self.nodes.len(), + self.nodes.len(), + num_eclasses - self.classes.len(), + self.classes.len() + ); + } + + fn remove_unreachable(&mut self, roots: &[Id]) { + // Canonicalize roots' children + let roots: Vec = roots.iter().map(|id| self.find_mut(*id)).collect(); + + let mut visited_eclasses = HashSet::::default(); let mut visited_enodes = HashSet::::default(); + let mut dfs_stack: Vec<&L> = roots .iter() .flat_map(|id| match self.classes.get(id) { @@ -1501,7 +1505,7 @@ impl> EGraph { self.nodes.retain(|_, enode| visited_enodes.contains(enode)); // Remove unreachable eclasses - // TODO: Very ugly, maybe have `visited_eclasses` be a `HashMap`? + // TODO: Verbose and possibly slow, maybe have `visited_eclasses` be a `HashMap`? let unreachable_eclasses = self .classes .keys() @@ -1520,7 +1524,7 @@ impl> EGraph { trace!("EGraph after removing enodes:\n{:?}", self.dump()); - // Validate children of remaining enodes + // Check for existence of remaining enodes' children #[cfg(debug_assertions)] { for (_, enode) in self.nodes.iter() { @@ -1530,14 +1534,6 @@ impl> EGraph { } } } - - info!( - "Removed {} enodes ({} remaining) and {} eclasses ({} remaining)", - num_enodes - self.nodes.len(), - self.nodes.len(), - num_eclasses - self.classes.len(), - self.classes.len() - ); } /// Update the analysis data of an e-class. diff --git a/src/run.rs b/src/run.rs index 700262e1..b1450a18 100644 --- a/src/run.rs +++ b/src/run.rs @@ -415,9 +415,9 @@ where let original_enodes = self .egraph - .nodes() - .map(|(id, _)| *id) - .collect::>(); + .classes() + .flat_map(|eclass| eclass.nodes.clone()) + .collect::>(); loop { let iter = self.run_one(&rules); @@ -429,11 +429,10 @@ where self.stop_reason = Some(stop_reason); break; } + self.egraph + .undo_rewrites(rules.clone(), &self.roots, &original_enodes); } - self.egraph - .undo_rewrites(rules.clone(), rules.clone(), &self.roots, &original_enodes); - assert!(!self.iterations.is_empty()); assert!(self.stop_reason.is_some()); self From c2d688e939f51256d46fb8bcc37458ed38e69b1b Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Mon, 1 Jul 2024 22:40:56 -0400 Subject: [PATCH 24/27] Remove unused `nodes` method and add TODO for `undo_rewrites` --- src/egraph.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 8fd4ce74..0b3b1e7a 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -136,14 +136,6 @@ impl> EGraph { self.classes.values_mut() } - /// Returns an iterator over (enode, non-canonical ID) entries in the egraph. - /// - /// TODO: Better docs or find a cleaner way to store the original egraph - /// (before rewriting occurs). - pub fn nodes(&self) -> impl ExactSizeIterator { - self.nodes.iter() - } - /// Returns `true` if the egraph is empty /// # Example /// ``` @@ -1255,6 +1247,8 @@ impl> EGraph { /// /// Removes RHS instances of the input rewrites, but will not remove enodes /// belonging to the original egraph (before rewriting). + /// + /// TODO: Example pub fn undo_rewrites<'a, R>( &mut self, rewrites_to_undo: R, From 2c87e4ccf389d964f392d3746d4d71edbe3ff7d8 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 25 Jul 2024 21:15:16 -0400 Subject: [PATCH 25/27] Remove unreachable eclass parents when cleaning egraph Also improve checking egraph validity when `cfg(debug_assertions)` --- src/egraph.rs | 93 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 38 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index 0b3b1e7a..a51eef48 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -183,6 +183,7 @@ impl> EGraph { /// This allows the egraph to explain why two expressions are /// equivalent with the [`explain_equivalence`](EGraph::explain_equivalence) function. pub fn with_explanations_enabled(mut self) -> Self { + return self; if self.explain.is_some() { return self; } @@ -1293,10 +1294,10 @@ impl> EGraph { let applier_enode_ids: HashSet = patterns_and_matches .into_iter() .flat_map(|(applier_pat, all_search_matches)| { - zip(repeat(applier_pat), all_search_matches.into_iter()) + zip(repeat(applier_pat), all_search_matches) }) .flat_map(|(applier_pat, search_matches)| { - zip(repeat(applier_pat), search_matches.substs.into_iter()) + zip(repeat(applier_pat), search_matches.substs) }) .filter_map(|(applier_pat, subst)| { let applier_enode_id = self.find_enode_id(applier_pat.ast.as_ref(), &subst)?; @@ -1405,22 +1406,24 @@ impl> EGraph { .map_children(|child| self.find_mut(child)); let eclass_id = &self.find_mut(*enode_id); - // Skip if removing enode would leave a dangling child + // Skip if eclass is a singleton and has parents + // + // TODO: Might be skipping removals that are valid? let eclass = self.classes.get(eclass_id).unwrap(); if eclass.nodes.len() == 1 && !eclass.parents.is_empty() { - for parent_id in &eclass.parents { - // Parent of enode is not specified for removal, so don't - // remove enode to avoid a dangling child - if !enode_ids.contains(parent_id) { - info!( - "Skipping removal of {:?} to avoid becoming a dangling child", - &enode_to_remove - ); - continue 'enode_removal; - } - } + info!( + "Skipping removal of {:?} to avoid becoming a dangling child", + &enode_to_remove + ); + continue 'enode_removal; } + trace!( + "Removing enode {:?} with id {:?}", + self.id_to_node(*enode_id), + enode_id + ); + // TODO: Shadowing shenanigans to avoid angering the borrow checker let eclass = self.classes.get_mut(eclass_id).unwrap(); eclass @@ -1466,7 +1469,7 @@ impl> EGraph { let roots: Vec = roots.iter().map(|id| self.find_mut(*id)).collect(); let mut visited_eclasses = HashSet::::default(); - let mut visited_enodes = HashSet::::default(); + let mut visited_enodes = HashSet::::default(); let mut dfs_stack: Vec<&L> = roots .iter() @@ -1482,7 +1485,7 @@ impl> EGraph { let eclass_id = self.find(enode_id); visited_eclasses.insert(eclass_id); - visited_enodes.insert(enode.clone()); + visited_enodes.insert(enode_id); let children_enodes: Vec<&L> = enode .children() @@ -1495,38 +1498,52 @@ impl> EGraph { } // Remove unreachable enodes - self.memo.retain(|enode, _| visited_enodes.contains(enode)); - self.nodes.retain(|_, enode| visited_enodes.contains(enode)); - - // Remove unreachable eclasses - // TODO: Verbose and possibly slow, maybe have `visited_eclasses` be a `HashMap`? - let unreachable_eclasses = self - .classes - .keys() - .copied() - .collect::>() - .difference(&visited_eclasses) - .copied() - .collect::>(); - for eclass_id in unreachable_eclasses { - self.unionfind.delete(eclass_id); - self.classes.remove(&eclass_id); - self.classes_by_op.values_mut().for_each(|op| { - op.remove(&eclass_id); - }); + self.memo + .retain(|_, enode_id| visited_enodes.contains(enode_id)); + self.nodes + .retain(|enode_id, _| visited_enodes.contains(enode_id)); + + // Clean up eclasses + for eclass_id in self.classes.keys().copied().collect::>() { + if visited_eclasses.contains(&eclass_id) { + // Remove unreachable parents + let eclass = self.classes.get_mut(&eclass_id).unwrap(); + eclass + .parents + .retain(|parent_id| visited_enodes.contains(parent_id)); + } else { + // Remove unreachable eclasses + self.unionfind.delete(eclass_id); + self.classes.remove(&eclass_id); + self.classes_by_op.values_mut().for_each(|op| { + op.remove(&eclass_id); + }); + } } trace!("EGraph after removing enodes:\n{:?}", self.dump()); - // Check for existence of remaining enodes' children #[cfg(debug_assertions)] { - for (_, enode) in self.nodes.iter() { - debug!("Validating children of remaining enode {:?}", enode); + // Check for existence of remaining enodes' children + for (enode_id, enode) in self.nodes.iter() { + debug!( + "Validating children of remaining enode {:?} with id {:?}", + enode, enode_id + ); for child in enode.children() { + trace!("Looking for child {:?}", child); self.find(*child); } } + + // Check for existence of remaining eclasses' parents + for (eclass_id, eclass) in self.classes.iter() { + debug!("Validating parents of remaining eclass {:?}", eclass_id); + for parent in &eclass.parents { + self.find(*parent); + } + } } } From 6ea5716d56bfe723f54ae144c63a8218b43734c0 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 30 Jul 2024 21:03:54 -0400 Subject: [PATCH 26/27] Add `InverseScheduler` --- src/egraph.rs | 1 - src/run.rs | 110 ++++++++++++++++++++++++++++++++++++++++++++--- src/unionfind.rs | 2 +- 3 files changed, 105 insertions(+), 8 deletions(-) diff --git a/src/egraph.rs b/src/egraph.rs index a51eef48..d45aa579 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -183,7 +183,6 @@ impl> EGraph { /// This allows the egraph to explain why two expressions are /// equivalent with the [`explain_equivalence`](EGraph::explain_equivalence) function. pub fn with_explanations_enabled(mut self) -> Self { - return self; if self.explain.is_some() { return self; } diff --git a/src/run.rs b/src/run.rs index b1450a18..14ca17c9 100644 --- a/src/run.rs +++ b/src/run.rs @@ -413,7 +413,7 @@ where check_rules(&rules); self.egraph.rebuild(); - let original_enodes = self + let _original_enodes = self .egraph .classes() .flat_map(|eclass| eclass.nodes.clone()) @@ -429,8 +429,6 @@ where self.stop_reason = Some(stop_reason); break; } - self.egraph - .undo_rewrites(rules.clone(), &self.roots, &original_enodes); } assert!(!self.iterations.is_empty()); @@ -549,7 +547,7 @@ where let mut applied = IndexMap::default(); result = result.and_then(|_| { rules.iter().try_for_each(|rw| { - let ms = self.scheduler.search_rewrite(i, &self.egraph, rw); + let ms = self.scheduler.search_rewrite(i, &mut self.egraph, rw); matches.push(ms); self.check_limits() }) @@ -695,7 +693,7 @@ where fn search_rewrite<'a>( &mut self, iteration: usize, - egraph: &EGraph, + egraph: &mut EGraph, rewrite: &'a Rewrite, ) -> Vec> { rewrite.search(egraph) @@ -875,7 +873,7 @@ where fn search_rewrite<'a>( &mut self, iteration: usize, - egraph: &EGraph, + egraph: &mut EGraph, rewrite: &'a Rewrite, ) -> Vec> { let stats = self.rule_stats(rewrite.name); @@ -915,6 +913,106 @@ where } } +#[derive(Debug)] +struct UndoStats { + times_applied: usize, + times_undone: usize, + match_limit: usize, +} + +/// TODO: docs +pub struct InverseScheduler { + default_match_limit: usize, + stats: IndexMap, + roots: Vec, + original_enodes: HashSet, +} + +impl InverseScheduler { + /// TODO: docs + pub fn new(default_match_limit: usize, roots: T, original_enodes: HashSet) -> Self + where + T: IntoIterator, + { + Self { + default_match_limit, + stats: Default::default(), + roots: roots.into_iter().collect(), + original_enodes, + } + } + + /// Set the initial match limit after which a rule will be undone. + /// Default: 1,000 + pub fn with_initial_match_limit(mut self, limit: usize) -> Self { + self.default_match_limit = limit; + self + } + + fn rule_stats(&mut self, name: Symbol) -> &mut UndoStats { + if self.stats.contains_key(&name) { + &mut self.stats[&name] + } else { + self.stats.entry(name).or_insert(UndoStats { + times_applied: 0, + times_undone: 0, + match_limit: self.default_match_limit, + }) + } + } + + /// Never undo a particular rule. + pub fn do_not_undo(mut self, name: impl Into) -> Self { + self.rule_stats(name.into()).match_limit = usize::MAX; + self + } + + /// Set the initial match limit for a rule. + pub fn rule_match_limit(mut self, name: impl Into, limit: usize) -> Self { + self.rule_stats(name.into()).match_limit = limit; + self + } +} + +impl RewriteScheduler for InverseScheduler +where + L: Language, + N: Analysis, +{ + fn can_stop(&mut self, _iteration: usize) -> bool { + // TODO: Reapply rules that were undone "too many times"? How to define that? + true + } + + fn search_rewrite<'a>( + &mut self, + _iteration: usize, + egraph: &mut EGraph, + rewrite: &'a Rewrite, + ) -> Vec> { + let stats = self.rule_stats(rewrite.name); + + let threshold = stats + .match_limit + .checked_shl(stats.times_undone as u32) + .unwrap(); + let matches = rewrite.search_with_limit(egraph, threshold.saturating_add(1)); + let total_len: usize = matches.iter().map(|m| m.substs.len()).sum(); + if total_len > threshold { + stats.times_undone += 1; + info!( + "Undoing {} (applied {}, undone{})", + rewrite.name, stats.times_applied, stats.times_undone, + ); + egraph.undo_rewrites([rewrite], self.roots.as_slice(), &self.original_enodes); + vec![] + } else { + stats.times_applied += 1; + matches + } + } +} + /// Custom data to inject into the [`Iteration`]s recorded by a [`Runner`] /// /// This trait allows you to add custom data to the [`Iteration`]s diff --git a/src/unionfind.rs b/src/unionfind.rs index 2be36b22..3e213274 100644 --- a/src/unionfind.rs +++ b/src/unionfind.rs @@ -25,7 +25,7 @@ impl UnionFind { } fn parent_mut(&mut self, query: Id) -> Option<&mut Id> { - (&mut self.parents[usize::from(query)]).as_mut() + self.parents[usize::from(query)].as_mut() } pub fn find(&self, mut current: Id) -> Option { From 8caa4606026b0b7779cd2e1758e1932daf978b74 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Tue, 13 Aug 2024 20:09:26 -0400 Subject: [PATCH 27/27] wip --- math_baseline.csv | 85 ++++++++++++++++++++++++++++++++++++++++++ math_inverse.csv | 95 +++++++++++++++++++++++++++++++++++++++++++++++ src/egraph.rs | 3 +- src/run.rs | 6 --- src/test.rs | 11 +++++- tests/math.rs | 2 +- 6 files changed, 193 insertions(+), 9 deletions(-) create mode 100644 math_baseline.csv create mode 100644 math_inverse.csv diff --git a/math_baseline.csv b/math_baseline.csv new file mode 100644 index 00000000..962e4b82 --- /dev/null +++ b/math_baseline.csv @@ -0,0 +1,85 @@ +diff_power_simple,0.001590461 +integ_one,0.000020708000000000002 +integ_part1,0.001072126 +integ_part3,0.000442666 +integ_sin,0.000019125 +integ_x,0.000025583999999999998 +math_associate_adds,0.005160467 +math_diff_different,0.00002425 +math_diff_ln,0.000018791 +math_diff_same,0.000017459 +math_diff_simple1,0.000317501 +math_diff_simple2,0.00026687600000000004 +math_powers,0.000028751 +math_simplify_add,0.000265459 +math_simplify_const,0.000117041 +math_simplify_factor,0.009269597000000001 +math_simplify_root,0.026495581000000004 +diff_power_simple,0.001554334 +integ_one,0.0000215 +integ_part1,0.001085791 +integ_part3,0.00046200099999999997 +integ_sin,0.000018624 +integ_x,0.000024042 +math_associate_adds,0.005166336 +math_diff_different,0.000019375 +math_diff_ln,0.000019208 +math_diff_same,0.000017625 +math_diff_simple1,0.000315541 +math_diff_simple2,0.00028537500000000004 +math_powers,0.000029749999999999998 +math_simplify_add,0.00027008399999999997 +math_simplify_const,0.00011716700000000001 +math_simplify_factor,0.009123505 +math_simplify_root,0.026413933000000004 +diff_power_simple,0.001597454 +integ_one,0.000021458 +integ_part1,0.001163789 +integ_part3,0.000451706 +integ_sin,0.000019584 +integ_x,0.000025625 +math_associate_adds,0.005248073000000001 +math_diff_different,0.000022709 +math_diff_ln,0.000024376000000000002 +math_diff_same,0.000018999000000000002 +math_diff_simple1,0.000323416 +math_diff_simple2,0.00027429200000000004 +math_powers,0.000028583 +math_simplify_add,0.000304208 +math_simplify_const,0.000122458 +math_simplify_factor,0.009213397 +math_simplify_root,0.026768194 +diff_power_simple,0.001610165 +integ_one,0.000018707999999999998 +integ_part1,0.001082832 +integ_part3,0.000467209 +integ_sin,0.000019916 +integ_x,0.000025583 +math_associate_adds,0.005180661 +math_diff_different,0.00002175 +math_diff_ln,0.000021167000000000002 +math_diff_same,0.000016791999999999998 +math_diff_simple1,0.000325459 +math_diff_simple2,0.00027145800000000003 +math_powers,0.000031459 +math_simplify_add,0.000271375 +math_simplify_const,0.000121082 +math_simplify_factor,0.009118949 +math_simplify_root,0.026393884 +diff_power_simple,0.002519455 +integ_one,0.000023375000000000002 +integ_part1,0.0011277070000000001 +integ_part3,0.00046670800000000003 +integ_sin,0.000020917000000000003 +integ_x,0.000043083 +math_associate_adds,0.0052744079999999995 +math_diff_different,0.000023209 +math_diff_ln,0.000016833 +math_diff_same,0.000016624999999999998 +math_diff_simple1,0.000667998 +math_diff_simple2,0.00028612500000000003 +math_powers,0.000031791999999999996 +math_simplify_add,0.000274334 +math_simplify_const,0.000114417 +math_simplify_factor,0.009354527 +math_simplify_root,0.027854118 diff --git a/math_inverse.csv b/math_inverse.csv new file mode 100644 index 00000000..97647060 --- /dev/null +++ b/math_inverse.csv @@ -0,0 +1,95 @@ +diff_power_harder,0.018302689 +diff_power_simple,0.000444457 +integ_one,0.000022042 +integ_part1,0.000722582 +integ_part2,0.002940712 +integ_part3,0.000278835 +integ_sin,0.000015583 +integ_x,0.000022583 +math_associate_adds,0.005222257 +math_diff_different,0.000025125 +math_diff_ln,0.000016917 +math_diff_same,0.000016042 +math_diff_simple1,0.000292376 +math_diff_simple2,0.000197 +math_powers,0.000029 +math_simplify_add,0.00018079199999999999 +math_simplify_const,0.000077376 +math_simplify_factor,0.000772084 +math_simplify_root,0.001093752 +diff_power_harder,0.018272403 +diff_power_simple,0.00045491400000000006 +integ_one,0.000018125 +integ_part1,0.000731414 +integ_part2,0.003110741 +integ_part3,0.000285458 +integ_sin,0.000018083 +integ_x,0.000023958 +math_associate_adds,0.005174609 +math_diff_different,0.000020958 +math_diff_ln,0.000015623999999999998 +math_diff_same,0.000020001 +math_diff_simple1,0.000267166 +math_diff_simple2,0.000203709 +math_powers,0.000025708 +math_simplify_add,0.00016270800000000001 +math_simplify_const,0.000067958 +math_simplify_factor,0.0006973719999999999 +math_simplify_root,0.0009685389999999999 +diff_power_harder,0.018521787 +diff_power_simple,0.00046241699999999996 +integ_one,0.000022417 +integ_part1,0.000742333 +integ_part2,0.003101416 +integ_part3,0.000294126 +integ_sin,0.000017083999999999998 +integ_x,0.000023833 +math_associate_adds,0.005226875 +math_diff_different,0.000021623999999999998 +math_diff_ln,0.000019792 +math_diff_same,0.000017 +math_diff_simple1,0.000269874 +math_diff_simple2,0.000193749 +math_powers,0.000025625999999999998 +math_simplify_add,0.000155751 +math_simplify_const,0.000067251 +math_simplify_factor,0.0006932509999999999 +math_simplify_root,0.000994665 +diff_power_harder,0.018488011999999998 +diff_power_simple,0.000447208 +integ_one,0.000017875 +integ_part1,0.000715626 +integ_part2,0.0030908290000000002 +integ_part3,0.000296832 +integ_sin,0.000017208000000000002 +integ_x,0.000023792 +math_associate_adds,0.005237906 +math_diff_different,0.000021417 +math_diff_ln,0.000020917 +math_diff_same,0.00001825 +math_diff_simple1,0.000268541 +math_diff_simple2,0.00019629100000000002 +math_powers,0.000025 +math_simplify_add,0.00014754 +math_simplify_const,0.000075416 +math_simplify_factor,0.000753583 +math_simplify_root,0.0011019559999999999 +diff_power_harder,0.018789224000000004 +diff_power_simple,0.00045216699999999996 +integ_one,0.000018167 +integ_part1,0.0007439170000000001 +integ_part2,0.0030491669999999998 +integ_part3,0.00027429 +integ_sin,0.000015791000000000002 +integ_x,0.000029 +math_associate_adds,0.005239588 +math_diff_different,0.000027541999999999998 +math_diff_ln,0.000017624999999999998 +math_diff_same,0.00001675 +math_diff_simple1,0.000271458 +math_diff_simple2,0.000198835 +math_powers,0.000026999 +math_simplify_add,0.00015270800000000002 +math_simplify_const,0.00006533400000000001 +math_simplify_factor,0.0006894179999999999 +math_simplify_root,0.001014711 diff --git a/src/egraph.rs b/src/egraph.rs index d45aa579..8f98d6a9 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -1261,7 +1261,7 @@ impl> EGraph { { let rewrites_to_undo: Vec<_> = rewrites_to_undo.into_iter().collect(); - info!("Undoing {} rewrites", rewrites_to_undo.len(),); + info!("Undoing {} rewrites", rewrites_to_undo.len()); debug!( "Rewrites to undo: {:?}", rewrites_to_undo @@ -1298,6 +1298,7 @@ impl> EGraph { .flat_map(|(applier_pat, search_matches)| { zip(repeat(applier_pat), search_matches.substs) }) + .skip(1) .filter_map(|(applier_pat, subst)| { let applier_enode_id = self.find_enode_id(applier_pat.ast.as_ref(), &subst)?; diff --git a/src/run.rs b/src/run.rs index 14ca17c9..2c4fd569 100644 --- a/src/run.rs +++ b/src/run.rs @@ -413,12 +413,6 @@ where check_rules(&rules); self.egraph.rebuild(); - let _original_enodes = self - .egraph - .classes() - .flat_map(|eclass| eclass.nodes.clone()) - .collect::>(); - loop { let iter = self.run_one(&rules); self.iterations.push(iter); diff --git a/src/test.rs b/src/test.rs index 4784e816..7c1b1191 100644 --- a/src/test.rs +++ b/src/test.rs @@ -171,7 +171,6 @@ where ); let mut runner = Runner::default() - .with_scheduler(SimpleScheduler) .with_hook(move |runner| { let n_nodes = runner.egraph.total_number_of_nodes(); eprintln!("Iter {}, {} nodes", runner.iterations.len(), n_nodes); @@ -185,6 +184,16 @@ where .with_node_limit(node_limit) .with_time_limit(Duration::from_secs(time_limit)); + let original_enodes = runner + .egraph + .classes() + .flat_map(|eclass| eclass.nodes.clone()) + .collect::>(); + + let inverse_scheduler = InverseScheduler::new(1000, runner.roots.clone(), original_enodes); + + let mut runner = runner.with_scheduler(SimpleScheduler); + for expr in exprs { runner = runner.with_expr(&expr.parse().unwrap()); } diff --git a/tests/math.rs b/tests/math.rs index a0d8c07a..621bb657 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -96,7 +96,7 @@ impl Analysis for ConstantFold { egraph.union(id, added); } // to not prune, comment this out - egraph[id].nodes.retain(|n| n.is_leaf()); + // egraph[id].nodes.retain(|n| n.is_leaf()); #[cfg(debug_assertions)] egraph[id].assert_unique_leaves();