Skip to content

add decision tree and random forest first draft#4

Open
j143 wants to merge 3 commits intomainfrom
decision-tree
Open

add decision tree and random forest first draft#4
j143 wants to merge 3 commits intomainfrom
decision-tree

Conversation

@j143
Copy link
Owner

@j143 j143 commented May 3, 2025

decisionTree

# This script implements decision trees for recoded and binned categorical and
# numerical input features. We train a single CART (classification and
# regression tree) decision trees depending on the provided labels y, either
# classification (majority vote per leaf) or regression (average per leaf).
#
# .. code-block::
#
#   For example, give a feature matrix with features [a,b,c,d]
#   and the following trees, M would look as follows:
#
#   (L1)               |d<5|
#                     /     \
#   (L2)           P1:2    |a<7|
#                          /   \
#   (L3)                 P2:2 P3:1
#
#   --> M :=
#   [[4, 5, 0, 2, 1, 7, 0, 0, 0, 0, 0, 2, 0, 1]]
#    |(L1)| |  (L2)   | |        (L3)         |
#
#
#
# INPUT:
# ------------------------------------------------------------------------------
# X               Feature matrix in recoded/binned representation
# y               Label matrix in recoded/binned representation
# ctypes          Row-Vector of column types [1 scale/ordinal, 2 categorical]
#                 of shape 1-by-(ncol(X)+1), where the last entry is the y type
# max_depth       Maximum depth of the learned tree (stopping criterion)
# min_leaf        Minimum number of samples in leaf nodes (stopping criterion),
#                 odd number recommended to avoid 50/50 leaf label decisions
# min_split       Minimum number of samples in leaf for attempting a split
# max_features    Parameter controlling the number of features used as split
#                 candidates at tree nodes: m = ceil(num_features^max_features)
# max_values      Parameter controlling the number of values per feature used
#                 as split candidates: nb = ceil(num_values^max_values)
# max_dataratio   Parameter in [0,1] controlling when to materialize data
#                 subsets of X and y on node splits. When set to 0, we always
#                 scan the original X and y, which has the benefit of avoiding
#                 the allocation and maintenance of data for all active nodes.
#                 When set to 0.01 we rematerialize whenever the sub-tree data
#                 would be less than 1% of last the parent materialize data size.
# impurity        Impurity measure: entropy, gini (default), rss (regression)
# seed            Fixed seed for randomization of samples and split candidates
# verbose         Flag indicating verbose debug output
# ------------------------------------------------------------------------------
#
# OUTPUT:
# ------------------------------------------------------------------------------
# M              Matrix M containing the learned trees, in linearized form
# ------------------------------------------------------------------------------

