diff --git a/.gitignore b/.gitignore index 53547711..55705d54 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,6 @@ perf.data* *.bench .vscode/settings.json out.csv + +# macOS +.DS_Store diff --git a/math_baseline.csv b/math_baseline.csv new file mode 100644 index 00000000..962e4b82 --- /dev/null +++ b/math_baseline.csv @@ -0,0 +1,85 @@ +diff_power_simple,0.001590461 +integ_one,0.000020708000000000002 +integ_part1,0.001072126 +integ_part3,0.000442666 +integ_sin,0.000019125 +integ_x,0.000025583999999999998 +math_associate_adds,0.005160467 +math_diff_different,0.00002425 +math_diff_ln,0.000018791 +math_diff_same,0.000017459 +math_diff_simple1,0.000317501 +math_diff_simple2,0.00026687600000000004 +math_powers,0.000028751 +math_simplify_add,0.000265459 +math_simplify_const,0.000117041 +math_simplify_factor,0.009269597000000001 +math_simplify_root,0.026495581000000004 +diff_power_simple,0.001554334 +integ_one,0.0000215 +integ_part1,0.001085791 +integ_part3,0.00046200099999999997 +integ_sin,0.000018624 +integ_x,0.000024042 +math_associate_adds,0.005166336 +math_diff_different,0.000019375 +math_diff_ln,0.000019208 +math_diff_same,0.000017625 +math_diff_simple1,0.000315541 +math_diff_simple2,0.00028537500000000004 +math_powers,0.000029749999999999998 +math_simplify_add,0.00027008399999999997 +math_simplify_const,0.00011716700000000001 +math_simplify_factor,0.009123505 +math_simplify_root,0.026413933000000004 +diff_power_simple,0.001597454 +integ_one,0.000021458 +integ_part1,0.001163789 +integ_part3,0.000451706 +integ_sin,0.000019584 +integ_x,0.000025625 +math_associate_adds,0.005248073000000001 +math_diff_different,0.000022709 +math_diff_ln,0.000024376000000000002 +math_diff_same,0.000018999000000000002 +math_diff_simple1,0.000323416 +math_diff_simple2,0.00027429200000000004 +math_powers,0.000028583 +math_simplify_add,0.000304208 +math_simplify_const,0.000122458 +math_simplify_factor,0.009213397 +math_simplify_root,0.026768194 +diff_power_simple,0.001610165 +integ_one,0.000018707999999999998 +integ_part1,0.001082832 +integ_part3,0.000467209 +integ_sin,0.000019916 +integ_x,0.000025583 +math_associate_adds,0.005180661 +math_diff_different,0.00002175 +math_diff_ln,0.000021167000000000002 +math_diff_same,0.000016791999999999998 +math_diff_simple1,0.000325459 +math_diff_simple2,0.00027145800000000003 +math_powers,0.000031459 +math_simplify_add,0.000271375 +math_simplify_const,0.000121082 +math_simplify_factor,0.009118949 +math_simplify_root,0.026393884 +diff_power_simple,0.002519455 +integ_one,0.000023375000000000002 +integ_part1,0.0011277070000000001 +integ_part3,0.00046670800000000003 +integ_sin,0.000020917000000000003 +integ_x,0.000043083 +math_associate_adds,0.0052744079999999995 +math_diff_different,0.000023209 +math_diff_ln,0.000016833 +math_diff_same,0.000016624999999999998 +math_diff_simple1,0.000667998 +math_diff_simple2,0.00028612500000000003 +math_powers,0.000031791999999999996 +math_simplify_add,0.000274334 +math_simplify_const,0.000114417 +math_simplify_factor,0.009354527 +math_simplify_root,0.027854118 diff --git a/math_inverse.csv b/math_inverse.csv new file mode 100644 index 00000000..97647060 --- /dev/null +++ b/math_inverse.csv @@ -0,0 +1,95 @@ +diff_power_harder,0.018302689 +diff_power_simple,0.000444457 +integ_one,0.000022042 +integ_part1,0.000722582 +integ_part2,0.002940712 +integ_part3,0.000278835 +integ_sin,0.000015583 +integ_x,0.000022583 +math_associate_adds,0.005222257 +math_diff_different,0.000025125 +math_diff_ln,0.000016917 +math_diff_same,0.000016042 +math_diff_simple1,0.000292376 +math_diff_simple2,0.000197 +math_powers,0.000029 +math_simplify_add,0.00018079199999999999 +math_simplify_const,0.000077376 +math_simplify_factor,0.000772084 +math_simplify_root,0.001093752 +diff_power_harder,0.018272403 +diff_power_simple,0.00045491400000000006 +integ_one,0.000018125 +integ_part1,0.000731414 +integ_part2,0.003110741 +integ_part3,0.000285458 +integ_sin,0.000018083 +integ_x,0.000023958 +math_associate_adds,0.005174609 +math_diff_different,0.000020958 +math_diff_ln,0.000015623999999999998 +math_diff_same,0.000020001 +math_diff_simple1,0.000267166 +math_diff_simple2,0.000203709 +math_powers,0.000025708 +math_simplify_add,0.00016270800000000001 +math_simplify_const,0.000067958 +math_simplify_factor,0.0006973719999999999 +math_simplify_root,0.0009685389999999999 +diff_power_harder,0.018521787 +diff_power_simple,0.00046241699999999996 +integ_one,0.000022417 +integ_part1,0.000742333 +integ_part2,0.003101416 +integ_part3,0.000294126 +integ_sin,0.000017083999999999998 +integ_x,0.000023833 +math_associate_adds,0.005226875 +math_diff_different,0.000021623999999999998 +math_diff_ln,0.000019792 +math_diff_same,0.000017 +math_diff_simple1,0.000269874 +math_diff_simple2,0.000193749 +math_powers,0.000025625999999999998 +math_simplify_add,0.000155751 +math_simplify_const,0.000067251 +math_simplify_factor,0.0006932509999999999 +math_simplify_root,0.000994665 +diff_power_harder,0.018488011999999998 +diff_power_simple,0.000447208 +integ_one,0.000017875 +integ_part1,0.000715626 +integ_part2,0.0030908290000000002 +integ_part3,0.000296832 +integ_sin,0.000017208000000000002 +integ_x,0.000023792 +math_associate_adds,0.005237906 +math_diff_different,0.000021417 +math_diff_ln,0.000020917 +math_diff_same,0.00001825 +math_diff_simple1,0.000268541 +math_diff_simple2,0.00019629100000000002 +math_powers,0.000025 +math_simplify_add,0.00014754 +math_simplify_const,0.000075416 +math_simplify_factor,0.000753583 +math_simplify_root,0.0011019559999999999 +diff_power_harder,0.018789224000000004 +diff_power_simple,0.00045216699999999996 +integ_one,0.000018167 +integ_part1,0.0007439170000000001 +integ_part2,0.0030491669999999998 +integ_part3,0.00027429 +integ_sin,0.000015791000000000002 +integ_x,0.000029 +math_associate_adds,0.005239588 +math_diff_different,0.000027541999999999998 +math_diff_ln,0.000017624999999999998 +math_diff_same,0.00001675 +math_diff_simple1,0.000271458 +math_diff_simple2,0.000198835 +math_powers,0.000026999 +math_simplify_add,0.00015270800000000002 +math_simplify_const,0.00006533400000000001 +math_simplify_factor,0.0006894179999999999 +math_simplify_root,0.001014711 diff --git a/src/egraph.rs b/src/egraph.rs index b8688153..8f98d6a9 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -2,6 +2,7 @@ use crate::*; use std::{ borrow::BorrowMut, fmt::{self, Debug, Display}, + iter::{repeat, zip}, marker::PhantomData, }; @@ -57,7 +58,7 @@ pub struct EGraph> { pub(crate) explain: Option>, unionfind: UnionFind, /// Stores the original node represented by each non-canonical id - nodes: Vec, + nodes: HashMap, /// Stores each enode's `Id`, not the `Id` of the eclass. /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new /// unions can cause them to become out of date. @@ -220,7 +221,7 @@ impl> EGraph { panic!("Use runner.with_explanations_enabled() or egraph.with_explanations_enabled() before running to get a copied egraph without unions"); } let mut egraph = Self::new(analysis); - for node in &self.nodes { + for (_, node) in &self.nodes { egraph.add(node.clone()); } egraph @@ -369,7 +370,7 @@ impl> EGraph { /// Like [`id_to_expr`](EGraph::id_to_expr) but only goes one layer deep pub fn id_to_node(&self, id: Id) -> &L { - &self.nodes[usize::from(id)] + self.nodes.get(&id).unwrap() } /// Like [`id_to_expr`](EGraph::id_to_expr), but creates a pattern instead of a term. @@ -576,13 +577,17 @@ impl> EGraph { /// assert_eq!(egraph.find(x), egraph.find(y)); /// ``` pub fn find(&self, id: Id) -> Id { - self.unionfind.find(id) + self.unionfind + .find(id) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", id)) } /// This is private, but internals should use this whenever /// possible because it does path compression. fn find_mut(&mut self, id: Id) -> Id { - self.unionfind.find_mut(id) + self.unionfind + .find_mut(id) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", id)) } /// Creates a [`Dot`] to visualize this egraph. See [`Dot`]. @@ -652,8 +657,8 @@ where nodes: src_egraph .nodes .into_iter() - .map(|x| self.map_node(x)) - .collect(), + .map(|(id, enode)| (id, self.map_node(enode))) + .collect::>(), analysis_pending: src_egraph.analysis_pending, classes: src_egraph .classes @@ -1047,8 +1052,7 @@ impl> EGraph { } else { let new_id = self.unionfind.make_set(); explain.add(original.clone(), new_id, new_id); - debug_assert_eq!(Id::from(self.nodes.len()), new_id); - self.nodes.push(original); + self.nodes.insert(new_id, original); self.unionfind.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); new_id @@ -1080,8 +1084,7 @@ impl> EGraph { parents: Default::default(), }; - debug_assert_eq!(Id::from(self.nodes.len()), id); - self.nodes.push(original); + self.nodes.insert(id, enode.clone()); // add this enode to the parent lists of its children enode.for_each(|child| { @@ -1240,6 +1243,310 @@ impl> EGraph { true } + /// Inverse equality saturation. + /// + /// Removes RHS instances of the input rewrites, but will not remove enodes + /// belonging to the original egraph (before rewriting). + /// + /// TODO: Example + pub fn undo_rewrites<'a, R>( + &mut self, + rewrites_to_undo: R, + roots: &[Id], + original_enodes: &HashSet, + ) where + R: IntoIterator>, + L: 'a, + N: 'a, + { + let rewrites_to_undo: Vec<_> = rewrites_to_undo.into_iter().collect(); + + info!("Undoing {} rewrites", rewrites_to_undo.len()); + debug!( + "Rewrites to undo: {:?}", + rewrites_to_undo + .iter() + .map(|rewrite| rewrite.name) + .collect::>() + ); + + let patterns_and_matches: Vec<(Pattern, Vec>)> = zip( + rewrites_to_undo.iter().map(|rewrite| { + Pattern::from( + rewrite + .applier + .get_pattern_ast() + .expect("Applier (RHS) of rewrite rule should be a pattern") + .clone(), + ) + }), + rewrites_to_undo.iter().map(|rewrite| rewrite.search(self)), + ) + .collect(); + + if patterns_and_matches.is_empty() { + // TODO: If no matches should I return `Option::None`? Not sure how + // the API for this should work. + return; + } + + let applier_enode_ids: HashSet = patterns_and_matches + .into_iter() + .flat_map(|(applier_pat, all_search_matches)| { + zip(repeat(applier_pat), all_search_matches) + }) + .flat_map(|(applier_pat, search_matches)| { + zip(repeat(applier_pat), search_matches.substs) + }) + .skip(1) + .filter_map(|(applier_pat, subst)| { + let applier_enode_id = self.find_enode_id(applier_pat.ast.as_ref(), &subst)?; + + let canonicalized_enode = self + .id_to_node(*applier_enode_id) + .clone() + .map_children(|child| self.find(child)); + + // Always preserve the original enodes. Does not compare by ID + // since an enode may have multiple IDs (at least in + // `self.nodes`). + // + // TODO: Is there a way to use IDs? Or a way this would have + // false negatives? + if original_enodes.contains(&canonicalized_enode) { + debug!( + "Skip marking {:?} for removal since it's in the original egraph", + canonicalized_enode + ); + return None; + } + + Some(*applier_enode_id) + }) + .collect(); + + self.remove_enodes(applier_enode_ids, roots); + } + + /// Find the e-node ID given a pattern and a substitution. + /// + /// # Example: TODO not actually done yet + /// ```ignore + /// use egg::*; + /// let egraph = EGraph::::default(); + /// let enode_id = egraph.find_enode_id(pattern_ast, subst); + /// let enode = egraph.id_to_node(enode_id); + /// ``` + fn find_enode_id(&self, pattern: &[ENodeOrVar], subst: &Subst) -> Option<&Id> { + // Pretty sure this is required since finding an instantiated enode + // relies on the egraph's enodes having canonicalized children + // TODO: Better explanation + assert!( + self.clean, + "Cannot remove enodes without a clean egraph, try rebuilding" + ); + + let mut id_buf: Vec = vec![0.into(); pattern.len()]; + let mut candidate: Option<&Id> = None; + for (i, enode_or_var) in pattern.iter().enumerate() { + let id = match enode_or_var { + ENodeOrVar::Var(var) => *subst.get(*var)?, + ENodeOrVar::ENode(enode) => { + let instantiated_enode = enode + .clone() + .map_children(|child| id_buf[usize::from(child)]); + candidate = self.memo.get(&instantiated_enode); + self.lookup(instantiated_enode)? + } + }; + id_buf[i] = id; + } + candidate + } + + /// Removes specified enodes (except when it would leave dangling children) + /// and cleans up the resulting egraph, in particular by removing + /// unreachable eclasses and enodes. + /// + /// For an enode that has parents (which aren't being removed) and is the + /// only member of its eclass, removing it when leave a dangling child. In + /// this case, the enode is not removed. + fn remove_enodes(&mut self, enode_ids: HashSet, roots: &[Id]) { + // Pretty sure this is required since e.g. rebuilding dedups enodes in + // eclasses, `remove_enodes` can't handle duplicate enodes + // TODO: Better explanation + assert!( + self.clean, + "cannot remove enodes without a clean egraph, try rebuilding" + ); + + if enode_ids.is_empty() { + return; + } + + debug!( + "Enodes which may be removed (with uncanonical children): {:?}", + enode_ids + .iter() + .map(|id| self.id_to_node(*id)) + .collect::>() + ); + trace!("EGraph before removing enodes:\n{:?}", self.dump()); + + let num_enodes = self.nodes.len(); + let num_eclasses = self.classes.len(); + + // Remove the input enodes from their corresponding eclasses + 'enode_removal: for enode_id in &enode_ids { + // Canonicalize enode's children so it can be found in + // `eclass.nodes` + let enode_to_remove = self + .id_to_node(*enode_id) + .clone() + .map_children(|child| self.find_mut(child)); + let eclass_id = &self.find_mut(*enode_id); + + // Skip if eclass is a singleton and has parents + // + // TODO: Might be skipping removals that are valid? + let eclass = self.classes.get(eclass_id).unwrap(); + if eclass.nodes.len() == 1 && !eclass.parents.is_empty() { + info!( + "Skipping removal of {:?} to avoid becoming a dangling child", + &enode_to_remove + ); + continue 'enode_removal; + } + + trace!( + "Removing enode {:?} with id {:?}", + self.id_to_node(*enode_id), + enode_id + ); + + // TODO: Shadowing shenanigans to avoid angering the borrow checker + let eclass = self.classes.get_mut(eclass_id).unwrap(); + eclass + .nodes + .remove( + eclass + .nodes + .binary_search(&enode_to_remove) + .unwrap_or_else(|_| { + panic!( + "Enode to remove ({:?}) not found in eclass: {:?}\nMost likely the result of external code removing enodes from the egraph", + enode_to_remove, eclass + ) + }), + ); + + // Remove enode from parent arrays of children eclasses + for eclass_id in enode_to_remove.children() { + let eclass = self.classes.get_mut(eclass_id).unwrap(); + eclass.parents.swap_remove( + eclass + .parents + .iter() + .position(|parent_id| parent_id == enode_id) + .expect("enode should be in parents array of its children eclasses"), + ); + } + } + + self.remove_unreachable(roots); + + info!( + "Removed {} enodes ({} remaining) and {} eclasses ({} remaining)", + num_enodes - self.nodes.len(), + self.nodes.len(), + num_eclasses - self.classes.len(), + self.classes.len() + ); + } + + fn remove_unreachable(&mut self, roots: &[Id]) { + // Canonicalize roots' children + let roots: Vec = roots.iter().map(|id| self.find_mut(*id)).collect(); + + let mut visited_eclasses = HashSet::::default(); + let mut visited_enodes = HashSet::::default(); + + let mut dfs_stack: Vec<&L> = roots + .iter() + .flat_map(|id| match self.classes.get(id) { + Some(eclass) => eclass.nodes.iter(), + None => [].iter(), + }) + .collect(); + + // Traverse egraph from roots to leaves, marking visited eclasses and enodes + while let Some(enode) = dfs_stack.pop() { + let enode_id = *self.memo.get(enode).unwrap(); + let eclass_id = self.find(enode_id); + + visited_eclasses.insert(eclass_id); + visited_enodes.insert(enode_id); + + let children_enodes: Vec<&L> = enode + .children() + .iter() + // Avoid following cycles + .filter(|child| !visited_eclasses.contains(child)) + .flat_map(|child| self.classes.get(child).unwrap().iter()) + .collect(); + dfs_stack.extend(children_enodes); + } + + // Remove unreachable enodes + self.memo + .retain(|_, enode_id| visited_enodes.contains(enode_id)); + self.nodes + .retain(|enode_id, _| visited_enodes.contains(enode_id)); + + // Clean up eclasses + for eclass_id in self.classes.keys().copied().collect::>() { + if visited_eclasses.contains(&eclass_id) { + // Remove unreachable parents + let eclass = self.classes.get_mut(&eclass_id).unwrap(); + eclass + .parents + .retain(|parent_id| visited_enodes.contains(parent_id)); + } else { + // Remove unreachable eclasses + self.unionfind.delete(eclass_id); + self.classes.remove(&eclass_id); + self.classes_by_op.values_mut().for_each(|op| { + op.remove(&eclass_id); + }); + } + } + + trace!("EGraph after removing enodes:\n{:?}", self.dump()); + + #[cfg(debug_assertions)] + { + // Check for existence of remaining enodes' children + for (enode_id, enode) in self.nodes.iter() { + debug!( + "Validating children of remaining enode {:?} with id {:?}", + enode, enode_id + ); + for child in enode.children() { + trace!("Looking for child {:?}", child); + self.find(*child); + } + } + + // Check for existence of remaining eclasses' parents + for (eclass_id, eclass) in self.classes.iter() { + debug!("Validating parents of remaining eclass {:?}", eclass_id); + for parent in &eclass.parents { + self.find(*parent); + } + } + } + } + /// Update the analysis data of an e-class. /// /// This also propagates the changes through the e-graph, @@ -1305,10 +1612,12 @@ impl> EGraph { for class in self.classes.values_mut() { let old_len = class.len(); - class - .nodes - .iter_mut() - .for_each(|n| n.update_children(|id| uf.find_mut(id))); + class.nodes.iter_mut().for_each(|n| { + n.update_children(|id| { + uf.find_mut(id) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", id)) + }) + }); class.nodes.sort_unstable(); class.nodes.dedup(); @@ -1384,13 +1693,13 @@ impl> EGraph { let mut n_unions = 0; while !self.pending.is_empty() || !self.analysis_pending.is_empty() { - while let Some(class) = self.pending.pop() { - let mut node = self.nodes[usize::from(class)].clone(); + while let Some(class_id) = self.pending.pop() { + let mut node = self.nodes.get(&class_id).unwrap().clone(); node.update_children(|id| self.find_mut(id)); - if let Some(memo_class) = self.memo.insert(node, class) { + if let Some(memo_class) = self.memo.insert(node, class_id) { let did_something = self.perform_union( memo_class, - class, + class_id, Some(Justification::Congruence), false, ); @@ -1399,7 +1708,7 @@ impl> EGraph { } while let Some(class_id) = self.analysis_pending.pop() { - let node = self.nodes[usize::from(class_id)].clone(); + let node = self.nodes.get(&class_id).unwrap().clone(); let class_id = self.find_mut(class_id); let node_data = N::make(self, &node); let class = self.classes.get_mut(&class_id).unwrap(); diff --git a/src/explain.rs b/src/explain.rs index 33cc0bb4..4225e7b7 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -82,7 +82,7 @@ pub struct Explain { pub(crate) struct ExplainNodes<'a, L: Language> { explain: &'a mut Explain, - nodes: &'a [L], + nodes: &'a HashMap, } #[derive(Default)] @@ -1047,7 +1047,7 @@ impl Explain { equalities } - pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a [L]) -> ExplainNodes<'a, L> { + pub(crate) fn with_nodes<'a>(&'a mut self, nodes: &'a HashMap) -> ExplainNodes<'a, L> { ExplainNodes { explain: self, nodes, @@ -1071,7 +1071,7 @@ impl<'a, L: Language> DerefMut for ExplainNodes<'a, L> { impl<'x, L: Language> ExplainNodes<'x, L> { pub(crate) fn node(&self, node_id: Id) -> &L { - &self.nodes[usize::from(node_id)] + self.nodes.get(&node_id).unwrap() } fn node_to_explanation( &self, @@ -1607,13 +1607,14 @@ impl<'x, L: Language> ExplainNodes<'x, L> { 'outer: for eclass in classes.keys() { let enodes = self.find_all_enodes(*eclass); // find all congruence nodes - let mut cannon_enodes: HashMap> = Default::default(); + let mut canonical_enodes: HashMap> = Default::default(); for enode in &enodes { - let cannon = self - .node(*enode) - .clone() - .map_children(|child| unionfind.find(child)); - if let Some(others) = cannon_enodes.get_mut(&cannon) { + let canonical = self.node(*enode).clone().map_children(|child| { + unionfind + .find(child) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", child)) + }); + if let Some(others) = canonical_enodes.get_mut(&canonical) { for other in others.iter() { congruence_neighbors[usize::from(*enode)].push(*other); congruence_neighbors[usize::from(*other)].push(*enode); @@ -1622,7 +1623,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> { others.push(*enode); } else { counter += 1; - cannon_enodes.insert(cannon, vec![*enode]); + canonical_enodes.insert(canonical, vec![*enode]); } // Don't find every congruence edge because that could be n^2 edges if counter > CONGRUENCE_LIMIT * self.explainfind.len() { @@ -1830,14 +1831,22 @@ impl<'x, L: Language> ExplainNodes<'x, L> { common_ancestor, ); unionfind.union(enode, *child); - ancestor[usize::from(unionfind.find(enode))] = enode; + ancestor[usize::from( + unionfind + .find(enode) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", child)), + )] = enode; } if common_ancestor_queries.get(&enode).is_some() { black_set.insert(enode); for other in common_ancestor_queries.get(&enode).unwrap() { if black_set.contains(other) { - let ancestor = ancestor[usize::from(unionfind.find(*other))]; + let ancestor = ancestor[usize::from( + unionfind + .find(*other) + .unwrap_or_else(|| panic!("eclass id {:?} not in egraph", other)), + )]; common_ancestor.insert((enode, *other), ancestor); common_ancestor.insert((*other, enode), ancestor); } diff --git a/src/run.rs b/src/run.rs index ae78d84e..2c4fd569 100644 --- a/src/run.rs +++ b/src/run.rs @@ -412,6 +412,7 @@ where let rules: Vec<&Rewrite> = rules.into_iter().collect(); check_rules(&rules); self.egraph.rebuild(); + loop { let iter = self.run_one(&rules); self.iterations.push(iter); @@ -540,7 +541,7 @@ where let mut applied = IndexMap::default(); result = result.and_then(|_| { rules.iter().try_for_each(|rw| { - let ms = self.scheduler.search_rewrite(i, &self.egraph, rw); + let ms = self.scheduler.search_rewrite(i, &mut self.egraph, rw); matches.push(ms); self.check_limits() }) @@ -686,7 +687,7 @@ where fn search_rewrite<'a>( &mut self, iteration: usize, - egraph: &EGraph, + egraph: &mut EGraph, rewrite: &'a Rewrite, ) -> Vec> { rewrite.search(egraph) @@ -866,7 +867,7 @@ where fn search_rewrite<'a>( &mut self, iteration: usize, - egraph: &EGraph, + egraph: &mut EGraph, rewrite: &'a Rewrite, ) -> Vec> { let stats = self.rule_stats(rewrite.name); @@ -906,6 +907,106 @@ where } } +#[derive(Debug)] +struct UndoStats { + times_applied: usize, + times_undone: usize, + match_limit: usize, +} + +/// TODO: docs +pub struct InverseScheduler { + default_match_limit: usize, + stats: IndexMap, + roots: Vec, + original_enodes: HashSet, +} + +impl InverseScheduler { + /// TODO: docs + pub fn new(default_match_limit: usize, roots: T, original_enodes: HashSet) -> Self + where + T: IntoIterator, + { + Self { + default_match_limit, + stats: Default::default(), + roots: roots.into_iter().collect(), + original_enodes, + } + } + + /// Set the initial match limit after which a rule will be undone. + /// Default: 1,000 + pub fn with_initial_match_limit(mut self, limit: usize) -> Self { + self.default_match_limit = limit; + self + } + + fn rule_stats(&mut self, name: Symbol) -> &mut UndoStats { + if self.stats.contains_key(&name) { + &mut self.stats[&name] + } else { + self.stats.entry(name).or_insert(UndoStats { + times_applied: 0, + times_undone: 0, + match_limit: self.default_match_limit, + }) + } + } + + /// Never undo a particular rule. + pub fn do_not_undo(mut self, name: impl Into) -> Self { + self.rule_stats(name.into()).match_limit = usize::MAX; + self + } + + /// Set the initial match limit for a rule. + pub fn rule_match_limit(mut self, name: impl Into, limit: usize) -> Self { + self.rule_stats(name.into()).match_limit = limit; + self + } +} + +impl RewriteScheduler for InverseScheduler +where + L: Language, + N: Analysis, +{ + fn can_stop(&mut self, _iteration: usize) -> bool { + // TODO: Reapply rules that were undone "too many times"? How to define that? + true + } + + fn search_rewrite<'a>( + &mut self, + _iteration: usize, + egraph: &mut EGraph, + rewrite: &'a Rewrite, + ) -> Vec> { + let stats = self.rule_stats(rewrite.name); + + let threshold = stats + .match_limit + .checked_shl(stats.times_undone as u32) + .unwrap(); + let matches = rewrite.search_with_limit(egraph, threshold.saturating_add(1)); + let total_len: usize = matches.iter().map(|m| m.substs.len()).sum(); + if total_len > threshold { + stats.times_undone += 1; + info!( + "Undoing {} (applied {}, undone{})", + rewrite.name, stats.times_applied, stats.times_undone, + ); + egraph.undo_rewrites([rewrite], self.roots.as_slice(), &self.original_enodes); + vec![] + } else { + stats.times_applied += 1; + matches + } + } +} + /// Custom data to inject into the [`Iteration`]s recorded by a [`Runner`] /// /// This trait allows you to add custom data to the [`Iteration`]s diff --git a/src/test.rs b/src/test.rs index 4784e816..7c1b1191 100644 --- a/src/test.rs +++ b/src/test.rs @@ -171,7 +171,6 @@ where ); let mut runner = Runner::default() - .with_scheduler(SimpleScheduler) .with_hook(move |runner| { let n_nodes = runner.egraph.total_number_of_nodes(); eprintln!("Iter {}, {} nodes", runner.iterations.len(), n_nodes); @@ -185,6 +184,16 @@ where .with_node_limit(node_limit) .with_time_limit(Duration::from_secs(time_limit)); + let original_enodes = runner + .egraph + .classes() + .flat_map(|eclass| eclass.nodes.clone()) + .collect::>(); + + let inverse_scheduler = InverseScheduler::new(1000, runner.roots.clone(), original_enodes); + + let mut runner = runner.with_scheduler(SimpleScheduler); + for expr in exprs { runner = runner.with_expr(&expr.parse().unwrap()); } diff --git a/src/unionfind.rs b/src/unionfind.rs index 39e9bc58..3e213274 100644 --- a/src/unionfind.rs +++ b/src/unionfind.rs @@ -4,48 +4,80 @@ use std::fmt::Debug; #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct UnionFind { - parents: Vec, + // TODO: Oof doubling memory usage and maybe destroying the cache hit rate + // by using Option instead of Id. Any way to use NonZeroU32 or similar? + parents: Vec>, } impl UnionFind { pub fn make_set(&mut self) -> Id { let id = Id::from(self.parents.len()); - self.parents.push(id); + self.parents.push(Some(id)); id } pub fn size(&self) -> usize { - self.parents.len() + self.parents.iter().filter(|node| node.is_some()).count() } - fn parent(&self, query: Id) -> Id { + fn parent(&self, query: Id) -> Option { self.parents[usize::from(query)] } - fn parent_mut(&mut self, query: Id) -> &mut Id { - &mut self.parents[usize::from(query)] + fn parent_mut(&mut self, query: Id) -> Option<&mut Id> { + self.parents[usize::from(query)].as_mut() } - pub fn find(&self, mut current: Id) -> Id { - while current != self.parent(current) { - current = self.parent(current) + pub fn find(&self, mut current: Id) -> Option { + while current != self.parent(current)? { + current = self.parent(current)? } - current + Some(current) } - pub fn find_mut(&mut self, mut current: Id) -> Id { - while current != self.parent(current) { - let grandparent = self.parent(self.parent(current)); - *self.parent_mut(current) = grandparent; + pub fn find_mut(&mut self, mut current: Id) -> Option { + while current != self.parent(current)? { + let grandparent = self.parent(self.parent(current)?)?; + *self.parent_mut(current)? = grandparent; current = grandparent; } - current + Some(current) } /// Given two leader ids, unions the two eclasses making root1 the leader. - pub fn union(&mut self, root1: Id, root2: Id) -> Id { - *self.parent_mut(root2) = root1; - root1 + pub fn union(&mut self, root1: Id, root2: Id) -> Option { + *self.parent_mut(root2)? = root1; + Some(root1) + } + + /// TODO: Naive implementation, for a potentially more efficient one see + /// [this paper](https://dl.acm.org/doi/10.1145/2636922), but I think it + /// would triple the memory usage (and be hella more complicated) + /// + /// There's also [this cool paper](https://link.springer.com/article/10.1007/s10817-017-9431-7), + /// although I don't think it covers deletions + pub fn delete(&mut self, query: Id) { + let parent = self.parent(query); + + self.parents[usize::from(query)] = None; + + let mut new_root: Option = None; + for idx in 0..self.parents.len() { + if parent == Some(query) { + // Deleted a root node so choose a new root for the children, if any + if self.parents[idx] == Some(query) { + if new_root.is_none() { + new_root = Some(Id::from(idx)); + } + self.parents[idx] = new_root; + } + } else { + // Deleting a non-root node + if self.parents[idx] == Some(query) { + self.parents[idx] = parent; + } + } + } } } @@ -53,8 +85,8 @@ impl UnionFind { mod tests { use super::*; - fn ids(us: impl IntoIterator) -> Vec { - us.into_iter().map(|u| u.into()).collect() + fn ids(us: impl IntoIterator>) -> Vec> { + us.into_iter().map(|u| u.map(|id| id.into())).collect() } #[test] @@ -68,7 +100,7 @@ mod tests { } // test the initial condition of everyone in their own set - assert_eq!(uf.parents, ids(0..n)); + assert_eq!(uf.parents, ids((0..n).map(Some))); // build up one set uf.union(id(0), id(1)); @@ -86,7 +118,104 @@ mod tests { } // indexes: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 - let expected = vec![0, 0, 0, 0, 4, 5, 6, 6, 6, 6]; + let expected = vec![ + Some(0), + Some(0), + Some(0), + Some(0), + Some(4), + Some(5), + Some(6), + Some(6), + Some(6), + Some(6), + ]; assert_eq!(uf.parents, ids(expected)); } + + #[test] + fn delete() { + let mut union_find = UnionFind::default(); + for _ in 0..10 { + union_find.make_set(); + } + + union_find.union(Id::from(0), Id::from(1)); + union_find.union(Id::from(0), Id::from(2)); + union_find.union(Id::from(0), Id::from(3)); + + union_find.union(Id::from(6), Id::from(7)); + union_find.union(Id::from(7), Id::from(8)); + union_find.union(Id::from(8), Id::from(9)); + + assert_eq!( + union_find.parents, + ids(vec![ + Some(0), + Some(0), + Some(0), + Some(0), + Some(4), + Some(5), + Some(6), + Some(6), + Some(7), + Some(8) + ]) + ); + + // Deletion leaves vacant nodes to avoid changing IDs (which correspond with indices) + // Since 0 is a root node, its children are assigned a new root (1) + union_find.delete(Id::from(0)); + assert_eq!( + union_find.parents, + ids(vec![ + None, + Some(1), + Some(1), + Some(1), + Some(4), + Some(5), + Some(6), + Some(6), + Some(7), + Some(8) + ]) + ); + + // + union_find.delete(Id::from(4)); + assert_eq!( + union_find.parents, + ids(vec![ + None, + Some(1), + Some(1), + Some(1), + None, + Some(5), + Some(6), + Some(6), + Some(7), + Some(8) + ]) + ); + + union_find.delete(Id::from(6)); + assert_eq!( + union_find.parents, + ids(vec![ + None, + Some(1), + Some(1), + Some(1), + None, + Some(5), + None, + Some(7), + Some(7), + Some(8) + ]) + ); + } } diff --git a/tests/math.rs b/tests/math.rs index a0d8c07a..621bb657 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -96,7 +96,7 @@ impl Analysis for ConstantFold { egraph.union(id, added); } // to not prune, comment this out - egraph[id].nodes.retain(|n| n.is_leaf()); + // egraph[id].nodes.retain(|n| n.is_leaf()); #[cfg(debug_assertions)] egraph[id].assert_unique_leaves(); diff --git a/tests/simple.rs b/tests/simple.rs index 9261a4c0..49a1fb2e 100644 --- a/tests/simple.rs +++ b/tests/simple.rs @@ -40,6 +40,7 @@ fn simplify(s: &str) -> String { #[test] fn simple_tests() { + let _ = env_logger::builder().is_test(true).try_init(); assert_eq!(simplify("(* 0 42)"), "0"); assert_eq!(simplify("(+ 0 (* 1 foo))"), "foo"); }