diff --git a/.gitignore b/.gitignore index c80629d..4833498 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ bin/ coverage-html .DS_Store flake.lock +data/ diff --git a/README.md b/README.md index e2c04bb..504b2f8 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,15 @@ symbolic-regression integrates symbolic regression capabilities into a DataFrame ```haskell ghci> import qualified DataFrame as D -ghci> import DataFrame.Functions ((.=)) +ghci> import qualified DataFrame.Functions as F ghci> import Symbolic.Regression -- Load your data ghci> df <- D.readParquet "./data/mtcars.parquet" +-- Define mpg as a column reference +ghci> let mpg = F.col "mpg" + -- Run symbolic regression to predict 'mpg' -- NOTE: ALL COLUMNS MUST BE CONVERTED TO DOUBLE FIRST -- e.g df' = D.derive "some_column" (F.toDouble (F.col @Int "some_column")) df @@ -24,19 +27,20 @@ ghci> df <- D.readParquet "./data/mtcars.parquet" ghci> exprs <- fit defaultRegressionConfig mpg df -- View discovered expressions (Pareto front from simplest to most complex) -ghci> map D.prettyPrint exprs --- [ qsec, --- , 57.33 / wt --- , 10.75 + (1557.67 / disp)] +ghci> mapM_ (\(i, e) -> putStrLn $ "Model " ++ show i ++ ": " ++ D.prettyPrint e) (zip [1..] exprs) + +-- Create named expressions for different complexity levels +ghci> import qualified Data.Text as T +ghci> let levels = zipWith (F..=) (map (T.pack . ("level_" ++) . show) [1..]) exprs --- Create named expressions that we'll use in a dataframe. -ghci> levels = zipWith (.=) ["level_1", "level_2", "level_3"] exprs +-- Show the various predictions in our dataframe +ghci> let df' = D.deriveMany levels df --- Show the various predictions in our dataframe. -ghci> D.deriveMany levels df +-- Or pick the best one for prediction +ghci> let df'' = D.derive "prediction" (last exprs) df' --- Or pick the best one -ghci> D.derive "prediction" (last exprs) df +-- Display the results +ghci> D.display (D.DisplayOptions 5) df'' ``` ## Configuration @@ -102,3 +106,14 @@ To install symbolic-regression you'll need: * libz: `sudo apt install libz-dev` * libnlopt: `sudo apt install libnlopt-dev` * libgmp: `sudo apt install libgmp-dev` + +### Nix Development Environment +For Nix users with flakes enabled: + +```bash +git clone +cd symbolic-regression +nix develop -c cabal repl +``` + +Then follow the Quick Start example above. diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..705e083 --- /dev/null +++ b/flake.nix @@ -0,0 +1,84 @@ +{ + description = "DataFrame symbolic regression"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-parts.url = "github:hercules-ci/flake-parts"; + }; + + outputs = + inputs@{ flake-parts, ... }: + flake-parts.lib.mkFlake { inherit inputs; } { + systems = [ + "x86_64-linux" + "aarch64-linux" + "aarch64-darwin" + "x86_64-darwin" + ]; + + perSystem = + { + config, + pkgs, + system, + ... + }: + let + pkgs' = import inputs.nixpkgs { + inherit system; + config.allowBroken = true; + }; + + haskellPackages = pkgs'.haskellPackages.extend ( + final: prev: { + # Use the correct hash that Nix gave us + random = pkgs'.haskell.lib.dontCheck ( + prev.callHackageDirect { + pkg = "random"; + ver = "1.3.0"; + sha256 = "sha256-AzXPz8oe+9lAS60nRNFiRTEDrFDB0FPLtbvJ9PB7brM="; + } { } + ); + + srtree = pkgs'.haskell.lib.dontCheck ( + prev.callHackageDirect { + pkg = "srtree"; + ver = "2.0.1.6"; + sha256 = "sha256-D561wnKoFRr/HSviacsbtF5VnJHehG73LOmWM4TxlNs="; + } { } + ); + + dataframe = pkgs'.haskell.lib.dontCheck ( + prev.callHackageDirect { + pkg = "dataframe"; + ver = "0.4.1.0"; + sha256 = "sha256-u4FrhD+oOnQiGazq/TWuL0PzUvLKKAkz679tZSkiMaY="; + } { } + ); + + # Jailbreak packages that depend on old random + time-compat = pkgs'.haskell.lib.doJailbreak prev.time-compat; + splitmix = pkgs'.haskell.lib.doJailbreak prev.splitmix; + } + ); + + in + { + packages = { + default = config.packages.symbolic-regression; + symbolic-regression = haskellPackages.callCabal2nix "symbolic-regression" ./. { }; + }; + + devShells.default = haskellPackages.shellFor { + packages = p: [ config.packages.symbolic-regression ]; + buildInputs = with pkgs'; [ + cabal-install + haskell-language-server + git + fourmolu + nixfmt + ]; + }; + }; + }; +} diff --git a/src/Symbolic/Regression.hs b/src/Symbolic/Regression.hs index 1bf5c8e..0263775 100644 --- a/src/Symbolic/Regression.hs +++ b/src/Symbolic/Regression.hs @@ -19,12 +19,14 @@ off complexity and accuracy. @ import qualified DataFrame as D -import DataFrame.Functions ((.=)) +import qualified DataFrame.Functions as F import Symbolic.Regression -- Load your data df <- D.readParquet "./data/mtcars.parquet" +let mpg = F.col "mpg" + -- Run symbolic regression to predict 'mpg' exprs <- fit defaultRegressionConfig mpg df @@ -54,6 +56,21 @@ module Symbolic.Regression ( RegressionConfig (..), ValidationConfig (..), defaultRegressionConfig, + + -- * Function constructors and utilities + UnaryFunc, + BinaryFunc, + uLog, + uSquare, + uCube, + uRecip, + bAdd, + bSub, + bMul, + bDiv, + bPow, + getUnaryName, + getBinaryName, ) where import Control.Exception (throw) @@ -105,14 +122,123 @@ import Data.SRTree.Datasets import qualified Data.SRTree.Internal as SI import Data.SRTree.Print import Data.SRTree.Random -import Data.Type.Equality (TestEquality (testEquality), type (:~:) (Refl)) -import Type.Reflection (typeRep) import Algorithm.EqSat (runEqSat) import Algorithm.EqSat.SearchSR +import Data.Time.Clock (NominalDiffTime) import Data.Time.Clock.POSIX import Text.ParseSR +{- | Lift a random generator action into the e-graph monad. + +Wraps an action in the `Rng IO` monad (`StateT StdGen IO`) so it can be executed +within the `StateT EGraph (StateT StdGen IO)` monad. The current random generator +state is read from the inner `StateT StdGen`, updated after the action runs, and +the result is returned in the e-graph monad. +-} +liftRng :: Rng IO a -> StateT EGraph (StateT StdGen IO) a +liftRng rngAction = do + gen <- lift get -- get StdGen from inner StateT + (result, gen') <- liftIO $ runStateT rngAction gen + lift $ put gen' -- update StdGen + return result + +{- | Tagged unary operation for symbolic regression. + +Pairs an operation name with its implementation, ensuring the operation's +identity is preserved throughout the search process. The name is used to +identify the operation in the internal representation. + +@ +myUnaryFunctions = [ uSquare (\`F.pow\` 2) + , uLog log + , uRecip (1 /) + ] +@ +-} +data UnaryFunc = UnaryFunc UnaryOpName (D.Expr Double -> D.Expr Double) + +-- | Convenience constructors for UnaryFunc +uLog :: (Expr Double -> Expr Double) -> UnaryFunc +uLog = UnaryFunc ULog + +uSquare :: (Expr Double -> Expr Double) -> UnaryFunc +uSquare = UnaryFunc USquare + +uCube :: (Expr Double -> Expr Double) -> UnaryFunc +uCube = UnaryFunc UCube + +uRecip :: (Expr Double -> Expr Double) -> UnaryFunc +uRecip = UnaryFunc URecip + +data UnaryOpName + = ULog + | USquare + | UCube + | URecip + deriving (Eq) + +instance Show UnaryOpName where + show ULog = "log" + show USquare = "square" + show UCube = "cube" + show URecip = "recip" + +-- | Extract the name of a unary operation. +getUnaryName :: UnaryFunc -> String +getUnaryName (UnaryFunc name _) = show name + +{- | Tagged binary operation for symbolic regression. + +Pairs an operation name with its implementation, ensuring the operation's +identity is preserved throughout the search process. The name is used to +identify the operation in the internal representation. + +@ +myBinaryFunctions = [ bAdd (+) + , bSub (-) + , bMul (*) + , bDiv (/) + ] +@ +-} +data BinaryOpName + = BAdd + | BSub + | BMul + | BDiv + | BPow + deriving (Eq) + +instance Show BinaryOpName where + show BAdd = "add" + show BSub = "sub" + show BMul = "mul" + show BDiv = "div" + show BPow = "pow" + +data BinaryFunc = BinaryFunc BinaryOpName (D.Expr Double -> D.Expr Double -> D.Expr Double) + +-- | Extract the name of a binary operation. +getBinaryName :: BinaryFunc -> String +getBinaryName (BinaryFunc name _) = show name + +-- | Convenience constructors for BinaryFunc +bAdd :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bAdd = BinaryFunc BAdd + +bSub :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bSub = BinaryFunc BSub + +bMul :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bMul = BinaryFunc BMul + +bDiv :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bDiv = BinaryFunc BDiv + +bPow :: (Expr Double -> Expr Double -> Expr Double) -> BinaryFunc +bPow = BinaryFunc BPow + {- | Configuration for the symbolic regression algorithm. Use 'defaultRegressionConfig' as a starting point and modify fields as needed: @@ -149,9 +275,9 @@ data RegressionConfig = RegressionConfig -- ^ Probability of crossover between expressions (default: 0.95) , mutationProbability :: Double -- ^ Probability of mutation (default: 0.3) - , unaryFunctions :: [D.Expr Double -> D.Expr Double] + , unaryFunctions :: [UnaryFunc] -- ^ Unary operations to include in the search space (default: @[]@) - , binaryFunctions :: [D.Expr Double -> D.Expr Double -> D.Expr Double] + , binaryFunctions :: [BinaryFunc] {- ^ Binary operations to include in the search space (default: @[(+), (-), (*), (\/)]@) -} @@ -167,6 +293,8 @@ data RegressionConfig = RegressionConfig -- ^ File path to save e-graph state for later resumption (default: @\"\"@) , loadFrom :: String -- ^ File path to load e-graph state from a previous run (default: @\"\"@) + , seed :: Maybe Int + -- ^ Random seed for reproducibility. 'Nothing' uses system random (default: 'Nothing') } data ValidationConfig = ValidationConfig @@ -200,14 +328,25 @@ defaultRegressionConfig = , tournamentSize = 3 , crossoverProbability = 0.95 , mutationProbability = 0.3 - , unaryFunctions = [(`F.pow` 2), (`F.pow` 3), log, (1 /)] - , binaryFunctions = [(+), (-), (*), (/)] + , unaryFunctions = + [ uSquare (`F.pow` 2) + , uCube (`F.pow` 3) + , uLog log + , uRecip (1 /) + ] + , binaryFunctions = + [ bAdd (+) + , bSub (-) + , bMul (*) + , bDiv (/) + ] , numParams = -1 , generational = False , simplifyExpressions = True , maxTime = -1 , dumpTo = "" , loadFrom = "" + , seed = Nothing } {- | Run symbolic regression to discover mathematical expressions that fit the data. @@ -247,7 +386,9 @@ fit :: -- | Pareto front of expressions, ordered simplest to most complex IO [D.Expr Double] fit cfg targetColumn df = do - g <- getStdGen + g <- case seed cfg of + Just s -> pure $ mkStdGen s + Nothing -> getStdGen let (train, validation) = case validationConfig cfg of Nothing -> (df, df) @@ -270,13 +411,9 @@ fit cfg targetColumn df = do Array S Ix2 Double toTarget d = fromLists' Seq (D.columnAsList targetColumn d) :: Array S Ix1 Double nonterminals = - intercalate - "," - ( Prelude.map - (toNonTerminal . (\f -> f (F.col "fake1") (F.col "fake2"))) - (binaryFunctions cfg) - ++ Prelude.map (toNonTerminal . (\f -> f (F.col "fake1"))) (unaryFunctions cfg) - ) + intercalate "," $ + Prelude.map getBinaryName (binaryFunctions cfg) + ++ Prelude.map getUnaryName (unaryFunctions cfg) varnames = intercalate "," @@ -317,45 +454,34 @@ toExpr cols (Fix (Bin op left right)) = case op of treeOp -> error ("UNIMPLEMENTED OPERATION: " ++ show treeOp) toExpr _ _ = error "UNIMPLEMENTED" -toNonTerminal :: D.Expr Double -> String -toNonTerminal (BinaryOp "add" _ _ _) = "add" -toNonTerminal (BinaryOp "sub" _ _ _) = "sub" -toNonTerminal (BinaryOp "mult" _ _ _) = "mul" -toNonTerminal (BinaryOp "divide" _ (Lit n :: Expr b) _) = case testEquality (typeRep @b) (typeRep @Double) of - Nothing -> error "[Internal Error] - Reciprocal of non-double" - Just Refl -> case n of - 1 -> "recip" - _ -> error "Unknown reciprocal" -toNonTerminal (BinaryOp "divide" _ _ _) = "div" -toNonTerminal (BinaryOp "pow" _ _ (e :: Expr b)) = case testEquality (typeRep @b) (typeRep @Int) of - Nothing -> error "Impossible: Raised to non-int power" - Just Refl -> case e of - (Lit 2) -> "square" - (Lit 3) -> "cube" - _ -> error "Unknown power" -toNonTerminal (UnaryOp "log" _ _) = "log" -toNonTerminal e = error ("Unsupported operation: " ++ show e) - -egraphGP :: +-- | Initialize the e-graph by loading from file (if specified) and inserting initial terms +initializeEGraph :: RegressionConfig -> - String -> -- nonterminals - String -> -- varnames - [(DataSet, DataSet)] -> - [DataSet] -> - StateT EGraph (StateT StdGen IO) [Fix SRTree] -egraphGP cfg nonterminals varnames dataTrainVals dataTests = do + (Fix SRTree -> RndEGraph (Double, [Array S Ix1 Double])) -> + RndEGraph EClassId -> + RndEGraph (SRTree ()) -> + RndEGraph [EClassId] -> + RndEGraph () +initializeEGraph cfg fitFun _rndTerm _rndNonTerm insertTerms = do unless (null (loadFrom cfg)) $ io (BS.readFile (loadFrom cfg)) >>= \eg -> put (decode eg) - _ <- insertTerms evaluateUnevaluated fitFun - t0 <- io getPOSIXTime - +-- | Create the initial population of expressions +createInitialPopulation :: + RegressionConfig -> + (Fix SRTree -> StateT EGraph (StateT StdGen IO) (Double, [Array S Ix1 Double])) -> + Rng IO (Fix SRTree) -> -- rndTerm + Rng IO (SRTree ()) -> -- rndNonTerm + (Int -> EClassId -> StateT EGraph (StateT StdGen IO) [String]) -> + StateT EGraph (StateT StdGen IO) ([EClassId], [[String]]) +createInitialPopulation cfg fitFun rndTerm rndNonTerm printExpr' = do pop <- replicateM (populationSize cfg) $ do ec <- insertRndExpr (maxExpressionSize cfg) rndTerm rndNonTerm >>= canonical _ <- updateIfNothing fitFun ec pure ec + pop' <- Prelude.mapM canonical pop output <- @@ -363,9 +489,202 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do then forM (Prelude.zip [0 ..] pop') $ uncurry printExpr' else pure [] - let mTime = - if maxTime cfg < 0 then Nothing else Just (fromIntegral $ maxTime cfg - 5) - (_, _, _) <- iterateFor (generations cfg) t0 mTime (pop', output, populationSize cfg) $ \_ (ps', out, curIx) -> do + pure (pop', output) + +-- | Finalize by saving e-graph to file if specified +finalizeEGraph :: RegressionConfig -> RndEGraph () +finalizeEGraph cfg = + unless (null (dumpTo cfg)) $ + get >>= (io . BS.writeFile (dumpTo cfg) . encode) + +-- | Crossover operation between two expressions +crossoverExpressions :: + RegressionConfig -> + Int -> -- maxExpressionSize + (Fix SRTree -> Fix SRTree) -> -- relabel function + EClassId -> + EClassId -> + RndEGraph EClassId +crossoverExpressions cfg maxSize relabel p1 p2 = do + sz <- getSize p1 + coin <- rnd $ tossBiased (crossoverProbability cfg) + if sz == 1 || not coin + then rnd (randomFrom [p1, p2]) + else do + pos <- rnd $ randomRange (1, sz - 1) + cands <- getAllSubClasses p2 + tree <- getSubtree pos 0 Nothing [] cands p1 + fromTree myCost (relabel tree) >>= canonical + where + getAllSubClasses p' = do + p <- canonical p' + en <- getBestENode p + case en of + Bin _ l r -> do + ls <- getAllSubClasses l + rs <- getAllSubClasses r + pure (p : ls <> rs) + Uni _ t -> (p :) <$> getAllSubClasses t + _ -> pure [p] + + getSubtree :: + Int -> + Int -> + Maybe (EClassId -> ENode) -> + [Maybe (EClassId -> ENode)] -> + [EClassId] -> + EClassId -> + RndEGraph (Fix SRTree) + getSubtree 0 sz (Just parent) mGrandParents cands p' = do + p <- canonical p' + candidates' <- + filterM (fmap (< maxSize - sz) . getSize) cands + candidates <- + filterM (doesNotExistGens mGrandParents . parent) candidates' + >>= traverse canonical + if null candidates + then getBestExpr p + else do + subtree <- rnd (randomFrom candidates) + getBestExpr subtree + getSubtree pos sz parent mGrandParents cands p' = do + p <- canonical p' + root <- getBestENode p >>= canonize + case root of + Param ix -> pure . Fix $ Param ix + Const x -> pure . Fix $ Const x + Var ix -> pure . Fix $ Var ix + Uni f t' -> do + t <- canonical t' + Fix . Uni f + <$> getSubtree (pos - 1) (sz + 1) (Just $ Uni f) (parent : mGrandParents) cands t + Bin op l'' r'' -> do + l <- canonical l'' + r <- canonical r'' + szLft <- getSize l + szRgt <- getSize r + if szLft < pos + then do + l' <- getBestExpr l + r' <- + getSubtree + (pos - szLft - 1) + (sz + szLft + 1) + (Just $ Bin op l) + (parent : mGrandParents) + cands + r + pure . Fix $ Bin op l' r' + else do + l' <- + getSubtree + (pos - 1) + (sz + szRgt + 1) + (Just (\t -> Bin op t r)) + (parent : mGrandParents) + cands + l + r' <- getBestExpr r + pure . Fix $ Bin op l' r' + +-- | Generate detailed expression report with multiple metrics and formats +generateExpressionReport :: + RegressionConfig -> + String -> -- varnames + [(DataSet, DataSet)] -> -- dataTrainVals + [DataSet] -> -- dataTests + Bool -> -- shouldReparam + (Fix SRTree -> StateT EGraph (StateT StdGen IO) (Double, [Array S Ix1 Double])) -> -- fitFun + Int -> -- expression index + EClassId -> + RndEGraph [String] +generateExpressionReport cfg varnames dataTrainVals dataTests shouldReparam fitFun ix ec' = do + ec <- canonical ec' + thetas' <- gets (fmap (_theta . _info) . (IM.!? ec) . _eClass) + bestExpr <- + (if simplifyExpressions cfg then simplifyEqSatDefault else id) + <$> getBestExpr ec + + let best' = + if shouldReparam then relabelParams bestExpr else relabelParamsOrder bestExpr + nParams' = countParamsUniq best' + fromSz (MA.Sz x) = x + nThetas = fmap (Prelude.map (fromSz . MA.size)) thetas' + (_, thetas) <- + if maybe False (Prelude.any (/= nParams')) nThetas + then fitFun best' + else pure (1.0, fromMaybe [] thetas') + + maxLoss <- negate . fromJust <$> getFitness ec + forM (Data.List.zip4 [(0 :: Int) ..] dataTrainVals dataTests thetas) $ \(view, (dataTrain, dataVal), dataTest, theta') -> do + let (x, y, mYErr) = dataTrain + (x_val, y_val, mYErr_val) = dataVal + (x_te, y_te, mYErr_te) = dataTest + distribution = lossFunction cfg + + expr = paramsToConst (MA.toList theta') best' + showNA z = if isNaN z then "" else show z + r2_train = r2 x y best' theta' + r2_val = r2 x_val y_val best' theta' + r2_te = r2 x_te y_te best' theta' + nll_train = nll distribution mYErr x y best' theta' + nll_val = nll distribution mYErr_val x_val y_val best' theta' + nll_te = nll distribution mYErr_te x_te y_te best' theta' + mdl_train = fractionalBayesFactor distribution mYErr x y theta' best' + mdl_val = fractionalBayesFactor distribution mYErr_val x_val y_val theta' best' + mdl_te = fractionalBayesFactor distribution mYErr_te x_te y_te theta' best' + vals = + intercalate "," $ + Prelude.map + showNA + [ nll_train + , nll_val + , nll_te + , maxLoss + , r2_train + , r2_val + , r2_te + , mdl_train + , mdl_val + , mdl_te + ] + thetaStr = intercalate ";" $ Prelude.map show (MA.toList theta') + showExprFun = if null varnames then showExpr else showExprWithVars (splitOn "," varnames) + showLatexFun = if null varnames then showLatex else showLatexWithVars (splitOn "," varnames) + pure $ + show ix + <> "," + <> show view + <> "," + <> showExprFun expr + <> "," + <> "\"" + <> showPython best' + <> "\"," + <> "\"$$" + <> showLatexFun best' + <> "$$\"," + <> thetaStr + <> "," + <> show @Int (countNodes $ convertProtectedOps expr) + <> "," + <> vals + +-- | Run the main evolutionary loop with timing and memory management +runEvolutionEngine :: + RegressionConfig -> + Int -> -- maxMem + ([EClassId] -> RndEGraph EClassId) -> -- evolve function + (Int -> EClassId -> RndEGraph [String]) -> -- printExpr function + RndEGraph () -> -- cleanEGraph function + POSIXTime -> + Maybe NominalDiffTime -> + ([EClassId], [[String]], Int) -> + RndEGraph ([EClassId], [[String]], Int) +runEvolutionEngine cfg maxMem evolve printExpr' cleanEGraph t0 mTime initial = + iterateFor (generations cfg) t0 mTime initial generationStep + where + generationStep _ (ps', out, curIx) = do newPop' <- replicateM (populationSize cfg) (evolve ps') out' <- @@ -380,21 +699,60 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do newPop <- if generational cfg then Prelude.mapM canonical newPop' - else do - pareto <- - concat <$> forM [1 .. maxExpressionSize cfg] (`getTopFitEClassWithSize` 2) - let remainder = populationSize cfg - length pareto - lft <- - if full - then getTopFitEClassThat remainder (const True) - else pure $ Prelude.take remainder newPop' - Prelude.mapM canonical (pareto <> lft) + else selectNextPopulation full newPop' + pure (newPop, out <> out', curIx + populationSize cfg) - unless (null (dumpTo cfg)) $ - get >>= (io . BS.writeFile (dumpTo cfg) . encode) + selectNextPopulation full newPop' = do + pareto <- concat <$> forM [1 .. maxExpressionSize cfg] (`getTopFitEClassWithSize` 2) + let remainder = populationSize cfg - length pareto + lft <- + if full + then getTopFitEClassThat remainder (const True) + else pure $ Prelude.take remainder newPop' + Prelude.mapM canonical (pareto <> lft) + + iterateFor 0 _ _ xs _ = pure xs + iterateFor n t0' maxT xs f = do + xs' <- f n xs + t1 <- io getPOSIXTime + let delta = t1 - t0' + maxT' = subtract delta <$> maxT + case maxT' of + Nothing -> iterateFor (n - 1) t1 maxT' xs' f + Just mt -> + if mt <= 0 + then pure xs + else iterateFor (n - 1) t1 maxT' xs' f + +egraphGP :: + RegressionConfig -> + String -> -- nonterminals + String -> -- varnames + [(DataSet, DataSet)] -> + [DataSet] -> + StateT EGraph (StateT StdGen IO) [Fix SRTree] +egraphGP cfg nonterminals varnames dataTrainVals dataTests = do + -- generate a random EClassId for the terminal expressions + -- rndTerm' :: RndEGraph EClassId + let rndTerm' = insertRndExpr (maxExpressionSize cfg) rndTerm rndNonTerm + + -- generate a random SRTree for non-terminal expressions + -- rndNonTerm' :: RndEGraph (SRTree ()) + let rndNonTerm' = liftRng rndNonTerm + + initializeEGraph cfg fitFun rndTerm' rndNonTerm' insertTerms + + t0 <- io getPOSIXTime + (pop', output) <- createInitialPopulation cfg fitFun rndTerm rndNonTerm printExpr' + + let mTime = if maxTime cfg < 0 then Nothing else Just (fromIntegral $ maxTime cfg - 5) + _ <- runGenerations t0 mTime (pop', output, populationSize cfg) + + finalizeEGraph cfg paretoFront' fitFun (maxExpressionSize cfg) where + -- Configuration and setup maxMem = 2000000 fitFun = fitnessMV @@ -426,6 +784,17 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do isBin (Bin{}) = True isBin _ = False + -- Random generators + rndTerm = do + coin <- toss + if coin || numParams cfg == 0 then randomFrom terms else randomFrom params + + rndNonTerm = randomFrom nonTerms + + -- Main evolution loop + runGenerations = runEvolutionEngine cfg maxMem evolve printExpr' cleanEGraph + + -- E-graph management cleanEGraph = do let nParetos = 10 io . putStrLn $ "cleaning" @@ -441,12 +810,6 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do Nothing -> pure () Just i'' -> insertFitness eId (fromJust $ _fitness i'') (_theta i'') - rndTerm = do - coin <- toss - if coin || numParams cfg == 0 then randomFrom terms else randomFrom params - - rndNonTerm = randomFrom nonTerms - refitChanged = do ids <- (gets (_refits . _eDB) >>= Prelude.mapM canonical . Set.toList) @@ -457,19 +820,7 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do (f, p) <- fitFun t insertFitness ec f p - iterateFor 0 _ _ xs _ = pure xs - iterateFor n t0' maxT xs f = do - xs' <- f n xs - t1 <- io getPOSIXTime - let delta = t1 - t0' - maxT' = subtract delta <$> maxT - case maxT' of - Nothing -> iterateFor (n - 1) t1 maxT' xs' f - Just mt -> - if mt <= 0 - then pure xs - else iterateFor (n - 1) t1 maxT' xs' f - + -- Evolution operators evolve xs' = do xs <- Prelude.mapM canonical xs' parents' <- tournament xs @@ -494,88 +845,10 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do combine (p1, p2) = crossover p1 p2 >>= mutate >>= canonical - crossover p1 p2 = do - sz <- getSize p1 - coin <- rnd $ tossBiased (crossoverProbability cfg) - if sz == 1 || not coin - then rnd (randomFrom [p1, p2]) - else do - pos <- rnd $ randomRange (1, sz - 1) - cands <- getAllSubClasses p2 - tree <- getSubtree pos 0 Nothing [] cands p1 - fromTree myCost (relabel tree) >>= canonical - - getSubtree :: - Int -> - Int -> - Maybe (EClassId -> ENode) -> - [Maybe (EClassId -> ENode)] -> - [EClassId] -> - EClassId -> - RndEGraph (Fix SRTree) - getSubtree 0 sz (Just parent) mGrandParents cands p' = do - p <- canonical p' - candidates' <- - filterM (fmap (< maxExpressionSize cfg - sz) . getSize) cands - candidates <- - filterM (doesNotExistGens mGrandParents . parent) candidates' - >>= traverse canonical - if null candidates - then getBestExpr p - else do - subtree <- rnd (randomFrom candidates) - getBestExpr subtree - getSubtree pos sz parent mGrandParents cands p' = do - p <- canonical p' - root <- getBestENode p >>= canonize - case root of - Param ix -> pure . Fix $ Param ix - Const x -> pure . Fix $ Const x - Var ix -> pure . Fix $ Var ix - Uni f t' -> do - t <- canonical t' - Fix . Uni f - <$> getSubtree (pos - 1) (sz + 1) (Just $ Uni f) (parent : mGrandParents) cands t - Bin op l'' r'' -> do - l <- canonical l'' - r <- canonical r'' - szLft <- getSize l - szRgt <- getSize r - if szLft < pos - then do - l' <- getBestExpr l - r' <- - getSubtree - (pos - szLft - 1) - (sz + szLft + 1) - (Just $ Bin op l) - (parent : mGrandParents) - cands - r - pure . Fix $ Bin op l' r' - else do - l' <- - getSubtree - (pos - 1) - (sz + szRgt + 1) - (Just (\t -> Bin op t r)) - (parent : mGrandParents) - cands - l - r' <- getBestExpr r - pure . Fix $ Bin op l' r' - - getAllSubClasses p' = do - p <- canonical p' - en <- getBestENode p - case en of - Bin _ l r -> do - ls <- getAllSubClasses l - rs <- getAllSubClasses r - pure (p : ls <> rs) - Uni _ t -> (p :) <$> getAllSubClasses t - _ -> pure [p] + -- Crossover operator + crossover = crossoverExpressions cfg (maxExpressionSize cfg) relabel + -- Mutation operator mutate p = do sz <- getSize p coin <- rnd $ tossBiased (mutationProbability cfg) @@ -646,81 +919,13 @@ egraphGP cfg nonterminals varnames dataTrainVals dataTests = do r' <- getBestExpr r pure . Fix $ Bin op l' r' - printExpr' :: Int -> EClassId -> RndEGraph [String] - printExpr' ix ec' = do - ec <- canonical ec' - thetas' <- gets (fmap (_theta . _info) . (IM.!? ec) . _eClass) - bestExpr <- - (if simplifyExpressions cfg then simplifyEqSatDefault else id) - <$> getBestExpr ec - - let best' = - if shouldReparam then relabelParams bestExpr else relabelParamsOrder bestExpr - nParams' = countParamsUniq best' - fromSz (MA.Sz x) = x - nThetas = fmap (Prelude.map (fromSz . MA.size)) thetas' - (_, thetas) <- - if maybe False (Prelude.any (/= nParams')) nThetas - then fitFun best' - else pure (1.0, fromMaybe [] thetas') - - maxLoss <- negate . fromJust <$> getFitness ec - forM (Data.List.zip4 [(0 :: Int) ..] dataTrainVals dataTests thetas) $ \(view, (dataTrain, dataVal), dataTest, theta') -> do - let (x, y, mYErr) = dataTrain - (x_val, y_val, mYErr_val) = dataVal - (x_te, y_te, mYErr_te) = dataTest - distribution = lossFunction cfg - - expr = paramsToConst (MA.toList theta') best' - showNA z = if isNaN z then "" else show z - r2_train = r2 x y best' theta' - r2_val = r2 x_val y_val best' theta' - r2_te = r2 x_te y_te best' theta' - nll_train = nll distribution mYErr x y best' theta' - nll_val = nll distribution mYErr_val x_val y_val best' theta' - nll_te = nll distribution mYErr_te x_te y_te best' theta' - mdl_train = fractionalBayesFactor distribution mYErr x y theta' best' - mdl_val = fractionalBayesFactor distribution mYErr_val x_val y_val theta' best' - mdl_te = fractionalBayesFactor distribution mYErr_te x_te y_te theta' best' - vals = - intercalate "," $ - Prelude.map - showNA - [ nll_train - , nll_val - , nll_te - , maxLoss - , r2_train - , r2_val - , r2_te - , mdl_train - , mdl_val - , mdl_te - ] - thetaStr = intercalate ";" $ Prelude.map show (MA.toList theta') - showExprFun = if null varnames then showExpr else showExprWithVars (splitOn "," varnames) - showLatexFun = if null varnames then showLatex else showLatexWithVars (splitOn "," varnames) - pure $ - show ix - <> "," - <> show view - <> "," - <> showExprFun expr - <> "," - <> "\"" - <> showPython best' - <> "\"," - <> "\"$$" - <> showLatexFun best' - <> "$$\"," - <> thetaStr - <> "," - <> show @Int (countNodes $ convertProtectedOps expr) - <> "," - <> vals + -- Output and reporting + printExpr' = generateExpressionReport cfg varnames dataTrainVals dataTests shouldReparam fitFun + -- Initialization helpers insertTerms = forM terms (fromTree myCost >=> canonical) + -- Pareto front extraction paretoFront' _ maxSize' = go 1 (-(1.0 / 0.0)) where go :: Int -> Double -> RndEGraph [Fix SRTree] diff --git a/symbolic-regression.cabal b/symbolic-regression.cabal index ffa4a8e..5716eb3 100644 --- a/symbolic-regression.cabal +++ b/symbolic-regression.cabal @@ -21,7 +21,7 @@ library import: warnings exposed-modules: Symbolic.Regression build-depends: base >= 4 && <5 - , dataframe ^>= 0.4 + , dataframe ^>= 0.4.1 , attoparsec >=0.14.4 && <0.15 , attoparsec-expr >=0.1.1.2 && <0.2 , binary >=0.8.9.1 && <0.9 @@ -62,4 +62,9 @@ test-suite symbolic-regression-test main-is: Main.hs build-depends: base >= 4 && <5, - symbolic-regression + symbolic-regression, + HUnit >= 1.6 && < 1.7, + dataframe ^>= 0.4.1, + random >= 1.2 && < 1.4, + text >= 2.0 && < 3, + time >= 1.12 && < 2 diff --git a/test/Main.hs b/test/Main.hs index 3e2059e..3913861 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -1,4 +1,152 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeApplications #-} + module Main (main) where +import qualified DataFrame as D +import qualified DataFrame.Functions as F +import Symbolic.Regression +import Test.HUnit + +-- Default regression config for deterministic testing +defaultTestConfig :: RegressionConfig +defaultTestConfig = + defaultRegressionConfig + { showTrace = False + , seed = Just 0 + } + +-- Compute maximum absolute error between target and predicted expression +maxAbsError :: D.Expr Double -> D.Expr Double -> D.DataFrame -> Double +maxAbsError target predicted df = + let (predColExpr, df') = D.deriveWithExpr "__test_predicted" predicted df + actualVals = D.columnAsList target df' :: [Double] + predVals = D.columnAsList predColExpr df' :: [Double] + in maximum (zipWith (\a p -> abs (a - p)) actualVals predVals) + +{- + +(predicted, predDf) = D.deriveWithExpr "predicted" bestExpr +predValues = D.columnAsList predicted predDf + +-} + +-- Extract best evolved expression (last in population history) +bestExpression :: RegressionConfig -> D.Expr Double -> D.DataFrame -> IO (Maybe (D.Expr Double)) +bestExpression cfg target df = do + exprs <- fit cfg target df + pure $ if null exprs then Nothing else Just (last exprs) + +-- High-level regression verification used by tests +verifyRegressionAccuracy :: + String -> + RegressionConfig -> + D.Expr Double -> + D.DataFrame -> + Double -> + Assertion +verifyRegressionAccuracy label cfg target df tolerance = do + mBest <- bestExpression cfg target df + case mBest of + Nothing -> + assertFailure noResultMsg + Just bestExpr -> do + let err = maxAbsError target bestExpr df + assertBool (errorMsg err) (err < tolerance) + where + noResultMsg = + label ++ ": regression returned no expressions" + + errorMsg err = + label + ++ ": expected max abs error < " + ++ show tolerance + ++ ", got " + ++ show err + +-- Linear: y = 2x + 1 +testLinearFormula :: Test +testLinearFormula = + TestCase $ + let xs = [1.0 .. 10.0] :: [Double] + ys = map (\x -> 2 * x + 1) xs + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultTestConfig + { generations = 30 + , populationSize = 80 + , maxExpressionSize = 4 + } + in verifyRegressionAccuracy "Linear Formula" cfg target df 1.0 + +-- Constant: y = 42 +testConstantFormula :: Test +testConstantFormula = + TestCase $ + let xs = [1.0, 2.0, 3.0, 4.0, 5.0] :: [Double] + ys = replicate 5 (42.0 :: Double) + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultTestConfig + { generations = 3 + , populationSize = 10 + , maxExpressionSize = 1 + } + in verifyRegressionAccuracy "Constant Formula" cfg target df 1.0 + +-- Quadratic: y = x² + 2x + 3 +testQuadraticFormula :: Test +testQuadraticFormula = + TestCase $ + let xs = [1.0 .. 6.0] :: [Double] + ys = map (\x -> x * x + 2 * x + 3) xs + df = D.fromNamedColumns [("x", D.fromList xs), ("y", D.fromList ys)] + target = F.col @Double "y" + cfg = + defaultTestConfig + { generations = 50 + , populationSize = 150 + , maxExpressionSize = 8 + } + in verifyRegressionAccuracy "Quadratic Formula" cfg target df 0.5 + +-- Two variable: z = x + y +testTwoVariableFormula :: Test +testTwoVariableFormula = + TestCase $ + let xs = [1.0, 2.0, 3.0, 4.0, 5.0] :: [Double] + ys = [2.0, 3.0, 1.0, 4.0, 2.0] :: [Double] + zs = zipWith (+) xs ys + df = + D.fromNamedColumns + [ ("x", D.fromList xs) + , ("y", D.fromList ys) + , ("z", D.fromList zs) + ] + target = F.col @Double "z" + cfg = + defaultTestConfig + { generations = 20 + , populationSize = 60 + , maxExpressionSize = 3 + } + in verifyRegressionAccuracy "Two-Variable Formula" cfg target df 0.1 + +allTests :: Test +allTests = + TestList + [ "Linear Formula" ~: testLinearFormula + , "Constant Formula" ~: testConstantFormula + , "Quadratic Formula" ~: testQuadraticFormula + , "Two-Variable Formula" ~: testTwoVariableFormula + ] + main :: IO () -main = putStrLn "Test suite not yet implemented." +main = do + putStrLn "Running symbolic regression tests..." + testCounts <- runTestTT allTests + putStrLn $ "Tests run: " ++ show (cases testCounts) + putStrLn $ "Failures: " ++ show (failures testCounts) + putStrLn $ "Errors: " ++ show (errors testCounts)