m_decisionTree = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] ctypes,
    Int max_depth = 10, Int min_leaf = 20, Int min_split = 50,
    Double max_features = 0.5, Double max_values = 1.0, Double max_dataratio = 0.25,
    String impurity = "gini", Int seed = -1, Boolean verbose = FALSE)
  return(Matrix[Double] M)
{

and

decisionTreePredict

# This script implements random forest prediction for recoded and binned
# categorical and numerical input features.
# Hummingbird paper (https://www.usenix.org/system/files/osdi20-nakandala.pdf).
#
# INPUT:
# ------------------------------------------------------------------------------
# X               Feature matrix in recoded/binned representation
# y               Label matrix in recoded/binned representation,
#                 optional for accuracy evaluation
# ctypes          Row-Vector of column types [1 scale/ordinal, 2 categorical]
# M               Matrix M holding the learned tree in linearized form
#                 see decisionTree() for the detailed tree representation.
# strategy        Prediction strategy, can be one of ["GEMM", "TT", "PTT"],
#                 referring to "Generic matrix multiplication",
#                 "Tree traversal", and "Perfect tree traversal", respectively
# verbose         Flag indicating verbose debug output
# ------------------------------------------------------------------------------
#
# OUTPUT:
# ------------------------------------------------------------------------------
# yhat            Label vector of predictions
# ------------------------------------------------------------------------------

m_decisionTreePredict = function(Matrix[Double] X, Matrix[Double] y = matrix(0,0,0),
    Matrix[Double] ctypes, Matrix[Double] M, String strategy="TT", Boolean verbose = FALSE)
  return (Matrix[Double] yhat)
{

y_test_sds = sds.from_numpy(y_test_full + 1.0)
ctypes_sds = sds.from_numpy(ctypes)
dt_model = decisionTree(X_train_sds, y_train_sds, ctypes_sds, icpt=1, max_depth=5)
dt_pred = decisionTreePredict(X_test_sds, dt_model).compute()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ERROR] Could not find 'None' in 'dropped' column for logistic regression accuracy. Using 0.0 as fallback.
25/05/03 05:13:05 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Traceback (most recent call last):
File "/workspaces/heart-attack-analysis/heart_attack_systemds.py", line 234, in
dt_pred = decisionTreePredict(X_test_sds, dt_model).compute()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: decisionTreePredict() missing 1 required positional argument: 'M'

ctypes_sds = sds.from_numpy(ctypes)
dt_model = decisionTree(X_train_sds, y_train_sds, ctypes_sds, icpt=1, max_depth=5)
dt_pred = decisionTreePredict(X_test_sds, dt_model).compute()
dt_pred = decisionTreePredict(X_test_sds, dt_model, X_train_sds).compute()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[ERROR] Could not find 'None' in 'dropped' column for logistic regression accuracy. Using 0.0 as fallback.
25/05/03 05:15:50 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Traceback (most recent call last):
  File "/workspaces/heart-attack-analysis/heart_attack_systemds.py", line 234, in <module>
    dt_pred = decisionTreePredict(X_test_sds, dt_model, X_train_sds).compute()
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/systemds/operator/nodes/matrix.py", line 108, in compute
    return super().compute(verbose, lineage)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/systemds/operator/operation_node.py", line 108, in compute
    result_variables = self._script.execute()
                       ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/systemds/script_building/script.py", line 100, in execute
    self.sds_context.exception_and_close(exception_str, trace_back_limit)
  File "/usr/local/python/3.12.1/lib/python3.12/site-packages/systemds/context/systemds_context.py", line 143, in exception_and_close
    raise RuntimeError(message)
RuntimeError: 

An error occurred while calling o1.prepareScript.
: org.apache.sysds.api.DMLException: org.apache.sysds.parser.LanguageException: ERROR: [line 5:0] -> V4=decisionTree(X=V1,y=V2,ctypes=V3,icpt=1,max_depth=5); -- Named function call parameter 'icpt' does not exist in signature of function 'm_decisionTree'. Function signature: [X, y, ctypes, max_depth, min_leaf, min_split, max_features, max_values, max_dataratio, impurity, seed, verbose]
        at org.apache.sysds.api.jmlc.Connection.prepareScript(Connection.java:284)
        at org.apache.sysds.api.jmlc.Connection.prepareScript(Connection.java:223)
        at org.apache.sysds.api.jmlc.Connection.prepareScript(Connection.java:210)
        at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103)
        at java.base/java.lang.reflect.Method.invoke(Method.java:580)
        at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
        at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
        at py4j.Gateway.invoke(Gateway.java:282)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.GatewayConnection.run(GatewayConnection.java:238)
        at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: org.apache.sysds.parser.LanguageException: ERROR: [line 5:0] -> V4=decisionTree(X=V1,y=V2,ctypes=V3,icpt=1,max_depth=5); -- Named function call parameter 'icpt' does not exist in signature of function 'm_decisionTree'. Function signature: [X, y, ctypes, max_depth, min_leaf, min_split, max_features, max_values, max_dataratio, impurity, seed, verbose]
        at org.apache.sysds.parser.Expression.raiseValidateError(Expression.java:395)
        at org.apache.sysds.parser.Expression.raiseValidateError(Expression.java:362)
        at org.apache.sysds.parser.FunctionCallIdentifier.validateExpression(FunctionCallIdentifier.java:162)
        at org.apache.sysds.parser.StatementBlock.validateAssignmentStatement(StatementBlock.java:933)
        at org.apache.sysds.parser.StatementBlock.validate(StatementBlock.java:881)
        at org.apache.sysds.parser.DMLTranslator.validateParseTree(DMLTranslator.java:135)
        at org.apache.sysds.parser.DMLTranslator.validateParseTree(DMLTranslator.java:110)
        at org.apache.sysds.api.jmlc.Connection.prepareScript(Connection.java:261)
        ... 11 more

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant