From b6846e1ca84ac5ad1ccd19aa4b0cfe5a0f247daf Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 20:57:34 -0500 Subject: [PATCH 01/23] Add data subdirectory to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c80629d..4833498 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ bin/ coverage-html .DS_Store flake.lock +data/ From 02a926c4558f075eff10d928f0b0b18512337bff Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 20:57:34 -0500 Subject: [PATCH 02/23] Add flake.nix for dev environment setup --- flake.nix | 78 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 flake.nix diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..08af1e1 --- /dev/null +++ b/flake.nix @@ -0,0 +1,78 @@ +{ + description = "DataFrame symbolic regression"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-parts.url = "github:hercules-ci/flake-parts"; + }; + + outputs = + inputs@{ flake-parts, ... }: + flake-parts.lib.mkFlake { inherit inputs; } { + systems = [ + "x86_64-linux" + "aarch64-linux" + "aarch64-darwin" + "x86_64-darwin" + ]; + + perSystem = + { + config, + pkgs, + system, + ... + }: + let + pkgs' = import inputs.nixpkgs { + inherit system; + config.allowBroken = true; + }; + + haskellPackages = pkgs'.haskellPackages.extend ( + final: prev: { + # Use the correct hash that Nix gave us + random = pkgs'.haskell.lib.dontCheck ( + prev.callHackageDirect { + pkg = "random"; + ver = "1.3.0"; + sha256 = "sha256-AzXPz8oe+9lAS60nRNFiRTEDrFDB0FPLtbvJ9PB7brM="; + } { } + ); + + srtree = pkgs'.haskell.lib.dontCheck ( + prev.callHackageDirect { + pkg = "srtree"; + ver = "2.0.1.6"; + sha256 = "sha256-D561wnKoFRr/HSviacsbtF5VnJHehG73LOmWM4TxlNs="; # Nix will tell you + } { } + ); + + # Jailbreak packages that depend on old random + time-compat = pkgs'.haskell.lib.doJailbreak prev.time-compat; + splitmix = pkgs'.haskell.lib.doJailbreak prev.splitmix; + + # Mark dataframe as unbroken + dataframe = pkgs'.haskell.lib.markUnbroken prev.dataframe; + } + ); + + in + { + packages = { + default = config.packages.symbolic-regression; + symbolic-regression = haskellPackages.callCabal2nix "symbolic-regression" ./. { }; + }; + + devShells.default = haskellPackages.shellFor { + packages = p: [ config.packages.symbolic-regression ]; + buildInputs = with pkgs'; [ + cabal-install + haskell-language-server + git + fourmolu + ]; + }; + }; + }; +} From 753b8277badbd1f57c9986f2547a4951fc3d8b98 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 20:57:34 -0500 Subject: [PATCH 03/23] Add Example.hs for manual testing --- src/Example.hs | 31 +++++++++++++++++++++++++++++++ symbolic-regression.cabal | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 src/Example.hs diff --git a/src/Example.hs b/src/Example.hs new file mode 100644 index 0000000..ebbecfc --- /dev/null +++ b/src/Example.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Example where + +import qualified DataFrame as D +import qualified DataFrame.Functions as F +import Symbolic.Regression + +main :: IO () +main = do + df <- D.readParquet "./data/mtcars.parquet" + + -- Define mpg as a column reference + let mpg = F.col "mpg" + + exprs <- fit defaultRegressionConfig mpg df + + -- Print each expression (use show or custom formatting) + mapM_ (putStrLn . show) exprs + + -- Create named expressions for different complexity levels + let levels = zipWith (F..=) ["level_1", "level_2", "level_3"] exprs + + -- deriveMany returns a DataFrame, so use 'let' not '<-' + let df' = D.deriveMany levels df + + -- Derive prediction using the best (last) expression + let df'' = D.derive "prediction" (last exprs) df' + + -- Do something with the result + pure () \ No newline at end of file diff --git a/symbolic-regression.cabal b/symbolic-regression.cabal index ffa4a8e..2c94f07 100644 --- a/symbolic-regression.cabal +++ b/symbolic-regression.cabal @@ -19,7 +19,7 @@ common warnings library import: warnings - exposed-modules: Symbolic.Regression + exposed-modules: Symbolic.Regression, Example build-depends: base >= 4 && <5 , dataframe ^>= 0.4 , attoparsec >=0.14.4 && <0.15 From 359f5d3c1473c61567a91e3b8fbbe858535e7a02 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 19:03:39 -0500 Subject: [PATCH 04/23] Refactor: Use GADTs for tagged operations in symbolic regression Replace fragile pattern matching in toNonTerminal with UnaryFunc/BinaryFunc GADTs that pair operation names with implementations. --- src/Example.hs | 6 ++-- src/Symbolic/Regression.hs | 66 +++++++++++++++++++++++++++++++------- 2 files changed, 58 insertions(+), 14 deletions(-) diff --git a/src/Example.hs b/src/Example.hs index ebbecfc..fa4bee4 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -15,8 +15,10 @@ main = do exprs <- fit defaultRegressionConfig mpg df - -- Print each expression (use show or custom formatting) - mapM_ (putStrLn . show) exprs + -- Print each expression + mapM_ + (\(i, e) -> putStrLn $ "Model " ++ show i ++ ": " ++ show e) + (zip [1 ..] exprs) -- Create named expressions for different complexity levels let levels = zipWith (F..=) ["level_1", "level_2", "level_3"] exprs diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index 1bf5c8e..c470872 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -113,6 +113,42 @@ import Algorithm.EqSat.SearchSR import Data.Time.Clock.POSIX import Text.ParseSR +{- | Tagged unary operation for symbolic regression. + +Pairs an operation name with its implementation, ensuring the operation's +identity is preserved throughout the search process. The name is used to +identify the operation in the internal representation. + +@ +UnaryFunc "square" (\\x -> x \`F.pow\` 2) +UnaryFunc "log" log +@ +-} +data UnaryFunc where + UnaryFunc :: String -> (D.Expr Double -> D.Expr Double) -> UnaryFunc + +-- | Extract the name of a unary operation. +getUnaryName :: UnaryFunc -> String +getUnaryName (UnaryFunc name _) = name + +{- | Tagged binary operation for symbolic regression. + +Pairs an operation name with its implementation, ensuring the operation's +identity is preserved throughout the search process. The name is used to +identify the operation in the internal representation. + +@ +BinaryFunc "add" (+) +BinaryFunc "mul" (*) +@ +-} +data BinaryFunc where + BinaryFunc :: String -> (D.Expr Double -> D.Expr Double -> D.Expr Double) -> BinaryFunc + +-- | Extract the name of a binary operation. +getBinaryName :: BinaryFunc -> String +getBinaryName (BinaryFunc name _) = name + {- | Configuration for the symbolic regression algorithm. Use 'defaultRegressionConfig' as a starting point and modify fields as needed: @@ -149,9 +185,9 @@ data RegressionConfig = RegressionConfig -- ^ Probability of crossover between expressions (default: 0.95) , mutationProbability :: Double -- ^ Probability of mutation (default: 0.3) - , unaryFunctions :: [D.Expr Double -> D.Expr Double] + , unaryFunctions :: [UnaryFunc] -- ^ Unary operations to include in the search space (default: @[]@) - , binaryFunctions :: [D.Expr Double -> D.Expr Double -> D.Expr Double] + , binaryFunctions :: [BinaryFunc] {- ^ Binary operations to include in the search space (default: @[(+), (-), (*), (\/)]@) -} @@ -200,8 +236,18 @@ defaultRegressionConfig = , tournamentSize = 3 , crossoverProbability = 0.95 , mutationProbability = 0.3 - , unaryFunctions = [(`F.pow` 2), (`F.pow` 3), log, (1 /)] - , binaryFunctions = [(+), (-), (*), (/)] + , unaryFunctions = + [ UnaryFunc "square" (\x -> x `F.pow` 2) + , UnaryFunc "cube" (\x -> x `F.pow` 3) + , UnaryFunc "log" log + , UnaryFunc "recip" (1 /) + ] + , binaryFunctions = + [ BinaryFunc "add" (+) + , BinaryFunc "sub" (-) + , BinaryFunc "mul" (*) + , BinaryFunc "div" (/) + ] , numParams = -1 , generational = False , simplifyExpressions = True @@ -270,13 +316,9 @@ fit cfg targetColumn df = do Array S Ix2 Double toTarget d = fromLists' Seq (D.columnAsList targetColumn d) :: Array S Ix1 Double nonterminals = - intercalate - "," - ( Prelude.map - (toNonTerminal . (\f -> f (F.col "fake1") (F.col "fake2"))) - (binaryFunctions cfg) - ++ Prelude.map (toNonTerminal . (\f -> f (F.col "fake1"))) (unaryFunctions cfg) - ) + intercalate "," $ + Prelude.map getBinaryName (binaryFunctions cfg) + ++ Prelude.map getUnaryName (unaryFunctions cfg) varnames = intercalate "," @@ -748,4 +790,4 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do ts <- go (n + 1) (max f f') pure (t : ts) else go (n + 1) (max f f') - else go (n + 1) f + else go (n + 1) f \ No newline at end of file From 872da2a215880386babe515aef7e12b64d11cc59 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 19:30:54 -0500 Subject: [PATCH 05/23] Bump DataFrame to 0.4.1 to get prettyprint for Example.hs and clean up example code --- flake.nix | 13 +++++++++---- src/Example.hs | 9 +++++---- symbolic-regression.cabal | 2 +- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/flake.nix b/flake.nix index 08af1e1..a6ceb33 100644 --- a/flake.nix +++ b/flake.nix @@ -44,16 +44,21 @@ prev.callHackageDirect { pkg = "srtree"; ver = "2.0.1.6"; - sha256 = "sha256-D561wnKoFRr/HSviacsbtF5VnJHehG73LOmWM4TxlNs="; # Nix will tell you + sha256 = "sha256-D561wnKoFRr/HSviacsbtF5VnJHehG73LOmWM4TxlNs="; + } { } + ); + + dataframe = pkgs'.haskell.lib.dontCheck ( + prev.callHackageDirect { + pkg = "dataframe"; + ver = "0.4.1.0"; + sha256 = "sha256-u4FrhD+oOnQiGazq/TWuL0PzUvLKKAkz679tZSkiMaY="; } { } ); # Jailbreak packages that depend on old random time-compat = pkgs'.haskell.lib.doJailbreak prev.time-compat; splitmix = pkgs'.haskell.lib.doJailbreak prev.splitmix; - - # Mark dataframe as unbroken - dataframe = pkgs'.haskell.lib.markUnbroken prev.dataframe; } ); diff --git a/src/Example.hs b/src/Example.hs index fa4bee4..e4985b4 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -2,6 +2,7 @@ module Example where +import qualified Data.Text as T import qualified DataFrame as D import qualified DataFrame.Functions as F import Symbolic.Regression @@ -13,15 +14,16 @@ main = do -- Define mpg as a column reference let mpg = F.col "mpg" + let config = defaultRegressionConfig{maxExpressionSize = 7} exprs <- fit defaultRegressionConfig mpg df -- Print each expression mapM_ - (\(i, e) -> putStrLn $ "Model " ++ show i ++ ": " ++ show e) + (\(i, e) -> putStrLn $ "Model " ++ show i ++ ": " ++ D.prettyPrint e) (zip [1 ..] exprs) -- Create named expressions for different complexity levels - let levels = zipWith (F..=) ["level_1", "level_2", "level_3"] exprs + let levels = zipWith (F..=) (map (T.pack . ("level_" ++) . show) [1 ..]) exprs -- deriveMany returns a DataFrame, so use 'let' not '<-' let df' = D.deriveMany levels df @@ -29,5 +31,4 @@ main = do -- Derive prediction using the best (last) expression let df'' = D.derive "prediction" (last exprs) df' - -- Do something with the result - pure () \ No newline at end of file + D.display (D.DisplayOptions 5) df'' diff --git a/symbolic-regression.cabal b/symbolic-regression.cabal index 2c94f07..e7d8407 100644 --- a/symbolic-regression.cabal +++ b/symbolic-regression.cabal @@ -21,7 +21,7 @@ library import: warnings exposed-modules: Symbolic.Regression, Example build-depends: base >= 4 && <5 - , dataframe ^>= 0.4 + , dataframe ^>= 0.4.1 , attoparsec >=0.14.4 && <0.15 , attoparsec-expr >=0.1.1.2 && <0.2 , binary >=0.8.9.1 && <0.9 From 80f756bed1764644329279e5f765ad0f56d08f28 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 19:38:06 -0500 Subject: [PATCH 06/23] Add optional seed to RegressionConfig for testability --- src/Symbolic/Regression.hs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index c470872..643c5a5 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -203,6 +203,8 @@ data RegressionConfig = RegressionConfig -- ^ File path to save e-graph state for later resumption (default: @\"\"@) , loadFrom :: String -- ^ File path to load e-graph state from a previous run (default: @\"\"@) + , seed :: Maybe Int + -- ^ Random seed for reproducibility. 'Nothing' uses system random (default: 'Nothing') } data ValidationConfig = ValidationConfig @@ -254,6 +256,7 @@ defaultRegressionConfig = , maxTime = -1 , dumpTo = "" , loadFrom = "" + , seed = Nothing } {- | Run symbolic regression to discover mathematical expressions that fit the data. @@ -293,7 +296,9 @@ fit :: -- | Pareto front of expressions, ordered simplest to most complex IO [D.Expr Double] fit cfg targetColumn df = do - g <- getStdGen + g <- case seed cfg of + Just s -> pure $ mkStdGen s + Nothing -> getStdGen let (train, validation) = case validationConfig cfg of Nothing -> (df, df) From 6227dee49114a2403baef718481a55284cc19f9b Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 21:45:03 -0500 Subject: [PATCH 07/23] Refactored eGraphGP for readability --- src/Symbolic/Regression.hs | 179 +++++++++++++++++++++++++------------ 1 file changed, 123 insertions(+), 56 deletions(-) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index 643c5a5..1ee1826 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -113,6 +113,20 @@ import Algorithm.EqSat.SearchSR import Data.Time.Clock.POSIX import Text.ParseSR +{- | Lift a random generator action into the e-graph monad. + +Wraps an action in the `Rng IO` monad (`StateT StdGen IO`) so it can be executed +within the `StateT EGraph (StateT StdGen IO)` monad. The current random generator +state is read from the inner `StateT StdGen`, updated after the action runs, and +the result is returned in the e-graph monad. +-} +liftRng :: Rng IO a -> StateT EGraph (StateT StdGen IO) a +liftRng rngAction = do + gen <- lift get -- get StdGen from inner StateT + (result, gen') <- liftIO $ runStateT rngAction gen + lift $ put gen' -- update StdGen + return result + {- | Tagged unary operation for symbolic regression. Pairs an operation name with its implementation, ensuring the operation's @@ -383,26 +397,34 @@ toNonTerminal (BinaryOp "pow" _ _ (e :: Expr b)) = case testEquality (typeRep @b toNonTerminal (UnaryOp "log" _ _) = "log" toNonTerminal e = error ("Unsupported operation: " ++ show e) -egraphGP :: +-- | Initialize the e-graph by loading from file (if specified) and inserting initial terms +initializeEGraph :: RegressionConfig -> - String -> -- nonterminals - String -> -- varnames - [(DataSet, DataSet)] -> - [DataSet] -> - StateT EGraph (StateT StdGen IO) [Fix SRTree] -egraphGP cfg nonterminals varnames dataTrainVals dataTests = do + (Fix SRTree -> RndEGraph (Double, [Array S Ix1 Double])) -> + RndEGraph EClassId -> + RndEGraph (SRTree ()) -> + RndEGraph [EClassId] -> + RndEGraph () +initializeEGraph cfg fitFun _rndTerm _rndNonTerm insertTerms = do unless (null (loadFrom cfg)) $ io (BS.readFile (loadFrom cfg)) >>= \eg -> put (decode eg) - _ <- insertTerms evaluateUnevaluated fitFun - t0 <- io getPOSIXTime - +-- | Create the initial population of expressions +createInitialPopulation :: + RegressionConfig -> + (Fix SRTree -> StateT EGraph (StateT StdGen IO) (Double, [Array S Ix1 Double])) -> + Rng IO (Fix SRTree) -> -- rndTerm + Rng IO (SRTree ()) -> -- rndNonTerm + (Int -> EClassId -> StateT EGraph (StateT StdGen IO) [String]) -> + StateT EGraph (StateT StdGen IO) ([EClassId], [[String]]) +createInitialPopulation cfg fitFun rndTerm rndNonTerm printExpr' = do pop <- replicateM (populationSize cfg) $ do ec <- insertRndExpr (maxExpressionSize cfg) rndTerm rndNonTerm >>= canonical _ <- updateIfNothing fitFun ec pure ec + pop' <- Prelude.mapM canonical pop output <- @@ -410,38 +432,42 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do then forM (Prelude.zip [0 ..] pop') $ uncurry printExpr' else pure [] - let mTime = - if maxTime cfg < 0 then Nothing else Just (fromIntegral $ maxTime cfg - 5) - (_, _, _) <- iterateFor (generations cfg) t0 mTime (pop', output, populationSize cfg) $ \_ (ps', out, curIx) -> do - newPop' <- replicateM (populationSize cfg) (evolve ps') - - out' <- - if showTrace cfg - then forM (Prelude.zip [curIx ..] newPop') $ uncurry printExpr' - else pure [] - - totSz <- gets (Map.size . _eNodeToEClass) - let full = totSz > max maxMem (populationSize cfg) - when full (cleanEGraph >> cleanDB) - - newPop <- - if generational cfg - then Prelude.mapM canonical newPop' - else do - pareto <- - concat <$> forM [1 .. maxExpressionSize cfg] (`getTopFitEClassWithSize` 2) - let remainder = populationSize cfg - length pareto - lft <- - if full - then getTopFitEClassThat remainder (const True) - else pure $ Prelude.take remainder newPop' - Prelude.mapM canonical (pareto <> lft) - pure (newPop, out <> out', curIx + populationSize cfg) + pure (pop', output) +-- | Finalize by saving e-graph to file if specified +finalizeEGraph :: RegressionConfig -> RndEGraph () +finalizeEGraph cfg = unless (null (dumpTo cfg)) $ get >>= (io . BS.writeFile (dumpTo cfg) . encode) + +egraphGP :: + RegressionConfig -> + String -> -- nonterminals + String -> -- varnames + [(DataSet, DataSet)] -> + [DataSet] -> + StateT EGraph (StateT StdGen IO) [Fix SRTree] +egraphGP cfg nonterminals varnames dataTrainVals dataTests = do + -- generate a random EClassId for the terminal expressions + -- rndTerm' :: RndEGraph EClassId + let rndTerm' = insertRndExpr (maxExpressionSize cfg) rndTerm rndNonTerm + + -- generate a random SRTree for non-terminal expressions + -- rndNonTerm' :: RndEGraph (SRTree ()) + let rndNonTerm' = liftRng rndNonTerm + + initializeEGraph cfg fitFun rndTerm' rndNonTerm' insertTerms + + t0 <- io getPOSIXTime + (pop', output) <- createInitialPopulation cfg fitFun rndTerm rndNonTerm printExpr' + + let mTime = if maxTime cfg < 0 then Nothing else Just (fromIntegral $ maxTime cfg - 5) + _ <- runGenerations t0 mTime (pop', output, populationSize cfg) + + finalizeEGraph cfg paretoFront' fitFun (maxExpressionSize cfg) where + -- Configuration and setup maxMem = 2000000 fitFun = fitnessMV @@ -473,6 +499,59 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do isBin (Bin{}) = True isBin _ = False + -- Random generators + rndTerm = do + coin <- toss + if coin || numParams cfg == 0 then randomFrom terms else randomFrom params + + rndNonTerm = randomFrom nonTerms + + -- Main evolution loop + runGenerations t0 mTime initial = + iterateFor (generations cfg) t0 mTime initial generationStep + where + generationStep _ (ps', out, curIx) = do + newPop' <- replicateM (populationSize cfg) (evolve ps') + + out' <- + if showTrace cfg + then forM (Prelude.zip [curIx ..] newPop') $ uncurry printExpr' + else pure [] + + totSz <- gets (Map.size . _eNodeToEClass) + let full = totSz > max maxMem (populationSize cfg) + when full (cleanEGraph >> cleanDB) + + newPop <- + if generational cfg + then Prelude.mapM canonical newPop' + else selectNextPopulation full newPop' + + pure (newPop, out <> out', curIx + populationSize cfg) + + selectNextPopulation full newPop' = do + pareto <- concat <$> forM [1 .. maxExpressionSize cfg] (`getTopFitEClassWithSize` 2) + let remainder = populationSize cfg - length pareto + lft <- + if full + then getTopFitEClassThat remainder (const True) + else pure $ Prelude.take remainder newPop' + Prelude.mapM canonical (pareto <> lft) + + iterateFor 0 _ _ xs _ = pure xs + iterateFor n t0' maxT xs f = do + xs' <- f n xs + t1 <- io getPOSIXTime + let delta = t1 - t0' + maxT' = subtract delta <$> maxT + case maxT' of + Nothing -> iterateFor (n - 1) t1 maxT' xs' f + Just mt -> + if mt <= 0 + then pure xs + else iterateFor (n - 1) t1 maxT' xs' f + + -- E-graph management cleanEGraph = do let nParetos = 10 io . putStrLn $ "cleaning" @@ -488,12 +567,6 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do Nothing -> pure () Just i'' -> insertFitness eId (fromJust $ _fitness i'') (_theta i'') - rndTerm = do - coin <- toss - if coin || numParams cfg == 0 then randomFrom terms else randomFrom params - - rndNonTerm = randomFrom nonTerms - refitChanged = do ids <- (gets (_refits . _eDB) >>= Prelude.mapM canonical . Set.toList) @@ -504,19 +577,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do (f, p) <- fitFun t insertFitness ec f p - iterateFor 0 _ _ xs _ = pure xs - iterateFor n t0' maxT xs f = do - xs' <- f n xs - t1 <- io getPOSIXTime - let delta = t1 - t0' - maxT' = subtract delta <$> maxT - case maxT' of - Nothing -> iterateFor (n - 1) t1 maxT' xs' f - Just mt -> - if mt <= 0 - then pure xs - else iterateFor (n - 1) t1 maxT' xs' f - + -- Evolution operators evolve xs' = do xs <- Prelude.mapM canonical xs' parents' <- tournament xs @@ -541,6 +602,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do combine (p1, p2) = crossover p1 p2 >>= mutate >>= canonical + -- Crossover operator crossover p1 p2 = do sz <- getSize p1 coin <- rnd $ tossBiased (crossoverProbability cfg) @@ -552,6 +614,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do tree <- getSubtree pos 0 Nothing [] cands p1 fromTree myCost (relabel tree) >>= canonical + -- Crossover and mutation helpers getSubtree :: Int -> Int -> @@ -623,6 +686,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do Uni _ t -> (p :) <$> getAllSubClasses t _ -> pure [p] + -- Mutation operator mutate p = do sz <- getSize p coin <- rnd $ tossBiased (mutationProbability cfg) @@ -693,6 +757,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do r' <- getBestExpr r pure . Fix $ Bin op l' r' + -- Output and reporting printExpr' :: Int -> EClassId -> RndEGraph [String] printExpr' ix ec' = do ec <- canonical ec' @@ -766,8 +831,10 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do <> "," <> vals + -- Initialization helpers insertTerms = forM terms (fromTree myCost >=> canonical) + -- Pareto front extraction paretoFront' _ maxSize' = go 1 (-(1.0 / 0.0)) where go :: Int -> Double -> RndEGraph [Fix SRTree] From c3177635f9af1d0827f95a54dbacbfef6ca8fe4c Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 22:01:38 -0500 Subject: [PATCH 08/23] Updated example code name --- src/Example.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Example.hs b/src/Example.hs index e4985b4..c8245c9 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -7,8 +7,8 @@ import qualified DataFrame as D import qualified DataFrame.Functions as F import Symbolic.Regression -main :: IO () -main = do +example_predictMPG :: IO () +example_predictMPG = do df <- D.readParquet "./data/mtcars.parquet" -- Define mpg as a column reference From 61046e57a4e44f425d6bced604c176e8d191fb7b Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 22:05:27 -0500 Subject: [PATCH 09/23] Cleanup --- src/Example.hs | 1 - src/Symbolic/Regression.hs | 21 --------------------- 2 files changed, 22 deletions(-) diff --git a/src/Example.hs b/src/Example.hs index c8245c9..70e8b77 100644 --- a/src/Example.hs +++ b/src/Example.hs @@ -14,7 +14,6 @@ example_predictMPG = do -- Define mpg as a column reference let mpg = F.col "mpg" - let config = defaultRegressionConfig{maxExpressionSize = 7} exprs <- fit defaultRegressionConfig mpg df -- Print each expression diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index 1ee1826..8f967a5 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -105,8 +105,6 @@ import Data.SRTree.Datasets import qualified Data.SRTree.Internal as SI import Data.SRTree.Print import Data.SRTree.Random -import Data.Type.Equality (TestEquality (testEquality), type (:~:) (Refl)) -import Type.Reflection (typeRep) import Algorithm.EqSat (runEqSat) import Algorithm.EqSat.SearchSR @@ -378,25 +376,6 @@ toExpr cols (Fix (Bin op left right)) = case op of treeOp -> error ("UNIMPLEMENTED OPERATION: " ++ show treeOp) toExpr _ _ = error "UNIMPLEMENTED" -toNonTerminal :: D.Expr Double -> String -toNonTerminal (BinaryOp "add" _ _ _) = "add" -toNonTerminal (BinaryOp "sub" _ _ _) = "sub" -toNonTerminal (BinaryOp "mult" _ _ _) = "mul" -toNonTerminal (BinaryOp "divide" _ (Lit n :: Expr b) _) = case testEquality (typeRep @b) (typeRep @Double) of - Nothing -> error "[Internal Error] - Reciprocal of non-double" - Just Refl -> case n of - 1 -> "recip" - _ -> error "Unknown reciprocal" -toNonTerminal (BinaryOp "divide" _ _ _) = "div" -toNonTerminal (BinaryOp "pow" _ _ (e :: Expr b)) = case testEquality (typeRep @b) (typeRep @Int) of - Nothing -> error "Impossible: Raised to non-int power" - Just Refl -> case e of - (Lit 2) -> "square" - (Lit 3) -> "cube" - _ -> error "Unknown power" -toNonTerminal (UnaryOp "log" _ _) = "log" -toNonTerminal e = error ("Unsupported operation: " ++ show e) - -- | Initialize the e-graph by loading from file (if specified) and inserting initial terms initializeEGraph :: RegressionConfig -> From cbdd3b97a92cb86108db1d088f4c1599091020a7 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 22:16:28 -0500 Subject: [PATCH 10/23] Removed unnecessary GADTs for UnaryFunc and BinaryFunc --- src/Symbolic/Regression.hs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index 8f967a5..f0dde53 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -136,8 +136,10 @@ UnaryFunc "square" (\\x -> x \`F.pow\` 2) UnaryFunc "log" log @ -} -data UnaryFunc where - UnaryFunc :: String -> (D.Expr Double -> D.Expr Double) -> UnaryFunc +data UnaryFunc = UnaryFunc String (D.Expr Double -> D.Expr Double) + +-- data UnaryFunc where +-- UnaryFunc :: String -> (D.Expr Double -> D.Expr Double) -> UnaryFunc -- | Extract the name of a unary operation. getUnaryName :: UnaryFunc -> String @@ -154,8 +156,7 @@ BinaryFunc "add" (+) BinaryFunc "mul" (*) @ -} -data BinaryFunc where - BinaryFunc :: String -> (D.Expr Double -> D.Expr Double -> D.Expr Double) -> BinaryFunc +data BinaryFunc = BinaryFunc String (D.Expr Double -> D.Expr Double -> D.Expr Double) -- | Extract the name of a binary operation. getBinaryName :: BinaryFunc -> String From 07aad9be1e776a707ad08f2b1e86ecd25917fcda Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Tue, 10 Feb 2026 22:26:30 -0500 Subject: [PATCH 11/23] Defined types for unary and binary operation names --- src/Symbolic/Regression.hs | 61 +++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index f0dde53..637b9c0 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -132,18 +132,29 @@ identity is preserved throughout the search process. The name is used to identify the operation in the internal representation. @ -UnaryFunc "square" (\\x -> x \`F.pow\` 2) -UnaryFunc "log" log +UnaryFunc USquare (\\x -> x \`F.pow\` 2) +UnaryFunc ULog log @ -} -data UnaryFunc = UnaryFunc String (D.Expr Double -> D.Expr Double) - --- data UnaryFunc where --- UnaryFunc :: String -> (D.Expr Double -> D.Expr Double) -> UnaryFunc +data UnaryOpName + = ULog + | USquare + | UCube + | URecip + deriving (Eq) + +instance Show UnaryOpName where + show ULog = "log" + show USquare = "square" + show UCube = "cube" + show URecip = "recip" + +data UnaryFunc + = UnaryFunc UnaryOpName (D.Expr Double -> D.Expr Double) -- | Extract the name of a unary operation. getUnaryName :: UnaryFunc -> String -getUnaryName (UnaryFunc name _) = name +getUnaryName (UnaryFunc name _) = show name {- | Tagged binary operation for symbolic regression. @@ -156,11 +167,27 @@ BinaryFunc "add" (+) BinaryFunc "mul" (*) @ -} -data BinaryFunc = BinaryFunc String (D.Expr Double -> D.Expr Double -> D.Expr Double) +data BinaryOpName + = BAdd + | BSub + | BMul + | BDiv + | BPow + deriving (Eq) + +instance Show BinaryOpName where + show BAdd = "add" + show BSub = "sub" + show BMul = "mul" + show BDiv = "div" + show BPow = "pow" + +data BinaryFunc + = BinaryFunc BinaryOpName (D.Expr Double -> D.Expr Double -> D.Expr Double) -- | Extract the name of a binary operation. getBinaryName :: BinaryFunc -> String -getBinaryName (BinaryFunc name _) = name +getBinaryName (BinaryFunc name _) = show name {- | Configuration for the symbolic regression algorithm. @@ -252,16 +279,16 @@ defaultRegressionConfig = , crossoverProbability = 0.95 , mutationProbability = 0.3 , unaryFunctions = - [ UnaryFunc "square" (\x -> x `F.pow` 2) - , UnaryFunc "cube" (\x -> x `F.pow` 3) - , UnaryFunc "log" log - , UnaryFunc "recip" (1 /) + [ UnaryFunc USquare (\x -> x `F.pow` 2) + , UnaryFunc UCube (\x -> x `F.pow` 3) + , UnaryFunc ULog log + , UnaryFunc URecip (1 /) ] , binaryFunctions = - [ BinaryFunc "add" (+) - , BinaryFunc "sub" (-) - , BinaryFunc "mul" (*) - , BinaryFunc "div" (/) + [ BinaryFunc BAdd (+) + , BinaryFunc BSub (-) + , BinaryFunc BMul (*) + , BinaryFunc BDiv (/) ] , numParams = -1 , generational = False From b347631d5da1d8e03f6d5043361594dd06a7f45b Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Wed, 11 Feb 2026 22:54:33 -0500 Subject: [PATCH 12/23] Update haddock documentation for BinaryFunc, add convenience functions for UnaryFunc and BinaryFunc, use convenience functions for UnaryFunc and BinaryFunc --- src/Symbolic/Regression.hs | 59 +++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index 637b9c0..9421637 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -136,6 +136,21 @@ UnaryFunc USquare (\\x -> x \`F.pow\` 2) UnaryFunc ULog log @ -} +data UnaryFunc = UnaryFunc UnaryOpName (D.Expr Double -> D.Expr Double) + +-- | Convenience constructors for UnaryFunc +uLog :: (Expr Double -> Expr Double) -> UnaryFunc +uLog = UnaryFunc ULog + +uSquare :: (Expr Double -> Expr Double) -> UnaryFunc +uSquare = UnaryFunc USquare + +uCube :: (Expr Double -> Expr Double) -> UnaryFunc +uCube = UnaryFunc UCube + +uRecip :: (Expr Double -> Expr Double) -> UnaryFunc +uRecip = UnaryFunc URecip + data UnaryOpName = ULog | USquare @@ -149,9 +164,6 @@ instance Show UnaryOpName where show UCube = "cube" show URecip = "recip" -data UnaryFunc - = UnaryFunc UnaryOpName (D.Expr Double -> D.Expr Double) - -- | Extract the name of a unary operation. getUnaryName :: UnaryFunc -> String getUnaryName (UnaryFunc name _) = show name @@ -163,8 +175,8 @@ identity is preserved throughout the search process. The name is used to identify the operation in the internal representation. @ -BinaryFunc "add" (+) -BinaryFunc "mul" (*) +BinaryFunc BAdd (+) +BinaryFunc BMul (*) @ -} data BinaryOpName @@ -182,13 +194,28 @@ instance Show BinaryOpName where show BDiv = "div" show BPow = "pow" -data BinaryFunc - = BinaryFunc BinaryOpName (D.Expr Double -> D.Expr Double -> D.Expr Double) +data BinaryFunc = BinaryFunc BinaryOpName (D.Expr Double -> D.Expr Double -> D.Expr Double) -- | Extract the name of a binary operation. getBinaryName :: BinaryFunc -> String getBinaryName (BinaryFunc name _) = show name +-- | Convenience constructors for BinaryFunc +bAdd :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bAdd = BinaryFunc BAdd + +bSub :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bSub = BinaryFunc BSub + +bMul :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bMul = BinaryFunc BMul + +bDiv :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bDiv = BinaryFunc BDiv + +bPow :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bPow = BinaryFunc BPow + {- | Configuration for the symbolic regression algorithm. Use 'defaultRegressionConfig' as a starting point and modify fields as needed: @@ -279,16 +306,16 @@ defaultRegressionConfig = , crossoverProbability = 0.95 , mutationProbability = 0.3 , unaryFunctions = - [ UnaryFunc USquare (\x -> x `F.pow` 2) - , UnaryFunc UCube (\x -> x `F.pow` 3) - , UnaryFunc ULog log - , UnaryFunc URecip (1 /) + [ uSquare (`F.pow` 2) + , uCube (`F.pow` 3) + , uLog log + , uRecip (1 /) ] , binaryFunctions = - [ BinaryFunc BAdd (+) - , BinaryFunc BSub (-) - , BinaryFunc BMul (*) - , BinaryFunc BDiv (/) + [ bAdd (+) + , bSub (-) + , bMul (*) + , bDiv (/) ] , numParams = -1 , generational = False @@ -869,4 +896,4 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do ts <- go (n + 1) (max f f') pure (t : ts) else go (n + 1) (max f f') - else go (n + 1) f \ No newline at end of file + else go (n + 1) f From cd25a2f7cc120d6476075e87953ef23dd465cf9a Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Wed, 11 Feb 2026 23:30:40 -0500 Subject: [PATCH 13/23] Update README.md with fixed example and optional nix instructions --- README.md | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e2c04bb..504b2f8 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,15 @@ symbolic-regression integrates symbolic regression capabilities into a DataFrame ```haskell ghci> import qualified DataFrame as D -ghci> import DataFrame.Functions ((.=)) +ghci> import qualified DataFrame.Functions as F ghci> import Symbolic.Regression -- Load your data ghci> df <- D.readParquet "./data/mtcars.parquet" +-- Define mpg as a column reference +ghci> let mpg = F.col "mpg" + -- Run symbolic regression to predict 'mpg' -- NOTE: ALL COLUMNS MUST BE CONVERTED TO DOUBLE FIRST -- e.g df' = D.derive "some_column" (F.toDouble (F.col @Int "some_column")) df @@ -24,19 +27,20 @@ ghci> df <- D.readParquet "./data/mtcars.parquet" ghci> exprs <- fit defaultRegressionConfig mpg df -- View discovered expressions (Pareto front from simplest to most complex) -ghci> map D.prettyPrint exprs --- [ qsec, --- , 57.33 / wt --- , 10.75 + (1557.67 / disp)] +ghci> mapM_ (\(i, e) -> putStrLn $ "Model " ++ show i ++ ": " ++ D.prettyPrint e) (zip [1..] exprs) + +-- Create named expressions for different complexity levels +ghci> import qualified Data.Text as T +ghci> let levels = zipWith (F..=) (map (T.pack . ("level_" ++) . show) [1..]) exprs --- Create named expressions that we'll use in a dataframe. -ghci> levels = zipWith (.=) ["level_1", "level_2", "level_3"] exprs +-- Show the various predictions in our dataframe +ghci> let df' = D.deriveMany levels df --- Show the various predictions in our dataframe. -ghci> D.deriveMany levels df +-- Or pick the best one for prediction +ghci> let df'' = D.derive "prediction" (last exprs) df' --- Or pick the best one -ghci> D.derive "prediction" (last exprs) df +-- Display the results +ghci> D.display (D.DisplayOptions 5) df'' ``` ## Configuration @@ -102,3 +106,14 @@ To install symbolic-regression you'll need: * libz: `sudo apt install libz-dev` * libnlopt: `sudo apt install libnlopt-dev` * libgmp: `sudo apt install libgmp-dev` + +### Nix Development Environment +For Nix users with flakes enabled: + +```bash +git clone +cd symbolic-regression +nix develop -c cabal repl +``` + +Then follow the Quick Start example above. From bc952a86bde9f816896ee8f15135cafe65023c11 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Wed, 11 Feb 2026 23:30:57 -0500 Subject: [PATCH 14/23] Remove unrunnable Example.hs --- src/Example.hs | 33 --------------------------------- 1 file changed, 33 deletions(-) delete mode 100644 src/Example.hs diff --git a/src/Example.hs b/src/Example.hs deleted file mode 100644 index 70e8b77..0000000 --- a/src/Example.hs +++ /dev/null @@ -1,33 +0,0 @@ -{-# LANGUAGE OverloadedStrings #-} - -module Example where - -import qualified Data.Text as T -import qualified DataFrame as D -import qualified DataFrame.Functions as F -import Symbolic.Regression - -example_predictMPG :: IO () -example_predictMPG = do - df <- D.readParquet "./data/mtcars.parquet" - - -- Define mpg as a column reference - let mpg = F.col "mpg" - - exprs <- fit defaultRegressionConfig mpg df - - -- Print each expression - mapM_ - (\(i, e) -> putStrLn $ "Model " ++ show i ++ ": " ++ D.prettyPrint e) - (zip [1 ..] exprs) - - -- Create named expressions for different complexity levels - let levels = zipWith (F..=) (map (T.pack . ("level_" ++) . show) [1 ..]) exprs - - -- deriveMany returns a DataFrame, so use 'let' not '<-' - let df' = D.deriveMany levels df - - -- Derive prediction using the best (last) expression - let df'' = D.derive "prediction" (last exprs) df' - - D.display (D.DisplayOptions 5) df'' From f9aa0ac3e01d1618098cb639bc2591b7e6425001 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Wed, 11 Feb 2026 23:32:51 -0500 Subject: [PATCH 15/23] Update .cabal file: Add dependencies to test-suite stanza and remove 'Example' module from library stanza's exposed modules --- symbolic-regression.cabal | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/symbolic-regression.cabal b/symbolic-regression.cabal index e7d8407..5716eb3 100644 --- a/symbolic-regression.cabal +++ b/symbolic-regression.cabal @@ -19,7 +19,7 @@ common warnings library import: warnings - exposed-modules: Symbolic.Regression, Example + exposed-modules: Symbolic.Regression build-depends: base >= 4 && <5 , dataframe ^>= 0.4.1 , attoparsec >=0.14.4 && <0.15 @@ -62,4 +62,9 @@ test-suite symbolic-regression-test main-is: Main.hs build-depends: base >= 4 && <5, - symbolic-regression + symbolic-regression, + HUnit >= 1.6 && < 1.7, + dataframe ^>= 0.4.1, + random >= 1.2 && < 1.4, + text >= 2.0 && < 3, + time >= 1.12 && < 2 From cdda61a00ada6d55215f37a70a7bd58b72b41b77 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Wed, 11 Feb 2026 23:35:41 -0500 Subject: [PATCH 16/23] Export functions for UnaryFunc and BinaryFunc from Symbolic.Regression --- src/Symbolic/Regression.hs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index 9421637..a4f9e7e 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -54,6 +54,21 @@ module Symbolic.Regression ( RegressionConfig (..), ValidationConfig (..), defaultRegressionConfig, + + -- * Function constructors and utilities + UnaryFunc, + BinaryFunc, + uLog, + uSquare, + uCube, + uRecip, + bAdd, + bSub, + bMul, + bDiv, + bPow, + getUnaryName, + getBinaryName, ) where import Control.Exception (throw) From 421bfde0509a0943fb8c2ae95be13fae7b5963ae Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Wed, 11 Feb 2026 23:59:41 -0500 Subject: [PATCH 17/23] Add basic tests to validate discovery of constant, linear, quadratic, and multi-variable models --- test/Main.hs | 146 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 145 insertions(+), 1 deletion(-) diff --git a/test/Main.hs b/test/Main.hs index 3e2059e..1f10c45 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,4 +1,148 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeApplications #-} + module Main (main) where +import qualified DataFrame as D +import qualified DataFrame.Functions as F +import Symbolic.Regression +import Test.HUnit + +-- | Test linear formula regression (y = 2x + 1) +testLinearFormula :: Test +testLinearFormula = TestCase $ do + let xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] :: [Double] + ys = map (\x -> 2 * x + 1) xs -- Known formula: y = 2x + 1 + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultRegressionConfig + { generations = 30 + , populationSize = 80 + , showTrace = False + , maxExpressionSize = 4 + } + + exprs <- fit cfg target df + + assertBool "Should find at least one expression" (not $ null exprs) + + + -- Test accuracy: the best expression should closely match y = 2x + 1 + let bestExpr = last exprs -- Most complex/accurate should be last + predDf = D.derive "predicted" bestExpr df + actualVals = D.columnAsList (F.col @Double "y") predDf :: [Double] + predVals = D.columnAsList (F.col @Double "predicted") predDf :: [Double] + errorValues = zipWith (\a p -> abs (a - p)) actualVals predVals + maxError = maximum errorValues + + -- For perfect linear data, should have reasonably low error + assertBool ("Max error should be < 1.0, got: " ++ show maxError) (maxError < 1.0) + +-- | Test constant formula regression (y = 42) +testConstantFormula :: Test +testConstantFormula = TestCase $ do + let xs = [1.0, 2.0, 3.0, 4.0, 5.0] :: [Double] + ys = replicate 5 42.0 :: [Double] -- Constant target: y = 42 + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultRegressionConfig + { generations = 3 + , populationSize = 10 + , showTrace = False + , maxExpressionSize = 1 + } + + exprs <- fit cfg target df + assertBool "Should handle constant target without crashing" (not $ null exprs) + + + -- Test that it can predict the constant correctly + let bestExpr = last exprs + predDf = D.derive "predicted" bestExpr df + predVals = D.columnAsList (F.col @Double "predicted") predDf :: [Double] + errorValues = map (\p -> abs (42.0 - p)) predVals + maxError = maximum errorValues + + assertBool ("Should predict constant accurately, max error: " ++ show maxError) (maxError < 1.0) + +-- | Test quadratic formula regression (y = x² + 2x + 3) +testQuadraticFormula :: Test +testQuadraticFormula = TestCase $ do + let xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] :: [Double] + ys = map (\x -> x * x + 2 * x + 3) xs -- Known formula: y = x² + 2x + 3 + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultRegressionConfig + { generations = 50 + , populationSize = 150 + , showTrace = False + , maxExpressionSize = 8 -- Need more complexity for quadratic + } + + exprs <- fit cfg target df + + assertBool "Should find at least one expression" (not $ null exprs) + + + -- Test accuracy: the best expression should closely match y = x² + 2x + 3 + let bestExpr = last exprs -- Most complex/accurate should be last + predDf = D.derive "predicted" bestExpr df + actualVals = D.columnAsList (F.col @Double "y") predDf :: [Double] + predVals = D.columnAsList (F.col @Double "predicted") predDf :: [Double] + errorValues = zipWith (\a p -> abs (a - p)) actualVals predVals + maxError = maximum errorValues + + -- Quadratic might need slightly higher tolerance than linear + assertBool ("Max error should be < 0.5, got: " ++ show maxError) (maxError < 0.5) + +-- | Test two-variable formula regression (z = x + y) +testTwoVariableFormula :: Test +testTwoVariableFormula = TestCase $ do + let xs = [1.0, 2.0, 3.0, 4.0, 5.0] :: [Double] + ys = [2.0, 3.0, 1.0, 4.0, 2.0] :: [Double] + zs = zipWith (+) xs ys -- Known formula: z = x + y + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys), ("z", D.fromList zs)] + target = F.col @Double "z" + cfg = + defaultRegressionConfig + { generations = 20 + , populationSize = 60 + , showTrace = False + , maxExpressionSize = 3 + } + + exprs <- fit cfg target df + + assertBool "Should find at least one expression" (not $ null exprs) + + + -- Test accuracy: the best expression should closely match z = x + y + let bestExpr = last exprs -- Most complex/accurate should be last + predDf = D.derive "predicted" bestExpr df + actualVals = D.columnAsList (F.col @Double "z") predDf :: [Double] + predVals = D.columnAsList (F.col @Double "predicted") predDf :: [Double] + errorValues = zipWith (\a p -> abs (a - p)) actualVals predVals + maxError = maximum errorValues + + -- Two-variable linear should be quite accurate + assertBool ("Max error should be < 0.1, got: " ++ show maxError) (maxError < 0.1) + +allTests :: Test +allTests = + TestList + [ "Linear Formula" ~: testLinearFormula + , "Constant Formula" ~: testConstantFormula + , "Quadratic Formula" ~: testQuadraticFormula + , "Two-Variable Formula" ~: testTwoVariableFormula + ] + main :: IO () -main = putStrLn "Test suite not yet implemented." +main = do + putStrLn "Running symbolic regression tests..." + testCounts <- runTestTT allTests + putStrLn $ "Tests run: " ++ show (cases testCounts) + putStrLn $ "Failures: " ++ show (failures testCounts) + putStrLn $ "Errors: " ++ show (errors testCounts) From 015fb6bc7916520842c282d45d53266d8fcd8dea Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Thu, 12 Feb 2026 00:13:30 -0500 Subject: [PATCH 18/23] Update UnaryFunc/BinaryFunc haddocks to use convenience constructors --- src/Symbolic/Regression.hs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index a4f9e7e..1e42c9c 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -147,8 +147,10 @@ identity is preserved throughout the search process. The name is used to identify the operation in the internal representation. @ -UnaryFunc USquare (\\x -> x \`F.pow\` 2) -UnaryFunc ULog log +myUnaryFunctions = [ uSquare (\`F.pow\` 2) + , uLog log + , uRecip (1 /) + ] @ -} data UnaryFunc = UnaryFunc UnaryOpName (D.Expr Double -> D.Expr Double) @@ -190,8 +192,11 @@ identity is preserved throughout the search process. The name is used to identify the operation in the internal representation. @ -BinaryFunc BAdd (+) -BinaryFunc BMul (*) +myBinaryFunctions = [ bAdd (+) + , bSub (-) + , bMul (*) + , bDiv (/) + ] @ -} data BinaryOpName From 741bae0cd4b66c5b942a14eabd59c741ef46cf3e Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Thu, 12 Feb 2026 00:19:01 -0500 Subject: [PATCH 19/23] Add nixfmt to flake devShell --- flake.nix | 1 + 1 file changed, 1 insertion(+) diff --git a/flake.nix b/flake.nix index a6ceb33..705e083 100644 --- a/flake.nix +++ b/flake.nix @@ -76,6 +76,7 @@ haskell-language-server git fourmolu + nixfmt ]; }; }; From 206b3300262201164caa567a0b2cac296f24fda3 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Thu, 12 Feb 2026 00:28:20 -0500 Subject: [PATCH 20/23] Fix haddock example in Symbolic.Regression --- src/Symbolic/Regression.hs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index 1e42c9c..a98f8ad 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -19,12 +19,14 @@ off complexity and accuracy. @ import qualified DataFrame as D -import DataFrame.Functions ((.=)) +import qualified DataFrame.Functions as F import Symbolic.Regression -- Load your data df <- D.readParquet "./data/mtcars.parquet" +let mpg = F.col "mpg" + -- Run symbolic regression to predict 'mpg' exprs <- fit defaultRegressionConfig mpg df From 7ecdcf2615f87e2dcbde84925c6312a3fe7289cd Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Thu, 12 Feb 2026 02:37:55 -0500 Subject: [PATCH 21/23] Define default regression config for tests with disabled tracing and defined seed for deterministic rng --- test/Main.hs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/Main.hs b/test/Main.hs index 1f10c45..61643a8 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -8,6 +8,14 @@ import qualified DataFrame.Functions as F import Symbolic.Regression import Test.HUnit +-- | Default regression config for testing: disabled tracing and defined seed for deterministic rng +defaultTestConfig :: RegressionConfig +defaultTestConfig = + defaultRegressionConfig + { showTrace = False + , seed = Just 0 + } + -- | Test linear formula regression (y = 2x + 1) testLinearFormula :: Test testLinearFormula = TestCase $ do @@ -16,17 +24,15 @@ testLinearFormula = TestCase $ do df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] target = F.col @Double "y" cfg = - defaultRegressionConfig + defaultTestConfig { generations = 30 , populationSize = 80 - , showTrace = False , maxExpressionSize = 4 } exprs <- fit cfg target df assertBool "Should find at least one expression" (not $ null exprs) - -- Test accuracy: the best expression should closely match y = 2x + 1 let bestExpr = last exprs -- Most complex/accurate should be last @@ -47,16 +53,14 @@ testConstantFormula = TestCase $ do df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] target = F.col @Double "y" cfg = - defaultRegressionConfig + defaultTestConfig { generations = 3 , populationSize = 10 - , showTrace = False , maxExpressionSize = 1 } exprs <- fit cfg target df assertBool "Should handle constant target without crashing" (not $ null exprs) - -- Test that it can predict the constant correctly let bestExpr = last exprs @@ -75,17 +79,15 @@ testQuadraticFormula = TestCase $ do df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] target = F.col @Double "y" cfg = - defaultRegressionConfig + defaultTestConfig { generations = 50 , populationSize = 150 - , showTrace = False , maxExpressionSize = 8 -- Need more complexity for quadratic } exprs <- fit cfg target df assertBool "Should find at least one expression" (not $ null exprs) - -- Test accuracy: the best expression should closely match y = x² + 2x + 3 let bestExpr = last exprs -- Most complex/accurate should be last @@ -107,17 +109,15 @@ testTwoVariableFormula = TestCase $ do df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys), ("z", D.fromList zs)] target = F.col @Double "z" cfg = - defaultRegressionConfig + defaultTestConfig { generations = 20 , populationSize = 60 - , showTrace = False , maxExpressionSize = 3 } exprs <- fit cfg target df assertBool "Should find at least one expression" (not $ null exprs) - -- Test accuracy: the best expression should closely match z = x + y let bestExpr = last exprs -- Most complex/accurate should be last From 6a46c1c179cac5ad1fd0799a4415d9bcc7ec8698 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Thu, 12 Feb 2026 03:07:34 -0500 Subject: [PATCH 22/23] Refactor tests, use deriveWithExpr when calculating maxAbsError between target and predicted expression --- test/Main.hs | 218 ++++++++++++++++++++++++++------------------------- 1 file changed, 111 insertions(+), 107 deletions(-) diff --git a/test/Main.hs b/test/Main.hs index 61643a8..3913861 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -8,7 +8,7 @@ import qualified DataFrame.Functions as F import Symbolic.Regression import Test.HUnit --- | Default regression config for testing: disabled tracing and defined seed for deterministic rng +-- Default regression config for deterministic testing defaultTestConfig :: RegressionConfig defaultTestConfig = defaultRegressionConfig @@ -16,119 +16,123 @@ defaultTestConfig = , seed = Just 0 } --- | Test linear formula regression (y = 2x + 1) -testLinearFormula :: Test -testLinearFormula = TestCase $ do - let xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] :: [Double] - ys = map (\x -> 2 * x + 1) xs -- Known formula: y = 2x + 1 - df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] - target = F.col @Double "y" - cfg = - defaultTestConfig - { generations = 30 - , populationSize = 80 - , maxExpressionSize = 4 - } - - exprs <- fit cfg target df +-- Compute maximum absolute error between target and predicted expression +maxAbsError :: D.Expr Double -> D.Expr Double -> D.DataFrame -> Double +maxAbsError target predicted df = + let (predColExpr, df') = D.deriveWithExpr "__test_predicted" predicted df + actualVals = D.columnAsList target df' :: [Double] + predVals = D.columnAsList predColExpr df' :: [Double] + in maximum (zipWith (\a p -> abs (a - p)) actualVals predVals) - assertBool "Should find at least one expression" (not $ null exprs) +{- - -- Test accuracy: the best expression should closely match y = 2x + 1 - let bestExpr = last exprs -- Most complex/accurate should be last - predDf = D.derive "predicted" bestExpr df - actualVals = D.columnAsList (F.col @Double "y") predDf :: [Double] - predVals = D.columnAsList (F.col @Double "predicted") predDf :: [Double] - errorValues = zipWith (\a p -> abs (a - p)) actualVals predVals - maxError = maximum errorValues +(predicted, predDf) = D.deriveWithExpr "predicted" bestExpr +predValues = D.columnAsList predicted predDf - -- For perfect linear data, should have reasonably low error - assertBool ("Max error should be < 1.0, got: " ++ show maxError) (maxError < 1.0) - --- | Test constant formula regression (y = 42) -testConstantFormula :: Test -testConstantFormula = TestCase $ do - let xs = [1.0, 2.0, 3.0, 4.0, 5.0] :: [Double] - ys = replicate 5 42.0 :: [Double] -- Constant target: y = 42 - df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] - target = F.col @Double "y" - cfg = - defaultTestConfig - { generations = 3 - , populationSize = 10 - , maxExpressionSize = 1 - } +-} +-- Extract best evolved expression (last in population history) +bestExpression :: RegressionConfig -> D.Expr Double -> D.DataFrame -> IO (Maybe (D.Expr Double)) +bestExpression cfg target df = do exprs <- fit cfg target df - assertBool "Should handle constant target without crashing" (not $ null exprs) - - -- Test that it can predict the constant correctly - let bestExpr = last exprs - predDf = D.derive "predicted" bestExpr df - predVals = D.columnAsList (F.col @Double "predicted") predDf :: [Double] - errorValues = map (\p -> abs (42.0 - p)) predVals - maxError = maximum errorValues - - assertBool ("Should predict constant accurately, max error: " ++ show maxError) (maxError < 1.0) - --- | Test quadratic formula regression (y = x² + 2x + 3) + pure $ if null exprs then Nothing else Just (last exprs) + +-- High-level regression verification used by tests +verifyRegressionAccuracy :: + String -> + RegressionConfig -> + D.Expr Double -> + D.DataFrame -> + Double -> + Assertion +verifyRegressionAccuracy label cfg target df tolerance = do + mBest <- bestExpression cfg target df + case mBest of + Nothing -> + assertFailure noResultMsg + Just bestExpr -> do + let err = maxAbsError target bestExpr df + assertBool (errorMsg err) (err < tolerance) + where + noResultMsg = + label ++ ": regression returned no expressions" + + errorMsg err = + label + ++ ": expected max abs error < " + ++ show tolerance + ++ ", got " + ++ show err + +-- Linear: y = 2x + 1 +testLinearFormula :: Test +testLinearFormula = + TestCase $ + let xs = [1.0 .. 10.0] :: [Double] + ys = map (\x -> 2 * x + 1) xs + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultTestConfig + { generations = 30 + , populationSize = 80 + , maxExpressionSize = 4 + } + in verifyRegressionAccuracy "Linear Formula" cfg target df 1.0 + +-- Constant: y = 42 +testConstantFormula :: Test +testConstantFormula = + TestCase $ + let xs = [1.0, 2.0, 3.0, 4.0, 5.0] :: [Double] + ys = replicate 5 (42.0 :: Double) + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultTestConfig + { generations = 3 + , populationSize = 10 + , maxExpressionSize = 1 + } + in verifyRegressionAccuracy "Constant Formula" cfg target df 1.0 + +-- Quadratic: y = x² + 2x + 3 testQuadraticFormula :: Test -testQuadraticFormula = TestCase $ do - let xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] :: [Double] - ys = map (\x -> x * x + 2 * x + 3) xs -- Known formula: y = x² + 2x + 3 - df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] - target = F.col @Double "y" - cfg = - defaultTestConfig - { generations = 50 - , populationSize = 150 - , maxExpressionSize = 8 -- Need more complexity for quadratic - } - - exprs <- fit cfg target df - - assertBool "Should find at least one expression" (not $ null exprs) - - -- Test accuracy: the best expression should closely match y = x² + 2x + 3 - let bestExpr = last exprs -- Most complex/accurate should be last - predDf = D.derive "predicted" bestExpr df - actualVals = D.columnAsList (F.col @Double "y") predDf :: [Double] - predVals = D.columnAsList (F.col @Double "predicted") predDf :: [Double] - errorValues = zipWith (\a p -> abs (a - p)) actualVals predVals - maxError = maximum errorValues - - -- Quadratic might need slightly higher tolerance than linear - assertBool ("Max error should be < 0.5, got: " ++ show maxError) (maxError < 0.5) - --- | Test two-variable formula regression (z = x + y) +testQuadraticFormula = + TestCase $ + let xs = [1.0 .. 6.0] :: [Double] + ys = map (\x -> x * x + 2 * x + 3) xs + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultTestConfig + { generations = 50 + , populationSize = 150 + , maxExpressionSize = 8 + } + in verifyRegressionAccuracy "Quadratic Formula" cfg target df 0.5 + +-- Two variable: z = x + y testTwoVariableFormula :: Test -testTwoVariableFormula = TestCase $ do - let xs = [1.0, 2.0, 3.0, 4.0, 5.0] :: [Double] - ys = [2.0, 3.0, 1.0, 4.0, 2.0] :: [Double] - zs = zipWith (+) xs ys -- Known formula: z = x + y - df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys), ("z", D.fromList zs)] - target = F.col @Double "z" - cfg = - defaultTestConfig - { generations = 20 - , populationSize = 60 - , maxExpressionSize = 3 - } - - exprs <- fit cfg target df - - assertBool "Should find at least one expression" (not $ null exprs) - - -- Test accuracy: the best expression should closely match z = x + y - let bestExpr = last exprs -- Most complex/accurate should be last - predDf = D.derive "predicted" bestExpr df - actualVals = D.columnAsList (F.col @Double "z") predDf :: [Double] - predVals = D.columnAsList (F.col @Double "predicted") predDf :: [Double] - errorValues = zipWith (\a p -> abs (a - p)) actualVals predVals - maxError = maximum errorValues - - -- Two-variable linear should be quite accurate - assertBool ("Max error should be < 0.1, got: " ++ show maxError) (maxError < 0.1) +testTwoVariableFormula = + TestCase $ + let xs = [1.0, 2.0, 3.0, 4.0, 5.0] :: [Double] + ys = [2.0, 3.0, 1.0, 4.0, 2.0] :: [Double] + zs = zipWith (+) xs ys + df = + D.fromNamedColumns + [ ("x", D.fromList xs) + , ("y", D.fromList ys) + , ("z", D.fromList zs) + ] + target = F.col @Double "z" + cfg = + defaultTestConfig + { generations = 20 + , populationSize = 60 + , maxExpressionSize = 3 + } + in verifyRegressionAccuracy "Two-Variable Formula" cfg target df 0.1 allTests :: Test allTests = From d9b64692d9e8c5d95d0ce922b2dfcc87d60e6a90 Mon Sep 17 00:00:00 2001 From: Ashwin Mathi Date: Thu, 12 Feb 2026 03:14:14 -0500 Subject: [PATCH 23/23] Refactor egraphGP function by extracting crossover, core evolution loop, and report generation Reduces egraphGP from ~420 lines to ~230 lines. --- src/Symbolic/Regression.hs | 431 ++++++++++++++++++++----------------- 1 file changed, 233 insertions(+), 198 deletions(-) diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index a98f8ad..0263775 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -125,6 +125,7 @@ import Data.SRTree.Random import Algorithm.EqSat (runEqSat) import Algorithm.EqSat.SearchSR +import Data.Time.Clock (NominalDiffTime) import Data.Time.Clock.POSIX import Text.ParseSR @@ -150,7 +151,7 @@ identify the operation in the internal representation. @ myUnaryFunctions = [ uSquare (\`F.pow\` 2) - , uLog log + , uLog log , uRecip (1 /) ] @ @@ -496,6 +497,234 @@ finalizeEGraph cfg = unless (null (dumpTo cfg)) $ get >>= (io . BS.writeFile (dumpTo cfg) . encode) +-- | Crossover operation between two expressions +crossoverExpressions :: + RegressionConfig -> + Int -> -- maxExpressionSize + (Fix SRTree -> Fix SRTree) -> -- relabel function + EClassId -> + EClassId -> + RndEGraph EClassId +crossoverExpressions cfg maxSize relabel p1 p2 = do + sz <- getSize p1 + coin <- rnd $ tossBiased (crossoverProbability cfg) + if sz == 1 || not coin + then rnd (randomFrom [p1, p2]) + else do + pos <- rnd $ randomRange (1, sz - 1) + cands <- getAllSubClasses p2 + tree <- getSubtree pos 0 Nothing [] cands p1 + fromTree myCost (relabel tree) >>= canonical + where + getAllSubClasses p' = do + p <- canonical p' + en <- getBestENode p + case en of + Bin _ l r -> do + ls <- getAllSubClasses l + rs <- getAllSubClasses r + pure (p : ls <> rs) + Uni _ t -> (p :) <$> getAllSubClasses t + _ -> pure [p] + + getSubtree :: + Int -> + Int -> + Maybe (EClassId -> ENode) -> + [Maybe (EClassId -> ENode)] -> + [EClassId] -> + EClassId -> + RndEGraph (Fix SRTree) + getSubtree 0 sz (Just parent) mGrandParents cands p' = do + p <- canonical p' + candidates' <- + filterM (fmap (< maxSize - sz) . getSize) cands + candidates <- + filterM (doesNotExistGens mGrandParents . parent) candidates' + >>= traverse canonical + if null candidates + then getBestExpr p + else do + subtree <- rnd (randomFrom candidates) + getBestExpr subtree + getSubtree pos sz parent mGrandParents cands p' = do + p <- canonical p' + root <- getBestENode p >>= canonize + case root of + Param ix -> pure . Fix $ Param ix + Const x -> pure . Fix $ Const x + Var ix -> pure . Fix $ Var ix + Uni f t' -> do + t <- canonical t' + Fix . Uni f + <$> getSubtree (pos - 1) (sz + 1) (Just $ Uni f) (parent : mGrandParents) cands t + Bin op l'' r'' -> do + l <- canonical l'' + r <- canonical r'' + szLft <- getSize l + szRgt <- getSize r + if szLft < pos + then do + l' <- getBestExpr l + r' <- + getSubtree + (pos - szLft - 1) + (sz + szLft + 1) + (Just $ Bin op l) + (parent : mGrandParents) + cands + r + pure . Fix $ Bin op l' r' + else do + l' <- + getSubtree + (pos - 1) + (sz + szRgt + 1) + (Just (\t -> Bin op t r)) + (parent : mGrandParents) + cands + l + r' <- getBestExpr r + pure . Fix $ Bin op l' r' + +-- | Generate detailed expression report with multiple metrics and formats +generateExpressionReport :: + RegressionConfig -> + String -> -- varnames + [(DataSet, DataSet)] -> -- dataTrainVals + [DataSet] -> -- dataTests + Bool -> -- shouldReparam + (Fix SRTree -> StateT EGraph (StateT StdGen IO) (Double, [Array S Ix1 Double])) -> -- fitFun + Int -> -- expression index + EClassId -> + RndEGraph [String] +generateExpressionReport cfg varnames dataTrainVals dataTests shouldReparam fitFun ix ec' = do + ec <- canonical ec' + thetas' <- gets (fmap (_theta . _info) . (IM.!? ec) . _eClass) + bestExpr <- + (if simplifyExpressions cfg then simplifyEqSatDefault else id) + <$> getBestExpr ec + + let best' = + if shouldReparam then relabelParams bestExpr else relabelParamsOrder bestExpr + nParams' = countParamsUniq best' + fromSz (MA.Sz x) = x + nThetas = fmap (Prelude.map (fromSz . MA.size)) thetas' + (_, thetas) <- + if maybe False (Prelude.any (/= nParams')) nThetas + then fitFun best' + else pure (1.0, fromMaybe [] thetas') + + maxLoss <- negate . fromJust <$> getFitness ec + forM (Data.List.zip4 [(0 :: Int) ..] dataTrainVals dataTests thetas) $ \(view, (dataTrain, dataVal), dataTest, theta') -> do + let (x, y, mYErr) = dataTrain + (x_val, y_val, mYErr_val) = dataVal + (x_te, y_te, mYErr_te) = dataTest + distribution = lossFunction cfg + + expr = paramsToConst (MA.toList theta') best' + showNA z = if isNaN z then "" else show z + r2_train = r2 x y best' theta' + r2_val = r2 x_val y_val best' theta' + r2_te = r2 x_te y_te best' theta' + nll_train = nll distribution mYErr x y best' theta' + nll_val = nll distribution mYErr_val x_val y_val best' theta' + nll_te = nll distribution mYErr_te x_te y_te best' theta' + mdl_train = fractionalBayesFactor distribution mYErr x y theta' best' + mdl_val = fractionalBayesFactor distribution mYErr_val x_val y_val theta' best' + mdl_te = fractionalBayesFactor distribution mYErr_te x_te y_te theta' best' + vals = + intercalate "," $ + Prelude.map + showNA + [ nll_train + , nll_val + , nll_te + , maxLoss + , r2_train + , r2_val + , r2_te + , mdl_train + , mdl_val + , mdl_te + ] + thetaStr = intercalate ";" $ Prelude.map show (MA.toList theta') + showExprFun = if null varnames then showExpr else showExprWithVars (splitOn "," varnames) + showLatexFun = if null varnames then showLatex else showLatexWithVars (splitOn "," varnames) + pure $ + show ix + <> "," + <> show view + <> "," + <> showExprFun expr + <> "," + <> "\"" + <> showPython best' + <> "\"," + <> "\"$$" + <> showLatexFun best' + <> "$$\"," + <> thetaStr + <> "," + <> show @Int (countNodes $ convertProtectedOps expr) + <> "," + <> vals + +-- | Run the main evolutionary loop with timing and memory management +runEvolutionEngine :: + RegressionConfig -> + Int -> -- maxMem + ([EClassId] -> RndEGraph EClassId) -> -- evolve function + (Int -> EClassId -> RndEGraph [String]) -> -- printExpr function + RndEGraph () -> -- cleanEGraph function + POSIXTime -> + Maybe NominalDiffTime -> + ([EClassId], [[String]], Int) -> + RndEGraph ([EClassId], [[String]], Int) +runEvolutionEngine cfg maxMem evolve printExpr' cleanEGraph t0 mTime initial = + iterateFor (generations cfg) t0 mTime initial generationStep + where + generationStep _ (ps', out, curIx) = do + newPop' <- replicateM (populationSize cfg) (evolve ps') + + out' <- + if showTrace cfg + then forM (Prelude.zip [curIx ..] newPop') $ uncurry printExpr' + else pure [] + + totSz <- gets (Map.size . _eNodeToEClass) + let full = totSz > max maxMem (populationSize cfg) + when full (cleanEGraph >> cleanDB) + + newPop <- + if generational cfg + then Prelude.mapM canonical newPop' + else selectNextPopulation full newPop' + + pure (newPop, out <> out', curIx + populationSize cfg) + + selectNextPopulation full newPop' = do + pareto <- concat <$> forM [1 .. maxExpressionSize cfg] (`getTopFitEClassWithSize` 2) + let remainder = populationSize cfg - length pareto + lft <- + if full + then getTopFitEClassThat remainder (const True) + else pure $ Prelude.take remainder newPop' + Prelude.mapM canonical (pareto <> lft) + + iterateFor 0 _ _ xs _ = pure xs + iterateFor n t0' maxT xs f = do + xs' <- f n xs + t1 <- io getPOSIXTime + let delta = t1 - t0' + maxT' = subtract delta <$> maxT + case maxT' of + Nothing -> iterateFor (n - 1) t1 maxT' xs' f + Just mt -> + if mt <= 0 + then pure xs + else iterateFor (n - 1) t1 maxT' xs' f + egraphGP :: RegressionConfig -> String -> -- nonterminals @@ -563,49 +792,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do rndNonTerm = randomFrom nonTerms -- Main evolution loop - runGenerations t0 mTime initial = - iterateFor (generations cfg) t0 mTime initial generationStep - where - generationStep _ (ps', out, curIx) = do - newPop' <- replicateM (populationSize cfg) (evolve ps') - - out' <- - if showTrace cfg - then forM (Prelude.zip [curIx ..] newPop') $ uncurry printExpr' - else pure [] - - totSz <- gets (Map.size . _eNodeToEClass) - let full = totSz > max maxMem (populationSize cfg) - when full (cleanEGraph >> cleanDB) - - newPop <- - if generational cfg - then Prelude.mapM canonical newPop' - else selectNextPopulation full newPop' - - pure (newPop, out <> out', curIx + populationSize cfg) - - selectNextPopulation full newPop' = do - pareto <- concat <$> forM [1 .. maxExpressionSize cfg] (`getTopFitEClassWithSize` 2) - let remainder = populationSize cfg - length pareto - lft <- - if full - then getTopFitEClassThat remainder (const True) - else pure $ Prelude.take remainder newPop' - Prelude.mapM canonical (pareto <> lft) - - iterateFor 0 _ _ xs _ = pure xs - iterateFor n t0' maxT xs f = do - xs' <- f n xs - t1 <- io getPOSIXTime - let delta = t1 - t0' - maxT' = subtract delta <$> maxT - case maxT' of - Nothing -> iterateFor (n - 1) t1 maxT' xs' f - Just mt -> - if mt <= 0 - then pure xs - else iterateFor (n - 1) t1 maxT' xs' f + runGenerations = runEvolutionEngine cfg maxMem evolve printExpr' cleanEGraph -- E-graph management cleanEGraph = do @@ -659,88 +846,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do combine (p1, p2) = crossover p1 p2 >>= mutate >>= canonical -- Crossover operator - crossover p1 p2 = do - sz <- getSize p1 - coin <- rnd $ tossBiased (crossoverProbability cfg) - if sz == 1 || not coin - then rnd (randomFrom [p1, p2]) - else do - pos <- rnd $ randomRange (1, sz - 1) - cands <- getAllSubClasses p2 - tree <- getSubtree pos 0 Nothing [] cands p1 - fromTree myCost (relabel tree) >>= canonical - - -- Crossover and mutation helpers - getSubtree :: - Int -> - Int -> - Maybe (EClassId -> ENode) -> - [Maybe (EClassId -> ENode)] -> - [EClassId] -> - EClassId -> - RndEGraph (Fix SRTree) - getSubtree 0 sz (Just parent) mGrandParents cands p' = do - p <- canonical p' - candidates' <- - filterM (fmap (< maxExpressionSize cfg - sz) . getSize) cands - candidates <- - filterM (doesNotExistGens mGrandParents . parent) candidates' - >>= traverse canonical - if null candidates - then getBestExpr p - else do - subtree <- rnd (randomFrom candidates) - getBestExpr subtree - getSubtree pos sz parent mGrandParents cands p' = do - p <- canonical p' - root <- getBestENode p >>= canonize - case root of - Param ix -> pure . Fix $ Param ix - Const x -> pure . Fix $ Const x - Var ix -> pure . Fix $ Var ix - Uni f t' -> do - t <- canonical t' - Fix . Uni f - <$> getSubtree (pos - 1) (sz + 1) (Just $ Uni f) (parent : mGrandParents) cands t - Bin op l'' r'' -> do - l <- canonical l'' - r <- canonical r'' - szLft <- getSize l - szRgt <- getSize r - if szLft < pos - then do - l' <- getBestExpr l - r' <- - getSubtree - (pos - szLft - 1) - (sz + szLft + 1) - (Just $ Bin op l) - (parent : mGrandParents) - cands - r - pure . Fix $ Bin op l' r' - else do - l' <- - getSubtree - (pos - 1) - (sz + szRgt + 1) - (Just (\t -> Bin op t r)) - (parent : mGrandParents) - cands - l - r' <- getBestExpr r - pure . Fix $ Bin op l' r' - - getAllSubClasses p' = do - p <- canonical p' - en <- getBestENode p - case en of - Bin _ l r -> do - ls <- getAllSubClasses l - rs <- getAllSubClasses r - pure (p : ls <> rs) - Uni _ t -> (p :) <$> getAllSubClasses t - _ -> pure [p] + crossover = crossoverExpressions cfg (maxExpressionSize cfg) relabel -- Mutation operator mutate p = do @@ -814,78 +920,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do pure . Fix $ Bin op l' r' -- Output and reporting - printExpr' :: Int -> EClassId -> RndEGraph [String] - printExpr' ix ec' = do - ec <- canonical ec' - thetas' <- gets (fmap (_theta . _info) . (IM.!? ec) . _eClass) - bestExpr <- - (if simplifyExpressions cfg then simplifyEqSatDefault else id) - <$> getBestExpr ec - - let best' = - if shouldReparam then relabelParams bestExpr else relabelParamsOrder bestExpr - nParams' = countParamsUniq best' - fromSz (MA.Sz x) = x - nThetas = fmap (Prelude.map (fromSz . MA.size)) thetas' - (_, thetas) <- - if maybe False (Prelude.any (/= nParams')) nThetas - then fitFun best' - else pure (1.0, fromMaybe [] thetas') - - maxLoss <- negate . fromJust <$> getFitness ec - forM (Data.List.zip4 [(0 :: Int) ..] dataTrainVals dataTests thetas) $ \(view, (dataTrain, dataVal), dataTest, theta') -> do - let (x, y, mYErr) = dataTrain - (x_val, y_val, mYErr_val) = dataVal - (x_te, y_te, mYErr_te) = dataTest - distribution = lossFunction cfg - - expr = paramsToConst (MA.toList theta') best' - showNA z = if isNaN z then "" else show z - r2_train = r2 x y best' theta' - r2_val = r2 x_val y_val best' theta' - r2_te = r2 x_te y_te best' theta' - nll_train = nll distribution mYErr x y best' theta' - nll_val = nll distribution mYErr_val x_val y_val best' theta' - nll_te = nll distribution mYErr_te x_te y_te best' theta' - mdl_train = fractionalBayesFactor distribution mYErr x y theta' best' - mdl_val = fractionalBayesFactor distribution mYErr_val x_val y_val theta' best' - mdl_te = fractionalBayesFactor distribution mYErr_te x_te y_te theta' best' - vals = - intercalate "," $ - Prelude.map - showNA - [ nll_train - , nll_val - , nll_te - , maxLoss - , r2_train - , r2_val - , r2_te - , mdl_train - , mdl_val - , mdl_te - ] - thetaStr = intercalate ";" $ Prelude.map show (MA.toList theta') - showExprFun = if null varnames then showExpr else showExprWithVars (splitOn "," varnames) - showLatexFun = if null varnames then showLatex else showLatexWithVars (splitOn "," varnames) - pure $ - show ix - <> "," - <> show view - <> "," - <> showExprFun expr - <> "," - <> "\"" - <> showPython best' - <> "\"," - <> "\"$$" - <> showLatexFun best' - <> "$$\"," - <> thetaStr - <> "," - <> show @Int (countNodes $ convertProtectedOps expr) - <> "," - <> vals + printExpr' = generateExpressionReport cfg varnames dataTrainVals dataTests shouldReparam fitFun -- Initialization helpers insertTerms = forM terms (fromTree myCost >=> canonical)