diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 921b396e799e7..2cd515c89d3d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame} * {{{ * ./bin/run-example ml.DecisionTreeExample [options] * }}} + * Note that Decision Trees can take a large amount of memory. If the run-example command above + * fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g + * [examples JAR path] [options] + * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ object DecisionTreeExample { @@ -70,7 +77,7 @@ object DecisionTreeExample { val parser = new OptionParser[Params]("DecisionTreeExample") { head("DecisionTreeExample: an example decision tree app.") opt[String]("algo") - .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") .action((x, c) => c.copy(algo = x)) opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") @@ -222,18 +229,23 @@ object DecisionTreeExample { // (1) For classification, re-index classes. val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { - val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName) + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) stages += labelIndexer } // (2) Identify categorical features using VectorIndexer. // Features with more than maxCategories values will be treated as continuous. - val featuresIndexer = new VectorIndexer().setInputCol("features") - .setOutputCol("indexedFeatures").setMaxCategories(10) + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) stages += featuresIndexer // (3) Learn DecisionTree val dt = algo match { case "classification" => - new DecisionTreeClassifier().setFeaturesCol("indexedFeatures") + new DecisionTreeClassifier() + .setFeaturesCol("indexedFeatures") .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) @@ -242,7 +254,8 @@ object DecisionTreeExample { .setCacheNodeIds(params.cacheNodeIds) .setCheckpointInterval(params.checkpointInterval) case "regression" => - new DecisionTreeRegressor().setFeaturesCol("indexedFeatures") + new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures") .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala new file mode 100644 index 0000000000000..66046d3fd0d50 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * Params for [[IDF]] and [[IDFModel]]. + */ +private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol { + + /** + * The minimum of documents in which a term should appear. + * @group param + */ + final val minDocFreq = new IntParam( + this, "minDocFreq", "minimum of documents in which a term should appear for filtering") + + /** @group getParam */ + def getMinDocFreq: Int = getOrDefault(minDocFreq) + + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT) + SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Compute the Inverse Document Frequency (IDF) given a collection of documents. + */ +@AlphaComponent +final class IDF extends Estimator[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val idf = new feature.IDF(getMinDocFreq).fit(input) + val model = new IDFModel(this, map, idf) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[IDF]]. + */ +@AlphaComponent +class IDFModel private[ml] ( + override val parent: IDF, + override val fittingParamMap: ParamMap, + idfModel: feature.IDFModel) + extends Model[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val idf = udf { vec: Vector => idfModel.transform(vec) } + dataset.withColumn(map(outputCol), idf(col(map(inputCol)))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index 6f4509f03d033..eb2609faef05a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this.asInstanceOf[this.type] + this } /** @group getParam */ @@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params { def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ - protected def getOldImpurity: OldImpurity = { + private[ml] def getOldImpurity: OldImpurity = { getImpurity match { case "variance" => OldVariance case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index cb940f62990ed..708c769087dd0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -38,7 +38,7 @@ sealed trait Split extends Serializable { private[tree] def toOld: OldSplit } -private[ml] object Split { +private[tree] object Split { def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { oldSplit.featureType match { @@ -58,7 +58,7 @@ private[ml] object Split { * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ -final class CategoricalSplit( +final class CategoricalSplit private[ml] ( override val featureIndex: Int, leftCategories: Array[Double], private val numCategories: Int) @@ -130,7 +130,8 @@ final class CategoricalSplit( * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ -final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { +final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) + extends Split { override private[ml] def shouldGoLeft(features: Vector): Boolean = { features(featureIndex) <= threshold diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala new file mode 100644 index 0000000000000..4f7e062bb8e95 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +class IDFSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + def getResultFromDF(result: DataFrame): Array[Vector] = { + result.select("idf_value").collect().map { + case Row(features: Vector) => features + } + } + + def assertValues(lhs: Array[Vector], rhs: Array[Vector]): Unit = { + assert((lhs, rhs).zipped.forall { (vector1, vector2) => + vector1 ~== vector2 absTol 1E-5 + }, "The vector value is not correct after IDF.") + } + + def getResultFromVector(dataSet: Array[Vector], model: Vector): Array[Vector] = { + dataSet.map { + case data: DenseVector => + val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y } + Vectors.dense(res) + case data: SparseVector => + val res = data.indices.zip(data.values).map { case (id, value) => + (id, value * model(id)) + } + Vectors.sparse(data.size, res) + } + } + + test("Normalization with default parameter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0))) + val numOfData = data.length + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((numOfData + 1.0) / (x + 1.0)) + }) + val expected = getResultFromVector(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idf_value") + .fit(df) + + idfModel.transform(df).select("idf_value", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5) + } + } + + test("Normalization with setter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0))) + val numOfData = data.length + + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + val dataFrame = sc.parallelize(data, 2).map(Tuple1.apply).toDF("features") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idf_value") + .setMinDocFreq(1) + .fit(dataFrame) + + val expectedModel = Vectors.dense(Array(0, 3, 1, 2).map { x => + if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 + }) + + assertValues( + getResultFromDF(idfModel.transform(dataFrame)), + getResultFromVector(data, expectedModel)) + } +}