diff --git a/docs/content/docs/operators/classification/knn.md b/docs/content/docs/operators/classification/knn.md index 0724f2daf..f62af4de7 100644 --- a/docs/content/docs/operators/classification/knn.md +++ b/docs/content/docs/operators/classification/knn.md @@ -67,7 +67,7 @@ Below are the parameters required by `KnnModel`. ```java import org.apache.flink.ml.classification.knn.Knn; import org.apache.flink.ml.classification.knn.KnnModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/classification/linearsvc.md b/docs/content/docs/operators/classification/linearsvc.md index 5b134a995..02e5f912a 100644 --- a/docs/content/docs/operators/classification/linearsvc.md +++ b/docs/content/docs/operators/classification/linearsvc.md @@ -77,7 +77,7 @@ Below are the parameters required by `LinearSVCModel`. ```java import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/classification/logisticregression.md b/docs/content/docs/operators/classification/logisticregression.md index b45b8480a..c68f9ec2b 100644 --- a/docs/content/docs/operators/classification/logisticregression.md +++ b/docs/content/docs/operators/classification/logisticregression.md @@ -74,7 +74,7 @@ Below are the parameters required by `LogisticRegressionModel`. ```java import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -251,9 +251,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -323,7 +323,7 @@ public class OnlineLogisticRegressionExample { // Creates an online LogisticRegression object and initializes its parameters and initial // model data. - Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L); + Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L, 2L, 0L); Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData)); OnlineLogisticRegression olr = new OnlineLogisticRegression() diff --git a/docs/content/docs/operators/classification/naivebayes.md b/docs/content/docs/operators/classification/naivebayes.md index 3fe9beb8e..b6900bafc 100644 --- a/docs/content/docs/operators/classification/naivebayes.md +++ b/docs/content/docs/operators/classification/naivebayes.md @@ -66,7 +66,7 @@ Below are parameters required by `NaiveBayesModel`. ```java import org.apache.flink.ml.classification.naivebayes.NaiveBayes; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/clustering/agglomerativeclustering.md b/docs/content/docs/operators/clustering/agglomerativeclustering.md index 9ded65cca..254f92bfb 100644 --- a/docs/content/docs/operators/clustering/agglomerativeclustering.md +++ b/docs/content/docs/operators/clustering/agglomerativeclustering.md @@ -69,7 +69,7 @@ format of the merging information is import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering; import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/clustering/kmeans.md b/docs/content/docs/operators/clustering/kmeans.md index eeeff603a..cb6dae84f 100644 --- a/docs/content/docs/operators/clustering/kmeans.md +++ b/docs/content/docs/operators/clustering/kmeans.md @@ -67,7 +67,7 @@ Below are the parameters required by `KMeansModel`. ```java import org.apache.flink.ml.clustering.kmeans.KMeans; import org.apache.flink.ml.clustering.kmeans.KMeansModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -228,9 +228,9 @@ import org.apache.flink.ml.clustering.kmeans.KMeansModelData; import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; diff --git a/docs/content/docs/operators/feature/countvectorizer.md b/docs/content/docs/operators/feature/countvectorizer.md index b6658c06c..68905ac02 100644 --- a/docs/content/docs/operators/feature/countvectorizer.md +++ b/docs/content/docs/operators/feature/countvectorizer.md @@ -72,7 +72,7 @@ Below are the parameters required by `CountVectorizerModel`. ```java import org.apache.flink.ml.feature.countvectorizer.CountVectorizer; import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/dct.md b/docs/content/docs/operators/feature/dct.md index 356260be5..fde4b1e9b 100644 --- a/docs/content/docs/operators/feature/dct.md +++ b/docs/content/docs/operators/feature/dct.md @@ -60,7 +60,7 @@ that the transform matrix is unitary (aka scaled DCT-II). ```java import org.apache.flink.ml.feature.dct.DCT; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/elementwiseproduct.md b/docs/content/docs/operators/feature/elementwiseproduct.md index 0021c15b4..4461262de 100644 --- a/docs/content/docs/operators/feature/elementwiseproduct.md +++ b/docs/content/docs/operators/feature/elementwiseproduct.md @@ -58,7 +58,7 @@ scaling vector, the transformer will throw an IllegalArgumentException. ```java import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/featurehasher.md b/docs/content/docs/operators/feature/featurehasher.md index e804d9ae6..af93911fb 100644 --- a/docs/content/docs/operators/feature/featurehasher.md +++ b/docs/content/docs/operators/feature/featurehasher.md @@ -69,7 +69,7 @@ For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for det ```java import org.apache.flink.ml.feature.featurehasher.FeatureHasher; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/hashingtf.md b/docs/content/docs/operators/feature/hashingtf.md index d340d9096..45a27af59 100644 --- a/docs/content/docs/operators/feature/hashingtf.md +++ b/docs/content/docs/operators/feature/hashingtf.md @@ -63,7 +63,7 @@ the output values are accumulated by default. ```java import org.apache.flink.ml.feature.hashingtf.HashingTF; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/idf.md b/docs/content/docs/operators/feature/idf.md index ca26c286a..2b478184f 100644 --- a/docs/content/docs/operators/feature/idf.md +++ b/docs/content/docs/operators/feature/idf.md @@ -73,7 +73,7 @@ Below are the parameters required by `IDFModel`. ```java import org.apache.flink.ml.feature.idf.IDF; import org.apache.flink.ml.feature.idf.IDFModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/interaction.md b/docs/content/docs/operators/feature/interaction.md index 1b9bca1bc..a313b1246 100644 --- a/docs/content/docs/operators/feature/interaction.md +++ b/docs/content/docs/operators/feature/interaction.md @@ -62,7 +62,7 @@ be Vector(3, 6, 4, 8). ```java import org.apache.flink.ml.feature.interaction.Interaction; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/kbinsdiscretizer.md b/docs/content/docs/operators/feature/kbinsdiscretizer.md index b26e7804c..047851742 100644 --- a/docs/content/docs/operators/feature/kbinsdiscretizer.md +++ b/docs/content/docs/operators/feature/kbinsdiscretizer.md @@ -70,7 +70,7 @@ Below are the parameters required by `KBinsDiscretizerModel`. import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/maxabsscaler.md b/docs/content/docs/operators/feature/maxabsscaler.md index 1c3d4fd91..c5ba4c81b 100644 --- a/docs/content/docs/operators/feature/maxabsscaler.md +++ b/docs/content/docs/operators/feature/maxabsscaler.md @@ -59,7 +59,7 @@ It does not shift/center the data and thus does not destroy any sparsity. ```java import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler; import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/minhashlsh.md b/docs/content/docs/operators/feature/minhashlsh.md index 5c3aee92a..0c7494cd7 100644 --- a/docs/content/docs/operators/feature/minhashlsh.md +++ b/docs/content/docs/operators/feature/minhashlsh.md @@ -75,9 +75,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.feature.lsh.MinHashLSH; import org.apache.flink.ml.feature.lsh.MinHashLSHModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/minmaxscaler.md b/docs/content/docs/operators/feature/minmaxscaler.md index b4427b629..ed0e953eb 100644 --- a/docs/content/docs/operators/feature/minmaxscaler.md +++ b/docs/content/docs/operators/feature/minmaxscaler.md @@ -59,7 +59,7 @@ MinMaxScaler is an algorithm that rescales feature values to a common range ```java import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/normalizer.md b/docs/content/docs/operators/feature/normalizer.md index 65e3760a4..7562d33fa 100644 --- a/docs/content/docs/operators/feature/normalizer.md +++ b/docs/content/docs/operators/feature/normalizer.md @@ -57,7 +57,7 @@ A Transformer that normalizes a vector to have unit norm using the given p-norm. ```java import org.apache.flink.ml.feature.normalizer.Normalizer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/onehotencoder.md b/docs/content/docs/operators/feature/onehotencoder.md index 32ff2ffa0..3b706445b 100644 --- a/docs/content/docs/operators/feature/onehotencoder.md +++ b/docs/content/docs/operators/feature/onehotencoder.md @@ -64,7 +64,7 @@ vector column for each input column. ```java import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/onlinestandardscaler.md b/docs/content/docs/operators/feature/onlinestandardscaler.md index ef998f1f8..d158d16e5 100644 --- a/docs/content/docs/operators/feature/onlinestandardscaler.md +++ b/docs/content/docs/operators/feature/onlinestandardscaler.md @@ -91,9 +91,9 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.ml.common.window.EventTimeTumblingWindows; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.DataTypes; @@ -202,13 +202,13 @@ t_env = StreamTableEnvironment.create(env) # Generates input data. dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType( - get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(), - get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer() + get_gateway().jvm.org.apache.flink.ml.linalg.DenseIntDoubleVector(0).getClass(), + get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer() ).getSerializerString() schema = Schema.new_builder() .column("ts", "TIMESTAMP_LTZ(3)") - .column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')" + .column("input", "RAW('org.apache.flink.ml.linalg.DenseIntDoubleVector', '{serializer}')" .format(serializer=dense_vector_serializer)) .watermark("ts", "ts - INTERVAL '1' SECOND") .build() diff --git a/docs/content/docs/operators/feature/polynomialexpansion.md b/docs/content/docs/operators/feature/polynomialexpansion.md index a7f5ee899..ab5cc89a9 100644 --- a/docs/content/docs/operators/feature/polynomialexpansion.md +++ b/docs/content/docs/operators/feature/polynomialexpansion.md @@ -63,7 +63,7 @@ http://en.wikipedia.org/wiki/Polynomial_expansion. ```java import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/robustscaler.md b/docs/content/docs/operators/feature/robustscaler.md index a37085dd6..69c50740f 100644 --- a/docs/content/docs/operators/feature/robustscaler.md +++ b/docs/content/docs/operators/feature/robustscaler.md @@ -87,7 +87,7 @@ Below are the parameters required by `RobustScalerModel`. ```java import org.apache.flink.ml.feature.robustscaler.RobustScaler; import org.apache.flink.ml.feature.robustscaler.RobustScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/standardscaler.md b/docs/content/docs/operators/feature/standardscaler.md index 55879d35c..c313a400d 100644 --- a/docs/content/docs/operators/feature/standardscaler.md +++ b/docs/content/docs/operators/feature/standardscaler.md @@ -59,7 +59,7 @@ the mean and scaling each dimension to unit variance. ```java import org.apache.flink.ml.feature.standardscaler.StandardScaler; import org.apache.flink.ml.feature.standardscaler.StandardScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/univariatefeatureselector.md b/docs/content/docs/operators/feature/univariatefeatureselector.md index 873919e21..577fd58a8 100644 --- a/docs/content/docs/operators/feature/univariatefeatureselector.md +++ b/docs/content/docs/operators/feature/univariatefeatureselector.md @@ -105,7 +105,7 @@ Below are the parameters required by `UnivariateFeatureSelectorModel`. ```java import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector; import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/variancethresholdselector.md b/docs/content/docs/operators/feature/variancethresholdselector.md index 5cd9a2f14..474276a1f 100644 --- a/docs/content/docs/operators/feature/variancethresholdselector.md +++ b/docs/content/docs/operators/feature/variancethresholdselector.md @@ -69,7 +69,7 @@ Below are the parameters required by `VarianceThresholdSelectorModel`. ```java import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector; import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/vectorassembler.md b/docs/content/docs/operators/feature/vectorassembler.md index 2877e419c..df8ea5216 100644 --- a/docs/content/docs/operators/feature/vectorassembler.md +++ b/docs/content/docs/operators/feature/vectorassembler.md @@ -69,7 +69,7 @@ the strategy specified by the {@link HasHandleInvalid} parameter as follows: ```java import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/vectorslicer.md b/docs/content/docs/operators/feature/vectorslicer.md index c7965e0c0..8b6f0c11d 100644 --- a/docs/content/docs/operators/feature/vectorslicer.md +++ b/docs/content/docs/operators/feature/vectorslicer.md @@ -61,7 +61,7 @@ it throws an IllegalArgumentException. ```java import org.apache.flink.ml.feature.vectorslicer.VectorSlicer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/functions.md b/docs/content/docs/operators/functions.md index 96968eae8..32bfd71fa 100644 --- a/docs/content/docs/operators/functions.md +++ b/docs/content/docs/operators/functions.md @@ -38,7 +38,7 @@ of double arrays. {{< tab "Java">}} ```java -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -145,7 +145,7 @@ DenseVector instances. {{< tab "Java">}} ```java -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; diff --git a/docs/content/docs/operators/regression/linearregression.md b/docs/content/docs/operators/regression/linearregression.md index 3b9717be5..1996ffdfe 100644 --- a/docs/content/docs/operators/regression/linearregression.md +++ b/docs/content/docs/operators/regression/linearregression.md @@ -72,7 +72,7 @@ Below are the parameters required by `LinearRegressionModel`. {{< tab "Java">}} ```java -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; diff --git a/docs/content/docs/try-flink-ml/java/build-your-own-project.md b/docs/content/docs/try-flink-ml/java/build-your-own-project.md index 713bb9241..78a1752ce 100644 --- a/docs/content/docs/try-flink-ml/java/build-your-own-project.md +++ b/docs/content/docs/try-flink-ml/java/build-your-own-project.md @@ -176,7 +176,7 @@ package myflinkml; import org.apache.flink.ml.clustering.kmeans.KMeans; import org.apache.flink.ml.clustering.kmeans.KMeansModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java index 4a272e54f..6ca9f1027 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java @@ -23,8 +23,8 @@ import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorArrayGenerator; import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.table.api.DataTypes; @@ -77,8 +77,8 @@ public Table[] getData(StreamTableEnvironment tEnv) { * information. */ public static class GenerateWeightsFunction extends ScalarFunction { - public DenseVector eval(DenseVector[] centroids) { - return new DenseVector(centroids.length); + public DenseIntDoubleVector eval(DenseIntDoubleVector[] centroids) { + return new DenseIntDoubleVector(centroids.length); } @Override @@ -87,7 +87,7 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) { .outputTypeStrategy( callContext -> Optional.of( - DataTypes.of(DenseVectorTypeInfo.INSTANCE) + DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE) .toDataType(typeFactory))) .build(); } diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java index 0ab32cbb9..ef9c15baf 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java @@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.types.Row; import org.apache.flink.util.Preconditions; @@ -43,9 +43,9 @@ protected RowGenerator[] getRowGenerators() { new RowGenerator(getNumValues(), getSeed()) { @Override protected Row getRow() { - DenseVector[] result = new DenseVector[arraySize]; + DenseIntDoubleVector[] result = new DenseIntDoubleVector[arraySize]; for (int i = 0; i < arraySize; i++) { - result[i] = new DenseVector(vectorDim); + result[i] = new DenseIntDoubleVector(vectorDim); for (int j = 0; j < vectorDim; j++) { result[i].values[j] = random.nextDouble(); } @@ -58,7 +58,9 @@ protected Row getRow() { @Override protected RowTypeInfo getRowTypeInfo() { return new RowTypeInfo( - new TypeInformation[] {TypeInformation.of(DenseVector[].class)}, + new TypeInformation[] { + TypeInformation.of(DenseIntDoubleVector[].class) + }, columnNames[0]); } } diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java index 2923129f2..fcba2c51e 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java @@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.types.Row; import org.apache.flink.util.Preconditions; @@ -52,7 +52,8 @@ protected Row getRow() { @Override protected RowTypeInfo getRowTypeInfo() { return new RowTypeInfo( - new TypeInformation[] {DenseVectorTypeInfo.INSTANCE}, columnNames[0]); + new TypeInformation[] {DenseIntDoubleVectorTypeInfo.INSTANCE}, + columnNames[0]); } } }; diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java index dfc7b157d..4f4cdf383 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java @@ -23,7 +23,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.IntParam; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.ParamValidators; @@ -113,7 +113,7 @@ protected Row getRow() { protected RowTypeInfo getRowTypeInfo() { return new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, colNames[0]); } diff --git a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java index edb53197f..6c5fe64e0 100644 --- a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java +++ b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java @@ -25,7 +25,7 @@ import org.apache.flink.ml.benchmark.datagenerator.common.LabeledPointWithWeightGenerator; import org.apache.flink.ml.benchmark.datagenerator.common.RandomStringArrayGenerator; import org.apache.flink.ml.benchmark.datagenerator.common.RandomStringGenerator; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; @@ -68,9 +68,10 @@ public void testDenseVectorGenerator() { it.hasNext(); ) { Row row = it.next(); assertEquals(1, row.getArity()); - DenseVector vector = (DenseVector) row.getField(generator.getColNames()[0][0]); + DenseIntDoubleVector vector = + (DenseIntDoubleVector) row.getField(generator.getColNames()[0][0]); assertNotNull(vector); - assertEquals(vector.size(), generator.getVectorDim()); + assertEquals(vector.size().intValue(), generator.getVectorDim()); count++; } assertEquals(generator.getNumValues(), count); @@ -90,11 +91,12 @@ public void testDenseVectorArrayGenerator() { it.hasNext(); ) { Row row = it.next(); assertEquals(1, row.getArity()); - DenseVector[] vectors = (DenseVector[]) row.getField(generator.getColNames()[0][0]); + DenseIntDoubleVector[] vectors = + (DenseIntDoubleVector[]) row.getField(generator.getColNames()[0][0]); assertNotNull(vectors); assertEquals(generator.getArraySize(), vectors.length); - for (DenseVector vector : vectors) { - assertEquals(vector.size(), generator.getVectorDim()); + for (DenseIntDoubleVector vector : vectors) { + assertEquals(vector.size().intValue(), generator.getVectorDim()); } count++; } @@ -119,7 +121,7 @@ public void testLabeledPointWithWeightGenerator() { it.hasNext(); ) { Row row = it.next(); count++; - DenseVector features = (DenseVector) row.getField(featuresCol); + DenseIntDoubleVector features = (DenseIntDoubleVector) row.getField(featuresCol); assertNotNull(features); for (double value : features.values) { assertTrue(value >= 0); diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java index 8d278502a..8af464241 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java @@ -60,6 +60,7 @@ public class TableUtils { LOGICAL_TYPE_ROOTS_USING_EXTERNAL_TYPE_INFO.add(LogicalTypeRoot.MAP); LOGICAL_TYPE_ROOTS_USING_EXTERNAL_TYPE_INFO.add(LogicalTypeRoot.MULTISET); LOGICAL_TYPE_ROOTS_USING_EXTERNAL_TYPE_INFO.add(LogicalTypeRoot.ROW); + LOGICAL_TYPE_ROOTS_USING_EXTERNAL_TYPE_INFO.add(LogicalTypeRoot.STRUCTURED_TYPE); } // Constructs a RowTypeInfo from the given schema. Currently, this function does not support diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java index f812f8b20..acbff8c50 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java @@ -26,7 +26,7 @@ import org.apache.flink.ml.common.window.ProcessingTimeSessionWindows; import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows; import org.apache.flink.ml.common.window.Windows; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.BooleanParam; import org.apache.flink.ml.param.DoubleArrayArrayParam; @@ -138,7 +138,7 @@ private interface MyParams extends WithParams { "Description", new Double[][] {new Double[] {14.0, 15.0}, new Double[] {16.0, 17.0}}); - Param VECTOR_PARAM = + Param VECTOR_PARAM = new VectorParam("vectorParam", "Description", Vectors.dense(1.0, 2.0, 3.0)); Param WINDOWS_PARAM = diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java index b420ea4d8..5a28997ce 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java @@ -20,13 +20,14 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.DenseMatrix; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfo; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -113,12 +114,18 @@ public void testGetRowTypeInfo() { dataFields.add(Collections.singletonMap(0.1, 1)); preDefinedDataTypes.add(DataTypes.ROW(DataTypes.INT(), DataTypes.BIGINT())); dataFields.add(Row.of(1, 2L)); - preDefinedDataTypes.add(DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)); - dataFields.add(new DenseVector(new double[] {0.1, 0.2})); - preDefinedDataTypes.add(DataTypes.RAW(SparseVectorTypeInfo.INSTANCE)); - dataFields.add(new SparseVector(2, new int[] {0}, new double[] {0.1})); + preDefinedDataTypes.add(DataTypes.RAW(DenseIntDoubleVectorTypeInfo.INSTANCE)); + dataFields.add(new DenseIntDoubleVector(new double[] {0.1, 0.2})); + preDefinedDataTypes.add(DataTypes.RAW(SparseIntDoubleVectorTypeInfo.INSTANCE)); + dataFields.add(new SparseIntDoubleVector(2, new int[] {0}, new double[] {0.1})); preDefinedDataTypes.add(DataTypes.RAW(DenseMatrixTypeInfo.INSTANCE)); dataFields.add(new DenseMatrix(2, 2)); + preDefinedDataTypes.add( + DataTypes.STRUCTURED( + Tuple2.class, + DataTypes.FIELD("f0", DataTypes.BIGINT()), + DataTypes.FIELD("f1", DataTypes.BIGINT()))); + dataFields.add(Tuple2.of(1L, 2L)); Schema.Builder builder = Schema.newBuilder(); for (int i = 0; i < preDefinedDataTypes.size(); i++) { diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java index ec97b48c6..683c19e46 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java @@ -34,9 +34,9 @@ import org.apache.flink.ml.api.Stage; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.api.TransformerServable; import org.apache.flink.streaming.api.datastream.DataStream; @@ -253,8 +253,8 @@ public static Table convertDataTypesToSparseInt(StreamTableEnvironment tEnv, Tab RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(table.getResolvedSchema()); TypeInformation[] fieldTypes = inputTypeInfo.getFieldTypes(); for (int i = 0; i < fieldTypes.length; i++) { - if (fieldTypes[i].getTypeClass().equals(DenseVector.class)) { - fieldTypes[i] = SparseVectorTypeInfo.INSTANCE; + if (fieldTypes[i].getTypeClass().equals(DenseIntDoubleVector.class)) { + fieldTypes[i] = SparseIntDoubleVectorTypeInfo.INSTANCE; } else if (fieldTypes[i].getTypeClass().equals(Double.class)) { fieldTypes[i] = Types.INT; } @@ -273,8 +273,8 @@ public Row map(Row row) { int arity = row.getArity(); for (int i = 0; i < arity; i++) { Object obj = row.getField(i); - if (obj instanceof Vector) { - row.setField(i, ((Vector) obj).toSparse()); + if (obj instanceof IntDoubleVector) { + row.setField(i, ((IntDoubleVector) obj).toSparse()); } else if (obj instanceof Number) { row.setField(i, ((Number) obj).intValue()); } @@ -294,8 +294,8 @@ public static Class[] getColumnDataTypes(Table table) { } /** Note: this comparator imposes orderings that are inconsistent with equals. */ - public static int compare(Vector first, Vector second) { - if (first.size() != second.size()) { + public static int compare(IntDoubleVector first, IntDoubleVector second) { + if (first.size().intValue() != second.size().intValue()) { return Integer.compare(first.size(), second.size()); } else { for (int i = 0; i < first.size(); i++) { diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/ArrayToVectorExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/ArrayToVectorExample.java index 85c57b774..dc22dac24 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/ArrayToVectorExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/ArrayToVectorExample.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.examples; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -49,7 +49,7 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); Double[] doubleArray = row.getFieldAs("array"); - Vector vector = row.getFieldAs("vector"); + IntDoubleVector vector = row.getFieldAs("vector"); System.out.printf( "Input double array: %s\tOutput vector: %s\n", Arrays.toString(doubleArray), vector); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/VectorToArrayExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/VectorToArrayExample.java index 733fe45ce..f67a3b742 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/VectorToArrayExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/VectorToArrayExample.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.examples; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -54,7 +55,7 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector vector = row.getFieldAs("vector"); + IntDoubleVector vector = row.getFieldAs("vector"); Double[] doubleArray = row.getFieldAs("array"); System.out.printf( "Input vector: %s\tOutput double array: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/KnnExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/KnnExample.java index 9941c3065..3f3ca9fe4 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/KnnExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/KnnExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.classification.knn.Knn; import org.apache.flink.ml.classification.knn.KnnModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -78,7 +78,8 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(knn.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(knn.getFeaturesCol()); double expectedResult = (Double) row.getField(knn.getLabelCol()); double predictionResult = (Double) row.getField(knn.getPredictionCol()); System.out.printf( diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LinearSVCExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LinearSVCExample.java index f4941fb19..79ff44e74 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LinearSVCExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LinearSVCExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -62,11 +62,12 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(linearSVC.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(linearSVC.getFeaturesCol()); double expectedResult = (Double) row.getField(linearSVC.getLabelCol()); double predictionResult = (Double) row.getField(linearSVC.getPredictionCol()); - DenseVector rawPredictionResult = - (DenseVector) row.getField(linearSVC.getRawPredictionCol()); + DenseIntDoubleVector rawPredictionResult = + (DenseIntDoubleVector) row.getField(linearSVC.getRawPredictionCol()); System.out.printf( "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n", features, expectedResult, predictionResult, rawPredictionResult); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LogisticRegressionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LogisticRegressionExample.java index fc2168d06..6b3206496 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LogisticRegressionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LogisticRegressionExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -62,10 +62,12 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(lr.getFeaturesCol()); double expectedResult = (Double) row.getField(lr.getLabelCol()); double predictionResult = (Double) row.getField(lr.getPredictionCol()); - DenseVector rawPredictionResult = (DenseVector) row.getField(lr.getRawPredictionCol()); + DenseIntDoubleVector rawPredictionResult = + (DenseIntDoubleVector) row.getField(lr.getRawPredictionCol()); System.out.printf( "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n", features, expectedResult, predictionResult, rawPredictionResult); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/NaiveBayesExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/NaiveBayesExample.java index 1f369200a..e78f14520 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/NaiveBayesExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/NaiveBayesExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.classification.naivebayes.NaiveBayes; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -69,7 +69,8 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(naiveBayes.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(naiveBayes.getFeaturesCol()); double predictionResult = (Double) row.getField(naiveBayes.getPredictionCol()); System.out.printf("Features: %s \tPrediction Result: %s\n", features, predictionResult); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java index d4e7b2f27..6fb1a0eff 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java @@ -24,9 +24,9 @@ import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -81,7 +81,7 @@ public static void main(String[] args) { RowTypeInfo typeInfo = new RowTypeInfo( - new TypeInformation[] {DenseVectorTypeInfo.INSTANCE, Types.DOUBLE}, + new TypeInformation[] {DenseIntDoubleVectorTypeInfo.INSTANCE, Types.DOUBLE}, new String[] {"features", "label"}); SourceFunction trainSource = @@ -96,7 +96,8 @@ public static void main(String[] args) { // Creates an online LogisticRegression object and initializes its parameters and initial // model data. - Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L); + Row initModelData = + Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L, 2L, 0L); Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData)); OnlineLogisticRegression olr = new OnlineLogisticRegression() @@ -119,10 +120,12 @@ public static void main(String[] args) { // would change over time. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(olr.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(olr.getFeaturesCol()); Double expectedResult = (Double) row.getField(olr.getLabelCol()); Double predictionResult = (Double) row.getField(olr.getPredictionCol()); - DenseVector rawPredictionResult = (DenseVector) row.getField(olr.getRawPredictionCol()); + DenseIntDoubleVector rawPredictionResult = + (DenseIntDoubleVector) row.getField(olr.getRawPredictionCol()); System.out.printf( "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n", features, expectedResult, predictionResult, rawPredictionResult); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java index c48448f1d..54fa8851e 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java @@ -21,7 +21,7 @@ import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering; import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -37,7 +37,7 @@ public static void main(String[] args) { StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. - DataStream inputStream = + DataStream inputStream = env.fromElements( Vectors.dense(1, 1), Vectors.dense(1, 4), @@ -61,8 +61,8 @@ public static void main(String[] args) { // Extracts and displays the clustering results. for (CloseableIterator it = outputs[0].execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = - (DenseVector) row.getField(agglomerativeClustering.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(agglomerativeClustering.getFeaturesCol()); int clusterId = (Integer) row.getField(agglomerativeClustering.getPredictionCol()); System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/KMeansExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/KMeansExample.java index 62edf2e0d..5f5678fe1 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/KMeansExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/KMeansExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.clustering.kmeans.KMeans; import org.apache.flink.ml.clustering.kmeans.KMeansModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -36,7 +36,7 @@ public static void main(String[] args) { StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. - DataStream inputStream = + DataStream inputStream = env.fromElements( Vectors.dense(0.0, 0.0), Vectors.dense(0.0, 0.3), @@ -58,7 +58,8 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(kmeans.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(kmeans.getFeaturesCol()); int clusterId = (Integer) row.getField(kmeans.getPredictionCol()); System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/OnlineKMeansExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/OnlineKMeansExample.java index b7b7eb3bb..579f0c113 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/OnlineKMeansExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/OnlineKMeansExample.java @@ -23,9 +23,9 @@ import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -73,13 +73,14 @@ public static void main(String[] args) { SourceFunction trainSource = new PeriodicSourceFunction(1000, Arrays.asList(trainData1, trainData2)); DataStream trainStream = - env.addSource(trainSource, new RowTypeInfo(DenseVectorTypeInfo.INSTANCE)); + env.addSource(trainSource, new RowTypeInfo(DenseIntDoubleVectorTypeInfo.INSTANCE)); Table trainTable = tEnv.fromDataStream(trainStream).as("features"); SourceFunction predictSource = new PeriodicSourceFunction(1000, Collections.singletonList(predictData)); DataStream predictStream = - env.addSource(predictSource, new RowTypeInfo(DenseVectorTypeInfo.INSTANCE)); + env.addSource( + predictSource, new RowTypeInfo(DenseIntDoubleVectorTypeInfo.INSTANCE)); Table predictTable = tEnv.fromDataStream(predictStream).as("features"); // Creates an online K-means object and initializes its parameters and initial model data. @@ -102,10 +103,12 @@ public static void main(String[] args) { // would change over time. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row1 = it.next(); - DenseVector features1 = (DenseVector) row1.getField(onlineKMeans.getFeaturesCol()); + DenseIntDoubleVector features1 = + (DenseIntDoubleVector) row1.getField(onlineKMeans.getFeaturesCol()); Integer clusterId1 = (Integer) row1.getField(onlineKMeans.getPredictionCol()); Row row2 = it.next(); - DenseVector features2 = (DenseVector) row2.getField(onlineKMeans.getFeaturesCol()); + DenseIntDoubleVector features2 = + (DenseIntDoubleVector) row2.getField(onlineKMeans.getFeaturesCol()); Integer clusterId2 = (Integer) row2.getField(onlineKMeans.getPredictionCol()); if (Objects.equals(clusterId1, clusterId2)) { System.out.printf("%s and %s are now in the same cluster.\n", features1, features2); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java index fb1287caa..b28a37869 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.countvectorizer.CountVectorizer; import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -62,7 +62,8 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); String[] inputValue = (String[]) row.getField(countVectorizer.getInputCol()); - SparseVector outputValue = (SparseVector) row.getField(countVectorizer.getOutputCol()); + SparseIntDoubleVector outputValue = + (SparseIntDoubleVector) row.getField(countVectorizer.getOutputCol()); System.out.printf( "Input Value: %-15s \tOutput Value: %s\n", Arrays.toString(inputValue), outputValue.toString()); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/DCTExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/DCTExample.java index 2b3d68316..afb4987cd 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/DCTExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/DCTExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.dct.DCT; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -37,7 +37,7 @@ public static void main(String[] args) { StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. - List inputData = + List inputData = Arrays.asList( Vectors.dense(1.0, 1.0, 1.0, 1.0), Vectors.dense(1.0, 0.0, -1.0, 0.0)); Table inputTable = tEnv.fromDataStream(env.fromCollection(inputData)).as("input"); @@ -52,8 +52,8 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = row.getFieldAs(dct.getInputCol()); - Vector outputValue = row.getFieldAs(dct.getOutputCol()); + IntDoubleVector inputValue = row.getFieldAs(dct.getInputCol()); + IntDoubleVector outputValue = row.getFieldAs(dct.getOutputCol()); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ElementwiseProductExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ElementwiseProductExample.java index 927133330..386b9da71 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ElementwiseProductExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ElementwiseProductExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -56,8 +56,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = (Vector) row.getField(elementwiseProduct.getInputCol()); - Vector outputValue = (Vector) row.getField(elementwiseProduct.getOutputCol()); + IntDoubleVector inputValue = + (IntDoubleVector) row.getField(elementwiseProduct.getInputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(elementwiseProduct.getOutputCol()); System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java index c0f81c62f..fbbb1827c 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.featurehasher.FeatureHasher; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -61,7 +61,8 @@ public static void main(String[] args) { for (int i = 0; i < inputValues.length; i++) { inputValues[i] = row.getField(featureHash.getInputCols()[i]); } - Vector outputValue = (Vector) row.getField(featureHash.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(featureHash.getOutputCol()); System.out.printf( "Input Values: %s \tOutput Value: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java index 213ebd5cd..4927c4a50 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.hashingtf.HashingTF; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -60,7 +60,8 @@ public static void main(String[] args) { Row row = it.next(); List inputValue = (List) row.getField(hashingTF.getInputCol()); - SparseVector outputValue = (SparseVector) row.getField(hashingTF.getOutputCol()); + SparseIntDoubleVector outputValue = + (SparseIntDoubleVector) row.getField(hashingTF.getOutputCol()); System.out.printf( "Input Value: %s \tOutput Value: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java index ffa4d710a..53bd93ec5 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.idf.IDF; import org.apache.flink.ml.feature.idf.IDFModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -56,8 +56,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(idf.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(idf.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(idf.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(idf.getOutputCol()); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/InteractionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/InteractionExample.java index 6acb89038..1115baefa 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/InteractionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/InteractionExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.interaction.Interaction; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -58,7 +58,8 @@ public static void main(String[] args) { for (int i = 0; i < inputValues.length; i++) { inputValues[i] = row.getField(interaction.getInputCols()[i]); } - Vector outputValue = (Vector) row.getField(interaction.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(interaction.getOutputCol()); System.out.printf( "Input Values: %s \tOutput Value: %s\n", Arrays.toString(inputValues), outputValue); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java index 1478f8c84..bcc05684d 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java @@ -21,7 +21,7 @@ import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -63,8 +63,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(kBinsDiscretizer.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(kBinsDiscretizer.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(kBinsDiscretizer.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(kBinsDiscretizer.getOutputCol()); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java index cd53394f9..33826c989 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler; import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -64,8 +64,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(maxAbsScaler.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(maxAbsScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(maxAbsScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(maxAbsScaler.getOutputCol()); System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinHashLSHExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinHashLSHExample.java index 89c6bd624..c155ff11b 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinHashLSHExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinHashLSHExample.java @@ -22,9 +22,9 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.feature.lsh.MinHashLSH; import org.apache.flink.ml.feature.lsh.MinHashLSHModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -77,7 +77,7 @@ public static void main(String[] args) throws Exception { Types.ROW_NAMED( new String[] {"id", "vec"}, Types.INT, - TypeInformation.of(SparseVector.class)))); + TypeInformation.of(SparseIntDoubleVector.class)))); Table dataB = tEnv.fromDataStream( @@ -104,7 +104,7 @@ public static void main(String[] args) throws Exception { Types.ROW_NAMED( new String[] {"id", "vec"}, Types.INT, - TypeInformation.of(SparseVector.class)))); + TypeInformation.of(SparseIntDoubleVector.class)))); // Creates a MinHashLSH estimator object and initializes its parameters. MinHashLSH lsh = @@ -124,14 +124,15 @@ public static void main(String[] args) throws Exception { List fieldNames = output.getResolvedSchema().getColumnNames(); for (Row result : (List) IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect())) { - Vector inputValue = result.getFieldAs(fieldNames.indexOf(lsh.getInputCol())); - DenseVector[] outputValue = result.getFieldAs(fieldNames.indexOf(lsh.getOutputCol())); + IntDoubleVector inputValue = result.getFieldAs(fieldNames.indexOf(lsh.getInputCol())); + DenseIntDoubleVector[] outputValue = + result.getFieldAs(fieldNames.indexOf(lsh.getOutputCol())); System.out.printf( "Vector: %s \tHash values: %s\n", inputValue, Arrays.toString(outputValue)); } // Finds approximate nearest neighbors of the key. - Vector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1., 1.}); + IntDoubleVector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1., 1.}); output = model.approxNearestNeighbors(dataA, key, 2).select($("id"), $("distCol")); for (Row result : (List) IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect())) { diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinMaxScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinMaxScalerExample.java index 98e908ff4..cde8541c1 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinMaxScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinMaxScalerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -64,8 +64,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(minMaxScaler.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(minMaxScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(minMaxScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(minMaxScaler.getOutputCol()); System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NormalizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NormalizerExample.java index b0c6332b2..26a813b4e 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NormalizerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NormalizerExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.normalizer.Normalizer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -52,9 +52,9 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = (Vector) row.getField(normalizer.getInputCol()); + IntDoubleVector inputValue = (IntDoubleVector) row.getField(normalizer.getInputCol()); - Vector outputValue = (Vector) row.getField(normalizer.getOutputCol()); + IntDoubleVector outputValue = (IntDoubleVector) row.getField(normalizer.getOutputCol()); System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OneHotEncoderExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OneHotEncoderExample.java index 1ebc8d95b..976ffa922 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OneHotEncoderExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OneHotEncoderExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -56,8 +56,8 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); Double inputValue = (Double) row.getField(oneHotEncoder.getInputCols()[0]); - SparseVector outputValue = - (SparseVector) row.getField(oneHotEncoder.getOutputCols()[0]); + SparseIntDoubleVector outputValue = + (SparseIntDoubleVector) row.getField(oneHotEncoder.getOutputCols()[0]); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java index d16ed12e3..8675e1445 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java @@ -24,9 +24,9 @@ import org.apache.flink.ml.common.window.EventTimeTumblingWindows; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.DataTypes; @@ -73,7 +73,10 @@ public static void main(String[] args) { inputStreamWithEventTime, Schema.newBuilder() .column("f0", DataTypes.BIGINT()) - .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .column( + "f1", + DataTypes.RAW( + DenseIntDoubleVectorTypeInfo.INSTANCE)) .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") .watermark("rowtime", "SOURCE_WATERMARK()") .build()) @@ -94,9 +97,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(onlineStandardScaler.getInputCol()); - DenseVector outputValue = - (DenseVector) row.getField(onlineStandardScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(onlineStandardScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(onlineStandardScaler.getOutputCol()); long modelVersion = row.getFieldAs(onlineStandardScaler.getModelVersionCol()); System.out.printf( "Input Value: %s\tOutput Value: %-65s\tModel Version: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/PolynomialExpansionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/PolynomialExpansionExample.java index e4e6e20d1..632981978 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/PolynomialExpansionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/PolynomialExpansionExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -56,9 +56,11 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = (Vector) row.getField(polynomialExpansion.getInputCol()); + IntDoubleVector inputValue = + (IntDoubleVector) row.getField(polynomialExpansion.getInputCol()); - Vector outputValue = (Vector) row.getField(polynomialExpansion.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(polynomialExpansion.getOutputCol()); System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java index 04f082375..98b6e0396 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.robustscaler.RobustScaler; import org.apache.flink.ml.feature.robustscaler.RobustScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -67,8 +67,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(robustScaler.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(robustScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(robustScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(robustScaler.getOutputCol()); System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/StandardScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/StandardScalerExample.java index 571a58a97..5d38290bc 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/StandardScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/StandardScalerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.standardscaler.StandardScaler; import org.apache.flink.ml.feature.standardscaler.StandardScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -55,8 +55,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(standardScaler.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(standardScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(standardScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(standardScaler.getOutputCol()); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java index 4d4c07f6a..a5a82751f 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector; import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -67,10 +67,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = - (DenseVector) row.getField(univariateFeatureSelector.getFeaturesCol()); - DenseVector outputValue = - (DenseVector) row.getField(univariateFeatureSelector.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(univariateFeatureSelector.getFeaturesCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(univariateFeatureSelector.getOutputCol()); System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java index d441a3bcb..2ec96fb9f 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector; import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -67,10 +67,10 @@ public static void main(String[] args) { System.out.printf("Variance Threshold: %s\n", threshold); for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = - (DenseVector) row.getField(varianceThresholdSelector.getInputCol()); - DenseVector outputValue = - (DenseVector) row.getField(varianceThresholdSelector.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(varianceThresholdSelector.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(varianceThresholdSelector.getOutputCol()); System.out.printf("Input Values: %-15s\tOutput Values: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java index 0c146251e..5480c58ed 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -71,7 +71,8 @@ public static void main(String[] args) { inputValues[i] = row.getField(vectorAssembler.getInputCols()[i]); } - Vector outputValue = (Vector) row.getField(vectorAssembler.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(vectorAssembler.getOutputCol()); System.out.printf( "Input Values: %s \tOutput Value: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java index 7c138992c..b5e7d91af 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.vectorslicer.VectorSlicer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -52,9 +52,10 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = (Vector) row.getField(vectorSlicer.getInputCol()); + IntDoubleVector inputValue = (IntDoubleVector) row.getField(vectorSlicer.getInputCol()); - Vector outputValue = (Vector) row.getField(vectorSlicer.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(vectorSlicer.getOutputCol()); System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/regression/LinearRegressionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/regression/LinearRegressionExample.java index bc39f3adb..73599c54b 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/regression/LinearRegressionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/regression/LinearRegressionExample.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.examples.regression; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; @@ -60,7 +60,8 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(lr.getFeaturesCol()); double expectedResult = (Double) row.getField(lr.getLabelCol()); double predictionResult = (Double) row.getField(lr.getPredictionCol()); System.out.printf( diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml index 8050dd78c..37f05ea72 100644 --- a/flink-ml-lib/pom.xml +++ b/flink-ml-lib/pom.xml @@ -138,6 +138,12 @@ under the License. test test-jar + + fastutil + fastutil + 5.0.9 + + diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/Functions.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/Functions.java index 694b35e4f..a798852a3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/Functions.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/Functions.java @@ -18,10 +18,11 @@ package org.apache.flink.ml; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.table.api.ApiExpression; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.catalog.DataTypeFactory; @@ -37,18 +38,18 @@ /** Built-in table functions for data transformations. */ @SuppressWarnings("unused") public class Functions { - /** Converts a column of {@link Vector}s into a column of double arrays. */ + /** Converts a column of {@link IntDoubleVector}s into a column of double arrays. */ public static ApiExpression vectorToArray(Object... arguments) { return call(VectorToArrayFunction.class, arguments); } /** - * A {@link ScalarFunction} that converts a column of {@link Vector}s into a column of double - * arrays. + * A {@link ScalarFunction} that converts a column of {@link IntDoubleVector}s into a column of + * double arrays. */ public static class VectorToArrayFunction extends ScalarFunction { public double[] eval(Vector vector) { - return vector.toArray(); + return (double[]) vector.toArray(); } @Override @@ -66,7 +67,8 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) { } /** - * Converts a column of arrays of numeric type into a column of {@link DenseVector} instances. + * Converts a column of arrays of numeric type into a column of {@link DenseIntDoubleVector} + * instances. */ public static ApiExpression arrayToVector(Object... arguments) { return call(ArrayToVectorFunction.class, arguments); @@ -74,18 +76,18 @@ public static ApiExpression arrayToVector(Object... arguments) { /** * A {@link ScalarFunction} that converts a column of arrays of numeric type into a column of - * {@link DenseVector} instances. + * {@link DenseIntDoubleVector} instances. */ public static class ArrayToVectorFunction extends ScalarFunction { - public DenseVector eval(double[] array) { + public DenseIntDoubleVector eval(double[] array) { return Vectors.dense(array); } - public DenseVector eval(Double[] array) { + public DenseIntDoubleVector eval(Double[] array) { return eval(ArrayUtils.toPrimitive(array)); } - public DenseVector eval(Number[] array) { + public DenseIntDoubleVector eval(Number[] array) { double[] doubles = new double[array.length]; for (int i = 0; i < array.length; i++) { doubles[i] = array[i].doubleValue(); @@ -99,7 +101,7 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) { .outputTypeStrategy( callContext -> Optional.of( - DataTypes.of(DenseVectorTypeInfo.INSTANCE) + DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE) .toDataType(typeFactory))) .build(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java index 8ad15e276..2c9124ba3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java @@ -24,9 +24,9 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.DenseMatrix; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -63,7 +63,7 @@ public KnnModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); /* Tuple3 : */ - DataStream> inputDataWithNorm = + DataStream> inputDataWithNorm = computeNormSquare(tEnv.toDataStream(inputs[0])); DataStream modelData = genModelData(inputDataWithNorm); KnnModel model = new KnnModel().setModelData(tEnv.fromDataStream(modelData)); @@ -96,29 +96,32 @@ public static Knn load(StreamTableEnvironment tEnv, String path) throws IOExcept * @return Knn model. */ private static DataStream genModelData( - DataStream> inputDataWithNormSqare) { + DataStream> inputDataWithNormSqare) { DataStream modelData = DataStreamUtils.mapPartition( inputDataWithNormSqare, new RichMapPartitionFunction< - Tuple3, KnnModelData>() { + Tuple3, KnnModelData>() { @Override public void mapPartition( - Iterable> dataPoints, + Iterable> + dataPoints, Collector out) { - List> bufferedDataPoints = - new ArrayList<>(); - for (Tuple3 dataPoint : dataPoints) { + List> + bufferedDataPoints = new ArrayList<>(); + for (Tuple3 dataPoint : + dataPoints) { bufferedDataPoints.add(dataPoint); } int featureDim = bufferedDataPoints.get(0).f0.size(); DenseMatrix packedFeatures = new DenseMatrix(featureDim, bufferedDataPoints.size()); - DenseVector normSquares = - new DenseVector(bufferedDataPoints.size()); - DenseVector labels = new DenseVector(bufferedDataPoints.size()); + DenseIntDoubleVector normSquares = + new DenseIntDoubleVector(bufferedDataPoints.size()); + DenseIntDoubleVector labels = + new DenseIntDoubleVector(bufferedDataPoints.size()); int offset = 0; - for (Tuple3 dataPoint : + for (Tuple3 dataPoint : bufferedDataPoints) { System.arraycopy( dataPoint.f0.values, @@ -142,14 +145,15 @@ public void mapPartition( * @param inputData Input data. * @return Input data with norm square. */ - private DataStream> computeNormSquare( + private DataStream> computeNormSquare( DataStream inputData) { return inputData.map( - new MapFunction>() { + new MapFunction>() { @Override - public Tuple3 map(Row value) { + public Tuple3 map(Row value) { Double label = ((Number) value.getField(getLabelCol())).doubleValue(); - DenseVector feature = ((Vector) value.getField(getFeaturesCol())).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) value.getField(getFeaturesCol())).toDense(); return Tuple3.of(feature, label, Math.pow(BLAS.norm2(feature), 2)); } }); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java index 3194fbb29..eb32e20fc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -130,7 +130,7 @@ private static class PredictLabelFunction extends RichMapFunction { private KnnModelData knnModelData; private final int k; private final String broadcastKey; - private DenseVector distanceVector; + private DenseIntDoubleVector distanceVector; public PredictLabelFunction(String broadcastKey, int k, String featureCol) { this.k = k; @@ -144,14 +144,14 @@ public Row map(Row row) { knnModelData = (KnnModelData) getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); - distanceVector = new DenseVector(knnModelData.labels.size()); + distanceVector = new DenseIntDoubleVector(knnModelData.labels.size()); } - DenseVector feature = ((Vector) row.getField(featureCol)).toDense(); + DenseIntDoubleVector feature = ((IntDoubleVector) row.getField(featureCol)).toDense(); double prediction = predictLabel(feature); return Row.join(row, Row.of(prediction)); } - private double predictLabel(DenseVector feature) { + private double predictLabel(DenseIntDoubleVector feature) { double normSquare = Math.pow(BLAS.norm2(feature), 2); BLAS.gemv(-2.0, knnModelData.packedFeatures, true, feature, 0.0, distanceVector); for (int i = 0; i < distanceVector.size(); i++) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java index 8e4a5c524..d146e8a3d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java @@ -27,12 +27,12 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.DenseMatrix; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Matrix; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -51,13 +51,15 @@ public class KnnModelData { public DenseMatrix packedFeatures; - public DenseVector featureNormSquares; - public DenseVector labels; + public DenseIntDoubleVector featureNormSquares; + public DenseIntDoubleVector labels; public KnnModelData() {} public KnnModelData( - DenseMatrix packedFeatures, DenseVector featureNormSquares, DenseVector labels) { + DenseMatrix packedFeatures, + DenseIntDoubleVector featureNormSquares, + DenseIntDoubleVector labels) { this.packedFeatures = packedFeatures; this.featureNormSquares = featureNormSquares; this.labels = labels; @@ -77,13 +79,14 @@ public static DataStream getModelDataStream(Table modelDataTable) x -> new KnnModelData( ((Matrix) x.getField(0)).toDense(), - ((Vector) x.getField(1)).toDense(), - ((Vector) x.getField(2)).toDense())); + ((IntDoubleVector) x.getField(1)).toDense(), + ((IntDoubleVector) x.getField(2)).toDense())); } /** Encoder for {@link KnnModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(KnnModelData modelData, OutputStream outputStream) throws IOException { @@ -102,14 +105,15 @@ public Reader createReader(Configuration config, FSDataInputStream private final DataInputView source = new DataInputViewStreamWrapper(stream); - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public KnnModelData read() throws IOException { try { DenseMatrix matrix = DenseMatrixSerializer.INSTANCE.deserialize(source); - DenseVector normSquares = serializer.deserialize(source); - DenseVector labels = serializer.deserialize(source); + DenseIntDoubleVector normSquares = serializer.deserialize(source); + DenseIntDoubleVector labels = serializer.deserialize(source); return new KnnModelData(matrix, normSquares, labels); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java index 30c166b43..d03101ed4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java @@ -25,8 +25,8 @@ import org.apache.flink.ml.common.lossfunc.HingeLoss; import org.apache.flink.ml.common.optimizer.Optimizer; import org.apache.flink.ml.common.optimizer.SGD; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -77,15 +77,15 @@ public LinearSVCModel fit(Table... inputs) { || Double.compare(1.0, label) == 0, "LinearSVC only supports binary classification. But detected label: %s.", label); - DenseVector features = - ((Vector) dataPoint.getField(getFeaturesCol())) + DenseIntDoubleVector features = + ((IntDoubleVector) dataPoint.getField(getFeaturesCol())) .toDense(); return new LabeledPointWithWeight(features, label, weight); }); - DataStream initModelData = + DataStream initModelData = DataStreamUtils.reduce( - trainData.map(x -> x.getFeatures().size()), + trainData.map(x -> (Integer) x.features.size()), (ReduceFunction) (t0, t1) -> { Preconditions.checkState( @@ -93,7 +93,7 @@ public LinearSVCModel fit(Table... inputs) { "The training data should all have same dimensions."); return t0; }) - .map(DenseVector::new); + .map(DenseIntDoubleVector::new); Optimizer optimizer = new SGD( @@ -103,7 +103,7 @@ public LinearSVCModel fit(Table... inputs) { getTol(), getReg(), getElasticNet()); - DataStream rawModelData = + DataStream rawModelData = optimizer.optimize(initModelData, trainData, HingeLoss.INSTANCE); DataStream modelData = rawModelData.map(LinearSVCModelData::new); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java index a7da2e31c..ce9073578 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java @@ -25,10 +25,10 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -75,7 +75,7 @@ public Table[] transform(Table... inputs) { ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO, - DenseVectorTypeInfo.INSTANCE), + DenseIntDoubleVectorTypeInfo.INSTANCE), ArrayUtils.addAll( inputTypeInfo.getFieldNames(), getPredictionCol(), @@ -136,7 +136,7 @@ private static class PredictLabelFunction extends RichMapFunction { private final double threshold; - private DenseVector coefficient; + private DenseIntDoubleVector coefficient; public PredictLabelFunction( String broadcastModelKey, String featuresCol, double threshold) { @@ -153,7 +153,8 @@ public Row map(Row dataPoint) { getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); coefficient = modelData.coefficient; } - DenseVector features = ((Vector) dataPoint.getField(featuresCol)).toDense(); + DenseIntDoubleVector features = + ((IntDoubleVector) dataPoint.getField(featuresCol)).toDense(); Row predictionResult = predictOneDataPoint(features, coefficient, threshold); return Row.join(dataPoint, predictionResult); } @@ -168,7 +169,7 @@ public Row map(Row dataPoint) { * @return The prediction label and the raw predictions. */ private static Row predictOneDataPoint( - DenseVector feature, DenseVector coefficient, double threshold) { + DenseIntDoubleVector feature, DenseIntDoubleVector coefficient, double threshold) { double dotValue = BLAS.dot(feature, coefficient); return Row.of(dotValue >= threshold ? 1.0 : 0.0, Vectors.dense(dotValue, -dotValue)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java index 771c10473..3f07e0f6e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java @@ -25,9 +25,9 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -45,9 +45,9 @@ */ public class LinearSVCModelData { - public DenseVector coefficient; + public DenseIntDoubleVector coefficient; - public LinearSVCModelData(DenseVector coefficient) { + public LinearSVCModelData(DenseIntDoubleVector coefficient) { this.coefficient = coefficient; } @@ -63,12 +63,13 @@ public static DataStream getModelDataStream(Table modelData) StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); return tEnv.toDataStream(modelData) - .map(x -> new LinearSVCModelData(((Vector) x.getField(0)).toDense())); + .map(x -> new LinearSVCModelData(((IntDoubleVector) x.getField(0)).toDense())); } /** Data encoder for {@link LinearSVCModel}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(LinearSVCModelData modelData, OutputStream outputStream) @@ -85,12 +86,13 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration configuration, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public LinearSVCModelData read() throws IOException { try { - DenseVector coefficient = + DenseIntDoubleVector coefficient = serializer.deserialize(new DataInputViewStreamWrapper(inputStream)); return new LinearSVCModelData(coefficient); } catch (EOFException e) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index 87cc650c3..8d5c6ce2a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -25,8 +25,8 @@ import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; import org.apache.flink.ml.common.optimizer.Optimizer; import org.apache.flink.ml.common.optimizer.SGD; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -85,15 +85,15 @@ public LogisticRegressionModel fit(Table... inputs) { throw new RuntimeException( "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); } - DenseVector features = - ((Vector) dataPoint.getField(getFeaturesCol())) - .toDense(); + IntDoubleVector features = + ((IntDoubleVector) + dataPoint.getField(getFeaturesCol())); return new LabeledPointWithWeight(features, label, weight); }); - DataStream initModelData = + DataStream initModelData = DataStreamUtils.reduce( - trainData.map(x -> x.getFeatures().size()), + trainData.map(x -> (Integer) x.features.size()), (ReduceFunction) (t0, t1) -> { Preconditions.checkState( @@ -101,7 +101,7 @@ public LogisticRegressionModel fit(Table... inputs) { "The training data should all have same dimensions."); return t0; }) - .map(DenseVector::new); + .map(DenseIntDoubleVector::new); Optimizer optimizer = new SGD( @@ -111,11 +111,11 @@ public LogisticRegressionModel fit(Table... inputs) { getTol(), getReg(), getElasticNet()); - DataStream rawModelData = + DataStream rawModelData = optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE); - DataStream modelData = - rawModelData.map(vector -> new LogisticRegressionModelData(vector, 0L)); + DataStream modelData = + rawModelData.map(vector -> new LogisticRegressionModelDataSegment(vector, 0L)); LogisticRegressionModel model = new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); ParamUtils.updateExistingParams(model, paramMap); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java index e777c5faa..8081ec263 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -43,6 +43,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; /** A Model which classifies data using the model data computed by {@link LogisticRegression}. */ @@ -66,7 +67,7 @@ public Table[] transform(Table... inputs) { (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); DataStream inputStream = tEnv.toDataStream(inputs[0]); final String broadcastModelKey = "broadcastModelKey"; - DataStream modelDataStream = + DataStream modelDataStream = LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); RowTypeInfo outputTypeInfo = @@ -74,7 +75,7 @@ public Table[] transform(Table... inputs) { ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO, - TypeInformation.of(DenseVector.class)), + TypeInformation.of(DenseIntDoubleVector.class)), ArrayUtils.addAll( inputTypeInfo.getFieldNames(), getPredictionCol(), @@ -147,15 +148,22 @@ public PredictLabelFunction(String broadcastModelKey, Map, Object> para @Override public Row map(Row dataPoint) { if (servable == null) { - LogisticRegressionModelData modelData = - (LogisticRegressionModelData) - getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); - servable = new LogisticRegressionModelServable(modelData); + List modelData = + getRuntimeContext().getBroadcastVariable(broadcastModelKey); + + if (modelData.size() == 1) { + servable = new LogisticRegressionModelServable(modelData.get(0)); + } else { + LogisticRegressionModelDataSegment mergedModel = + LogisticRegressionModelDataSegment.mergeSegments(modelData); + servable = new LogisticRegressionModelServable(mergedModel); + } ParamUtils.updateExistingParams(servable, params); } - Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol()); + IntDoubleVector features = + (IntDoubleVector) dataPoint.getField(servable.getFeaturesCol()); - Tuple2 predictionResult = servable.transform(features); + Tuple2 predictionResult = servable.transform(features); return Row.join(dataPoint, Row.of(predictionResult.f0, predictionResult.f1)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java index e6acb7c73..6a10f98c3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java @@ -25,7 +25,7 @@ import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -45,8 +45,8 @@ public class LogisticRegressionModelDataUtil { /** - * Generates a Table containing a {@link LogisticRegressionModelData} instance with randomly - * generated coefficient. + * Generates a Table containing a {@link LogisticRegressionModelDataSegment} instance with + * randomly generated coefficient. * * @param tEnv The environment where to create the table. * @param dim The size of generated coefficient. @@ -59,7 +59,7 @@ public static Table generateRandomModelData(StreamTableEnvironment tEnv, int dim } private static class RandomModelDataGenerator - implements MapFunction { + implements MapFunction { private final int dim; private final int seed; @@ -69,13 +69,13 @@ public RandomModelDataGenerator(int dim, int seed) { } @Override - public LogisticRegressionModelData map(Integer integer) throws Exception { - DenseVector vector = new DenseVector(dim); + public LogisticRegressionModelDataSegment map(Integer integer) throws Exception { + DenseIntDoubleVector vector = new DenseIntDoubleVector(dim); Random random = new Random(seed); for (int j = 0; j < dim; j++) { vector.values[j] = random.nextDouble(); } - return new LogisticRegressionModelData(vector, 0L); + return new LogisticRegressionModelDataSegment(vector, 0L); } } @@ -85,11 +85,18 @@ public LogisticRegressionModelData map(Integer integer) throws Exception { * @param modelData The table model data. * @return The data stream model data. */ - public static DataStream getModelDataStream(Table modelData) { + public static DataStream getModelDataStream( + Table modelData) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); return tEnv.toDataStream(modelData) - .map(x -> new LogisticRegressionModelData(x.getFieldAs(0), x.getFieldAs(1))); + .map( + x -> + new LogisticRegressionModelDataSegment( + x.getFieldAs(0), + x.getFieldAs(1), + x.getFieldAs(2), + x.getFieldAs(3))); } /** @@ -105,9 +112,12 @@ public static DataStream getModelDataByteStream(Table modelDataTable) { return tEnv.toDataStream(modelDataTable) .map( x -> { - LogisticRegressionModelData modelData = - new LogisticRegressionModelData( - x.getFieldAs(0), x.getFieldAs(1)); + LogisticRegressionModelDataSegment modelData = + new LogisticRegressionModelDataSegment( + x.getFieldAs(0), + x.getFieldAs(1), + x.getFieldAs(2), + x.getFieldAs(3)); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); modelData.encode(outputStream); @@ -116,26 +126,27 @@ public static DataStream getModelDataByteStream(Table modelDataTable) { } /** Data encoder for {@link LogisticRegression} and {@link OnlineLogisticRegression}. */ - public static class ModelDataEncoder implements Encoder { + public static class ModelDataEncoder implements Encoder { @Override - public void encode(LogisticRegressionModelData modelData, OutputStream outputStream) + public void encode(LogisticRegressionModelDataSegment modelData, OutputStream outputStream) throws IOException { modelData.encode(outputStream); } } /** Data decoder for {@link LogisticRegression} and {@link OnlineLogisticRegression}. */ - public static class ModelDataDecoder extends SimpleStreamFormat { + public static class ModelDataDecoder + extends SimpleStreamFormat { @Override - public Reader createReader( + public Reader createReader( Configuration configuration, FSDataInputStream inputStream) { - return new Reader() { + return new Reader() { @Override - public LogisticRegressionModelData read() throws IOException { + public LogisticRegressionModelDataSegment read() throws IOException { try { - return LogisticRegressionModelData.decode(inputStream); + return LogisticRegressionModelDataSegment.decode(inputStream); } catch (EOFException e) { return null; } @@ -149,8 +160,8 @@ public void close() throws IOException { } @Override - public TypeInformation getProducedType() { - return TypeInformation.of(LogisticRegressionModelData.class); + public TypeInformation getProducedType() { + return TypeInformation.of(LogisticRegressionModelDataSegment.class); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java new file mode 100644 index 000000000..5299185ce --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; +import org.apache.flink.ml.common.ps.training.ComputeGradients; +import org.apache.flink.ml.common.ps.training.ComputeIndices; +import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.MiniBatchMLSession; +import org.apache.flink.ml.common.ps.training.PullStage; +import org.apache.flink.ml.common.ps.training.PushStage; +import org.apache.flink.ml.common.ps.training.SerializableConsumer; +import org.apache.flink.ml.common.ps.training.TrainingUtils; +import org.apache.flink.ml.common.ps.updater.FTRL; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.LongDoubleVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableFunction; +import org.apache.flink.util.function.SerializableSupplier; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * An Estimator which implements the large scale logistic regression algorithm using FTRL optimizer. + * + *

See https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegressionWithFtrl + implements Estimator, + LogisticRegressionWithFtrlParams { + + private final Map, Object> paramMap = new HashMap<>(); + + public LogisticRegressionWithFtrl() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public LogisticRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + String classificationType = getMultiClass(); + Preconditions.checkArgument( + "auto".equals(classificationType) || "binomial".equals(classificationType), + "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream trainData = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction) + dataPoint -> { + double weight = + getWeightCol() == null + ? 1.0 + : ((Number) + dataPoint.getField( + getWeightCol())) + .doubleValue(); + double label = + ((Number) dataPoint.getField(getLabelCol())) + .doubleValue(); + boolean isBinomial = + Double.compare(0., label) == 0 + || Double.compare(1., label) == 0; + if (!isBinomial) { + throw new RuntimeException( + "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); + } + Vector features = + dataPoint.getFieldAs(getFeaturesCol()); + return new LabeledPointWithWeight( + features, label, weight); + }); + + DataStream maxKey; + if (getModelDim() > 0) { + maxKey = trainData.getExecutionEnvironment().fromElements(getModelDim() - 1); + } else { + maxKey = + DataStreamUtils.reduce( + trainData.map( + x -> { + Vector feature = x.features; + long dim; + if (feature instanceof IntDoubleVector) { + dim = ((IntDoubleVector) feature).size(); + } else { + dim = ((LongDoubleVector) feature).size(); + } + return dim; + }), + (ReduceFunction) Math::max) + .map((MapFunction) value -> value - 1); + } + + MiniBatchMLSession mlSession = + new MiniBatchMLSession<>( + getGlobalBatchSize(), TypeInformation.of(LabeledPointWithWeight.class)); + + IterationStageList> iterationStages = + new IterationStageList<>(mlSession) + .addStage(new ComputeIndices()) + .addStage( + new PullStage( + (SerializableSupplier) () -> mlSession.pullIndices, + (SerializableConsumer) + x -> mlSession.pulledValues = x)) + .addStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) + .addStage( + new PushStage( + (SerializableSupplier) () -> mlSession.pushIndices, + (SerializableSupplier) + () -> mlSession.pushValues)) + .setTerminationCriteria( + (SerializableFunction< + MiniBatchMLSession, + Boolean>) + o -> o.iterationId >= getMaxIter()); + FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet()); + + DataStreamList resultList = + TrainingUtils.train( + trainData, + iterationStages, + maxKey, + new TupleTypeInfo<>( + Types.LONG, + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + ftrl, + getNumServers()); + + DataStream> rawModelData = resultList.get(0); + + final long modelVersion = 0L; + + DataStream modelData = + rawModelData.map( + tuple3 -> + new LogisticRegressionModelDataSegment( + Vectors.dense(tuple3.f2), + tuple3.f0, + tuple3.f1, + modelVersion)); + + LogisticRegressionModel model = + new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); + ParamUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static LogisticRegressionWithFtrl load(StreamTableEnvironment tEnv, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java new file mode 100644 index 000000000..4b2bd72a9 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.ml.common.param.HasElasticNet; +import org.apache.flink.ml.common.param.HasGlobalBatchSize; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasMaxIter; +import org.apache.flink.ml.common.param.HasMultiClass; +import org.apache.flink.ml.common.param.HasReg; +import org.apache.flink.ml.common.param.HasTol; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.LongParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** Params for {@link LogisticRegressionWithFtrl}. */ +public interface LogisticRegressionWithFtrlParams + extends HasLabelCol, + HasWeightCol, + HasGlobalBatchSize, + HasReg, + HasElasticNet, + HasMultiClass, + HasMaxIter, + HasTol, + LogisticRegressionModelParams { + + Param NUM_SERVERS = + new IntParam( + "numServers", + "Number of servers to store model parameters.", + 1, + ParamValidators.gtEq(1)); + + Param ALPHA = + new DoubleParam( + "alpha", + "The alpha parameter of FTRL optimizer.", + 0.1, + ParamValidators.gt(0.0)); + + Param BETA = + new DoubleParam( + "beta", "The beta parameter of FTRL optimizer.", 0.1, ParamValidators.gt(0.0)); + + Param MODEL_DIM = + new LongParam( + "modelDim", "number of features of input data.", 0L, ParamValidators.gtEq(0)); + + default int getNumServers() { + return get(NUM_SERVERS); + } + + default T setNumServers(Integer value) { + return set(NUM_SERVERS, value); + } + + default double getAlpha() { + return get(ALPHA); + } + + default T setAlpha(Double value) { + return set(ALPHA, value); + } + + default double getBeta() { + return get(BETA); + } + + default T setBeta(Double value) { + return set(BETA, value); + } + + default long getModelDim() { + return get(MODEL_DIM); + } + + default T setModelDim(long value) { + return set(MODEL_DIM, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java index 1bc19938f..57d9b23a8 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java @@ -35,9 +35,9 @@ import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -88,7 +88,7 @@ public OnlineLogisticRegressionModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream modelDataStream = + DataStream modelDataStream = LogisticRegressionModelDataUtil.getModelDataStream(initModelDataTable); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); @@ -114,9 +114,9 @@ public OnlineLogisticRegressionModel fit(Table... inputs) { getFeaturesCol(), getLabelCol(), getWeightCol()), pointTypeInfo); - DataStream initModelData = + DataStream initModelData = modelDataStream.map( - (MapFunction) + (MapFunction) value -> value.coefficient); initModelData.getTransformation().setParallelism(1); @@ -125,7 +125,7 @@ public OnlineLogisticRegressionModel fit(Table... inputs) { new FtrlIterationBody( getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet()); - DataStream onlineModelData = + DataStream onlineModelData = Iterations.iterateUnboundedStreams( DataStreamList.of(initModelData), DataStreamList.of(points), body) .get(0); @@ -189,7 +189,7 @@ public FtrlIterationBody( @Override public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { - DataStream modelData = variableStreams.get(0); + DataStream modelData = variableStreams.get(0); DataStream points = dataStreams.get(0); int parallelism = points.getParallelism(); @@ -198,17 +198,17 @@ public IterationBodyResult process( "There are more subtasks in the training process than the number " + "of elements in each batch. Some subtasks might be idling forever."); - DataStream newGradient = + DataStream newGradient = DataStreamUtils.generateBatchData(points, parallelism, batchSize) .connect(modelData.broadcast()) .transform( "LocalGradientCalculator", - TypeInformation.of(DenseVector[].class), + TypeInformation.of(DenseIntDoubleVector[].class), new CalculateLocalGradient()) .setParallelism(parallelism) .countWindowAll(parallelism) .reduce( - (ReduceFunction) + (ReduceFunction) (gradientInfo, newGradientInfo) -> { BLAS.axpy(1.0, gradientInfo[0], newGradientInfo[0]); BLAS.axpy(1.0, gradientInfo[1], newGradientInfo[1]); @@ -217,15 +217,15 @@ public IterationBodyResult process( } return newGradientInfo; }); - DataStream feedbackModelData = + DataStream feedbackModelData = newGradient .transform( "ModelDataUpdater", - TypeInformation.of(DenseVector.class), + TypeInformation.of(DenseIntDoubleVector.class), new UpdateModel(alpha, beta, l1, l2)) .setParallelism(1); - DataStream outputModelData = + DataStream outputModelData = feedbackModelData.map(new CreateLrModelData()).setParallelism(1); return new IterationBodyResult( DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData)); @@ -233,13 +233,15 @@ public IterationBodyResult process( } private static class CreateLrModelData - implements MapFunction, CheckpointedFunction { + implements MapFunction, + CheckpointedFunction { private Long modelVersion = 1L; private transient ListState modelVersionState; @Override - public LogisticRegressionModelData map(DenseVector denseVector) throws Exception { - return new LogisticRegressionModelData(denseVector, modelVersion++); + public LogisticRegressionModelDataSegment map(DenseIntDoubleVector denseVector) + throws Exception { + return new LogisticRegressionModelDataSegment(denseVector, modelVersion++); } @Override @@ -258,8 +260,8 @@ public void initializeState(FunctionInitializationContext context) throws Except } /** Updates model. */ - private static class UpdateModel extends AbstractStreamOperator - implements OneInputStreamOperator { + private static class UpdateModel extends AbstractStreamOperator + implements OneInputStreamOperator { private ListState nParamState; private ListState zParamState; private final double alpha; @@ -288,8 +290,9 @@ public void initializeState(StateInitializationContext context) throws Exception } @Override - public void processElement(StreamRecord streamRecord) throws Exception { - DenseVector[] gradientInfo = streamRecord.getValue(); + public void processElement(StreamRecord streamRecord) + throws Exception { + DenseIntDoubleVector[] gradientInfo = streamRecord.getValue(); double[] coefficient = gradientInfo[2].values; double[] g = gradientInfo[0].values; for (int i = 0; i < g.length; ++i) { @@ -317,13 +320,14 @@ public void processElement(StreamRecord streamRecord) throws Exce / ((beta + Math.sqrt(nParam[i])) / alpha + l2); } } - output.collect(new StreamRecord<>(new DenseVector(coefficient))); + output.collect(new StreamRecord<>(new DenseIntDoubleVector(coefficient))); } } - private static class CalculateLocalGradient extends AbstractStreamOperator - implements TwoInputStreamOperator { - private ListState modelDataState; + private static class CalculateLocalGradient + extends AbstractStreamOperator + implements TwoInputStreamOperator { + private ListState modelDataState; private ListState localBatchDataState; private double[] gradient; private double[] weightSum; @@ -334,7 +338,8 @@ public void initializeState(StateInitializationContext context) throws Exception modelDataState = context.getOperatorStateStore() .getListState( - new ListStateDescriptor<>("modelData", DenseVector.class)); + new ListStateDescriptor<>( + "modelData", DenseIntDoubleVector.class)); TypeInformation type = ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class)); localBatchDataState = @@ -353,7 +358,7 @@ private void calculateGradient() throws Exception { || !localBatchDataState.get().iterator().hasNext()) { return; } - DenseVector modelData = + DenseIntDoubleVector modelData = OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get(); modelDataState.clear(); @@ -362,7 +367,7 @@ private void calculateGradient() throws Exception { localBatchDataState.update(pointsList); for (Row point : points) { - Vector vec = point.getFieldAs(0); + IntDoubleVector vec = point.getFieldAs(0); double label = point.getFieldAs(1); double weight = point.getArity() == 2 ? 1.0 : point.getFieldAs(2); if (gradient == null) { @@ -371,14 +376,14 @@ private void calculateGradient() throws Exception { } double p = BLAS.dot(modelData, vec); p = 1 / (1 + Math.exp(-p)); - if (vec instanceof DenseVector) { - DenseVector dvec = (DenseVector) vec; + if (vec instanceof DenseIntDoubleVector) { + DenseIntDoubleVector dvec = (DenseIntDoubleVector) vec; for (int i = 0; i < modelData.size(); ++i) { gradient[i] += (p - label) * dvec.values[i]; weightSum[i] += 1.0; } } else { - SparseVector svec = (SparseVector) vec; + SparseIntDoubleVector svec = (SparseIntDoubleVector) vec; for (int i = 0; i < svec.indices.length; ++i) { int idx = svec.indices[i]; gradient[idx] += (p - label) * svec.values[i]; @@ -390,9 +395,9 @@ private void calculateGradient() throws Exception { if (points.length > 0) { output.collect( new StreamRecord<>( - new DenseVector[] { - new DenseVector(gradient), - new DenseVector(weightSum), + new DenseIntDoubleVector[] { + new DenseIntDoubleVector(gradient), + new DenseIntDoubleVector(weightSum), (getRuntimeContext().getIndexOfThisSubtask() == 0) ? modelData : null @@ -403,7 +408,8 @@ private void calculateGradient() throws Exception { } @Override - public void processElement2(StreamRecord modelDataRecord) throws Exception { + public void processElement2(StreamRecord modelDataRecord) + throws Exception { modelDataState.add(modelDataRecord.getValue()); calculateGradient(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java index f06086132..a0e6a096f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java @@ -27,8 +27,8 @@ import org.apache.flink.metrics.Gauge; import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -76,7 +76,7 @@ public Table[] transform(Table... inputs) { ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), Types.DOUBLE, - TypeInformation.of(DenseVector.class), + TypeInformation.of(DenseIntDoubleVector.class), Types.LONG), ArrayUtils.addAll( inputTypeInfo.getFieldNames(), @@ -99,12 +99,12 @@ public Table[] transform(Table... inputs) { /** A utility operator used for prediction. */ private static class PredictLabelOperator extends AbstractStreamOperator - implements TwoInputStreamOperator { + implements TwoInputStreamOperator { private final RowTypeInfo inputTypeInfo; private final Map, Object> params; private ListState bufferedPointsState; - private DenseVector coefficient; + private DenseIntDoubleVector coefficient; private long modelDataVersion = 0L; private LogisticRegressionModelServable servable; @@ -139,9 +139,9 @@ public void processElement1(StreamRecord streamRecord) throws Exception { } @Override - public void processElement2(StreamRecord streamRecord) + public void processElement2(StreamRecord streamRecord) throws Exception { - LogisticRegressionModelData modelData = streamRecord.getValue(); + LogisticRegressionModelDataSegment modelData = streamRecord.getValue(); coefficient = modelData.coefficient; modelDataVersion = modelData.modelVersion; for (Row dataPoint : bufferedPointsState.get()) { @@ -159,11 +159,12 @@ public void processElement(StreamRecord streamRecord) throws Exception { if (servable == null) { servable = new LogisticRegressionModelServable( - new LogisticRegressionModelData(coefficient, 0L)); + new LogisticRegressionModelDataSegment(coefficient, 0L)); ParamUtils.updateExistingParams(servable, params); } - Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol()); - Tuple2 predictionResult = servable.transform(features); + IntDoubleVector features = + (IntDoubleVector) dataPoint.getField(servable.getFeaturesCol()); + Tuple2 predictionResult = servable.transform(features); output.collect( new StreamRecord<>( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java index 404c5d19a..8c3935e9b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java @@ -27,9 +27,9 @@ import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -74,10 +74,10 @@ public NaiveBayesModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream> input = + DataStream> input = tEnv.toDataStream(inputs[0]) .map( - (MapFunction>) + (MapFunction>) row -> { Number number = (Number) row.getField(labelCol); Preconditions.checkNotNull( @@ -87,7 +87,7 @@ public NaiveBayesModel fit(Table... inputs) { number.intValue() == number.doubleValue(), "Label value should be indexed number."); return new Tuple2<>( - (Vector) row.getField(featuresCol), + (IntDoubleVector) row.getField(featuresCol), number.doubleValue()); }, Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE)); @@ -129,8 +129,8 @@ public NaiveBayesModel fit(Table... inputs) { DataTypes.ARRAY( DataTypes.MAP( DataTypes.DOUBLE(), DataTypes.DOUBLE())))) - .column("piArray", DataTypes.of(DenseVectorTypeInfo.INSTANCE)) - .column("labels", DataTypes.of(DenseVectorTypeInfo.INSTANCE)) + .column("piArray", DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE)) + .column("labels", DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE)) .build(); NaiveBayesModel model = @@ -165,10 +165,11 @@ public Map, Object> getParamMap() { * */ private static class ExtractFeatureFunction - implements FlatMapFunction, Tuple3> { + implements FlatMapFunction< + Tuple2, Tuple3> { @Override public void flatMap( - Tuple2 value, + Tuple2 value, Collector> collector) { Preconditions.checkNotNull(value.f1); for (int i = 0; i < value.f0.size(); i++) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java index 1ac9c00dd..05c7b5564 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -146,13 +146,13 @@ public Row map(Row row) { (NaiveBayesModelData) getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); } - Vector vector = (Vector) row.getField(featuresCol); + IntDoubleVector vector = (IntDoubleVector) row.getField(featuresCol); double label = findMaxProbLabel(calculateProb(modelData, vector), modelData.labels); return Row.join(row, Row.of(label)); } } - private static double findMaxProbLabel(DenseVector prob, Vector label) { + private static double findMaxProbLabel(DenseIntDoubleVector prob, IntDoubleVector label) { double result = 0.; int probSize = prob.size(); double maxVal = Double.NEGATIVE_INFINITY; @@ -167,9 +167,10 @@ private static double findMaxProbLabel(DenseVector prob, Vector label) { } /** Calculate probability of the input data. */ - private static DenseVector calculateProb(NaiveBayesModelData modelData, Vector data) { + private static DenseIntDoubleVector calculateProb( + NaiveBayesModelData modelData, IntDoubleVector data) { int labelSize = modelData.labels.size(); - DenseVector probs = new DenseVector(new double[labelSize]); + DenseIntDoubleVector probs = new DenseIntDoubleVector(new double[labelSize]); for (int i = 0; i < labelSize; i++) { Map[] labelData = modelData.theta[i]; for (int j = 0; j < data.size(); j++) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java index fac5dab32..03fc9f392 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java @@ -29,10 +29,10 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -60,8 +60,8 @@ public class NaiveBayesModelData { fields.put( "theta", Types.OBJECT_ARRAY(Types.OBJECT_ARRAY(Types.MAP(Types.DOUBLE, Types.DOUBLE)))); - fields.put("piArray", DenseVectorTypeInfo.INSTANCE); - fields.put("labels", DenseVectorTypeInfo.INSTANCE); + fields.put("piArray", DenseIntDoubleVectorTypeInfo.INSTANCE); + fields.put("labels", DenseIntDoubleVectorTypeInfo.INSTANCE); } public static final TypeInformation TYPE_INFO = @@ -74,13 +74,15 @@ public class NaiveBayesModelData { public Map[][] theta; /** Log of class priors, whose dimension is C (number of classes). */ - public DenseVector piArray; + public DenseIntDoubleVector piArray; /** Value of labels. */ - public DenseVector labels; + public DenseIntDoubleVector labels; public NaiveBayesModelData( - Map[][] theta, DenseVector piArray, DenseVector labels) { + Map[][] theta, + DenseIntDoubleVector piArray, + DenseIntDoubleVector labels) { this.theta = theta; this.piArray = piArray; this.labels = labels; @@ -103,14 +105,15 @@ public static DataStream getModelDataStream(Table modelData row -> new NaiveBayesModelData( (Map[][]) row.getField(0), - ((Vector) row.getField(1)).toDense(), - ((Vector) row.getField(2)).toDense()), + ((IntDoubleVector) row.getField(1)).toDense(), + ((IntDoubleVector) row.getField(2)).toDense()), TYPE_INFO); } /** Data encoder for the {@link NaiveBayesModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(NaiveBayesModelData modelData, OutputStream outputStream) @@ -141,7 +144,8 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration config, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public NaiveBayesModelData read() throws IOException { @@ -152,9 +156,11 @@ public NaiveBayesModelData read() throws IOException { new MapSerializer<>( DoubleSerializer.INSTANCE, DoubleSerializer.INSTANCE); - DenseVector labels = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector labels = + serializer.deserialize(inputViewStreamWrapper); - DenseVector piArray = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector piArray = + serializer.deserialize(inputViewStreamWrapper); int featureSize = inputViewStreamWrapper.readInt(); int numLabels = inputViewStreamWrapper.readInt(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java index 6fbf39d57..12901edb1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java @@ -43,10 +43,10 @@ import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound; import org.apache.flink.ml.common.iteration.TerminateOnMaxIter; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorWithNormSerializer; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -89,11 +89,12 @@ public KMeansModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream points = + DataStream points = tEnv.toDataStream(inputs[0]) - .map(row -> ((Vector) row.getField(getFeaturesCol())).toDense()); + .map(row -> ((IntDoubleVector) row.getField(getFeaturesCol())).toDense()); - DataStream initCentroids = selectRandomCentroids(points, getK(), getSeed()); + DataStream initCentroids = + selectRandomCentroids(points, getK(), getSeed()); IterationConfig config = IterationConfig.newBuilder() @@ -144,20 +145,20 @@ public KMeansIterationBody(int maxIterationNum, DistanceMeasure distanceMeasure) @Override public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { - DataStream centroids = variableStreams.get(0); - DataStream points = dataStreams.get(0); + DataStream centroids = variableStreams.get(0); + DataStream points = dataStreams.get(0); DataStream terminationCriteria = centroids.flatMap(new TerminateOnMaxIter(maxIterationNum)); - DataStream> centroidIdAndPoints = + DataStream> centroidIdAndPoints = points.connect(centroids.broadcast()) .transform( "CentroidsUpdateAccumulator", new TupleTypeInfo<>( BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, ObjectArrayTypeInfo.getInfoFor( - DenseVectorTypeInfo.INSTANCE)), + DenseIntDoubleVectorTypeInfo.INSTANCE)), new CentroidsUpdateAccumulator(distanceMeasure)); DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints, 100); @@ -169,7 +170,7 @@ public IterationBodyResult process( .reduce(new CentroidsUpdateReducer()) .map(new ModelDataGenerator()); - DataStream newCentroids = + DataStream newCentroids = newModelData.map(x -> x.centroids).setParallelism(1); DataStream finalModelData = @@ -183,10 +184,11 @@ public IterationBodyResult process( } private static class CentroidsUpdateReducer - implements ReduceFunction> { + implements ReduceFunction> { @Override - public Tuple2 reduce( - Tuple2 tuple2, Tuple2 t1) + public Tuple2 reduce( + Tuple2 tuple2, + Tuple2 t1) throws Exception { for (int i = 0; i < tuple2.f0.length; i++) { tuple2.f0[i] += t1.f0[i]; @@ -198,28 +200,31 @@ public Tuple2 reduce( } private static class ModelDataGenerator - implements MapFunction, KMeansModelData> { + implements MapFunction, KMeansModelData> { @Override - public KMeansModelData map(Tuple2 tuple2) throws Exception { + public KMeansModelData map(Tuple2 tuple2) + throws Exception { double[] weights = new double[tuple2.f0.length]; for (int i = 0; i < tuple2.f0.length; i++) { BLAS.scal(1.0 / tuple2.f0[i], tuple2.f1[i]); weights[i] = tuple2.f0[i]; } - return new KMeansModelData(tuple2.f1, new DenseVector(weights)); + return new KMeansModelData(tuple2.f1, new DenseIntDoubleVector(weights)); } } private static class CentroidsUpdateAccumulator - extends AbstractStreamOperator> + extends AbstractStreamOperator> implements TwoInputStreamOperator< - DenseVector, DenseVector[], Tuple2>, - IterationListener> { + DenseIntDoubleVector, + DenseIntDoubleVector[], + Tuple2>, + IterationListener> { private final DistanceMeasure distanceMeasure; - private ListState centroids; + private ListState centroids; private ListStateWithCache points; @@ -232,8 +237,8 @@ public CentroidsUpdateAccumulator(DistanceMeasure distanceMeasure) { public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - TypeInformation type = - ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); + TypeInformation type = + ObjectArrayTypeInfo.getInfoFor(DenseIntDoubleVectorTypeInfo.INSTANCE); centroids = context.getOperatorStateStore() @@ -255,12 +260,14 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } @Override - public void processElement1(StreamRecord streamRecord) throws Exception { + public void processElement1(StreamRecord streamRecord) + throws Exception { points.add(new VectorWithNorm(streamRecord.getValue())); } @Override - public void processElement2(StreamRecord streamRecord) throws Exception { + public void processElement2(StreamRecord streamRecord) + throws Exception { Preconditions.checkState(!centroids.get().iterator().hasNext()); centroids.add(streamRecord.getValue()); } @@ -269,9 +276,9 @@ public void processElement2(StreamRecord streamRecord) throws Exc public void onEpochWatermarkIncremented( int epochWatermark, Context context, - Collector> out) + Collector> out) throws Exception { - DenseVector[] centroidValues = + DenseIntDoubleVector[] centroidValues = Objects.requireNonNull( OperatorStateUtils.getUniqueElement(centroids, "centroids") .orElse(null)); @@ -281,11 +288,11 @@ public void onEpochWatermarkIncremented( centroidsWithNorm[i] = new VectorWithNorm(centroidValues[i]); } - DenseVector[] newCentroids = new DenseVector[centroidValues.length]; + DenseIntDoubleVector[] newCentroids = new DenseIntDoubleVector[centroidValues.length]; Integer[] counts = new Integer[centroidValues.length]; Arrays.fill(counts, 0); for (int i = 0; i < centroidValues.length; i++) { - newCentroids[i] = new DenseVector(centroidValues[i].size()); + newCentroids[i] = new DenseIntDoubleVector(centroidValues[i].size()); } for (VectorWithNorm point : points.get()) { @@ -301,25 +308,25 @@ public void onEpochWatermarkIncremented( @Override public void onIterationTerminated( - Context context, Collector> collector) { + Context context, Collector> collector) { centroids.clear(); points.clear(); } } - public static DataStream selectRandomCentroids( - DataStream data, int k, long seed) { - DataStream resultStream = + public static DataStream selectRandomCentroids( + DataStream data, int k, long seed) { + DataStream resultStream = DataStreamUtils.mapPartition( DataStreamUtils.sample(data, k, seed), - new MapPartitionFunction() { + new MapPartitionFunction() { @Override public void mapPartition( - Iterable iterable, - Collector collector) { - List list = new ArrayList<>(); + Iterable iterable, + Collector collector) { + List list = new ArrayList<>(); iterable.iterator().forEachRemaining(list::add); - collector.collect(list.toArray(new DenseVector[0])); + collector.collect(list.toArray(new DenseIntDoubleVector[0])); } }); resultStream.getTransformation().setParallelism(1); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java index 5aa57aec5..cac0846ad 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.distance.DistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -137,7 +137,8 @@ public Row map(Row dataPoint) { centroids[i] = new VectorWithNorm(modelData.centroids[i]); } } - DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense(); + DenseIntDoubleVector point = + ((IntDoubleVector) dataPoint.getField(featuresCol)).toDense(); int closestCentroidId = distanceMeasure.findClosest(centroids, new VectorWithNorm(point)); return Row.join(dataPoint, Row.of(closestCentroidId)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java index eb00b1ea3..b10bd471b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java @@ -28,9 +28,9 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -52,7 +52,7 @@ */ public class KMeansModelData { - public DenseVector[] centroids; + public DenseIntDoubleVector[] centroids; /** * The weight of the centroids. It is used when updating the model data in online training @@ -61,9 +61,9 @@ public class KMeansModelData { *

KMeansModelData objects generated during {@link KMeans#fit(Table...)} also contains this * field, so that it can be used as the initial model data of the online training process. */ - public DenseVector weights; + public DenseIntDoubleVector weights; - public KMeansModelData(DenseVector[] centroids, DenseVector weights) { + public KMeansModelData(DenseIntDoubleVector[] centroids, DenseIntDoubleVector weights) { Preconditions.checkArgument(centroids.length == weights.size()); this.centroids = centroids; this.weights = weights; @@ -103,15 +103,15 @@ private RandomCentroidsCreator(int k, int dim, double weight, long seed) { @Override public KMeansModelData map(Integer integer) { - DenseVector[] centroids = new DenseVector[k]; + DenseIntDoubleVector[] centroids = new DenseIntDoubleVector[k]; Random random = new Random(seed); for (int i = 0; i < k; i++) { - centroids[i] = new DenseVector(dim); + centroids[i] = new DenseIntDoubleVector(dim); for (int j = 0; j < dim; j++) { centroids[i].values[j] = random.nextDouble(); } } - DenseVector weights = new DenseVector(k); + DenseIntDoubleVector weights = new DenseIntDoubleVector(k); Arrays.fill(weights.values, weight); return new KMeansModelData(centroids, weights); } @@ -130,15 +130,16 @@ public static DataStream getModelDataStream(Table modelData) { .map( x -> new KMeansModelData( - Arrays.stream(((Vector[]) x.getField(0))) - .map(Vector::toDense) - .toArray(DenseVector[]::new), - ((Vector) x.getField(1)).toDense())); + Arrays.stream(((IntDoubleVector[]) x.getField(0))) + .map(IntDoubleVector::toDense) + .toArray(DenseIntDoubleVector[]::new), + ((IntDoubleVector) x.getField(1)).toDense())); } /** Data encoder for {@link KMeansModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(KMeansModelData modelData, OutputStream outputStream) @@ -146,7 +147,7 @@ public void encode(KMeansModelData modelData, OutputStream outputStream) DataOutputViewStreamWrapper outputViewStreamWrapper = new DataOutputViewStreamWrapper(outputStream); IntSerializer.INSTANCE.serialize(modelData.centroids.length, outputViewStreamWrapper); - for (DenseVector denseVector : modelData.centroids) { + for (DenseIntDoubleVector denseVector : modelData.centroids) { serializer.serialize(denseVector, new DataOutputViewStreamWrapper(outputStream)); } serializer.serialize(modelData.weights, new DataOutputViewStreamWrapper(outputStream)); @@ -159,7 +160,8 @@ public static class ModelDataDecoder extends SimpleStreamFormat public Reader createReader( Configuration config, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public KMeansModelData read() throws IOException { @@ -168,11 +170,13 @@ public KMeansModelData read() throws IOException { new DataInputViewStreamWrapper(inputStream); int numDenseVectors = IntSerializer.INSTANCE.deserialize(inputViewStreamWrapper); - DenseVector[] centroids = new DenseVector[numDenseVectors]; + DenseIntDoubleVector[] centroids = + new DenseIntDoubleVector[numDenseVectors]; for (int i = 0; i < numDenseVectors; i++) { centroids[i] = serializer.deserialize(inputViewStreamWrapper); } - DenseVector weights = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector weights = + serializer.deserialize(inputViewStreamWrapper); return new KMeansModelData(centroids, weights); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java index d1783a792..0d407b8df 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java @@ -33,10 +33,10 @@ import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.distance.DistanceMeasure; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -89,7 +89,7 @@ public OnlineKMeansModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream points = + DataStream points = tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); DataStream initModelData = @@ -156,7 +156,7 @@ public OnlineKMeansIterationBody( public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { DataStream modelData = variableStreams.get(0); - DataStream points = dataStreams.get(0); + DataStream points = dataStreams.get(0); int parallelism = points.getParallelism(); @@ -188,10 +188,10 @@ public IterationBodyResult process( private static class ModelDataGlobalReducer implements ReduceFunction { @Override public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) { - DenseVector weights = modelData.weights; - DenseVector[] centroids = modelData.centroids; - DenseVector newWeights = newModelData.weights; - DenseVector[] newCentroids = newModelData.centroids; + DenseIntDoubleVector weights = modelData.weights; + DenseIntDoubleVector[] centroids = modelData.centroids; + DenseIntDoubleVector newWeights = newModelData.weights; + DenseIntDoubleVector[] newCentroids = newModelData.centroids; int k = newCentroids.length; int dim = newCentroids[0].size(); @@ -225,11 +225,12 @@ public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newMode * */ private static class ModelDataLocalUpdater extends AbstractStreamOperator - implements TwoInputStreamOperator { + implements TwoInputStreamOperator< + DenseIntDoubleVector[], KMeansModelData, KMeansModelData> { private final DistanceMeasure distanceMeasure; private final int k; private final double decayFactor; - private ListState localBatchDataState; + private ListState localBatchDataState; private ListState modelDataState; private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double decayFactor) { @@ -242,8 +243,8 @@ private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double dec public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - TypeInformation type = - ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); + TypeInformation type = + ObjectArrayTypeInfo.getInfoFor(DenseIntDoubleVectorTypeInfo.INSTANCE); localBatchDataState = context.getOperatorStateStore() .getListState(new ListStateDescriptor<>("localBatch", type)); @@ -255,7 +256,8 @@ public void initializeState(StateInitializationContext context) throws Exception } @Override - public void processElement1(StreamRecord pointsRecord) throws Exception { + public void processElement1(StreamRecord pointsRecord) + throws Exception { localBatchDataState.add(pointsRecord.getValue()); alignAndComputeModelData(); } @@ -276,30 +278,30 @@ private void alignAndComputeModelData() throws Exception { KMeansModelData modelData = OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get(); - DenseVector[] centroids = modelData.centroids; + DenseIntDoubleVector[] centroids = modelData.centroids; VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[modelData.centroids.length]; for (int i = 0; i < centroidsWithNorm.length; i++) { centroidsWithNorm[i] = new VectorWithNorm(modelData.centroids[i]); } - DenseVector weights = modelData.weights; + DenseIntDoubleVector weights = modelData.weights; modelDataState.clear(); - List pointsList = + List pointsList = IteratorUtils.toList(localBatchDataState.get().iterator()); - DenseVector[] points = pointsList.remove(0); + DenseIntDoubleVector[] points = pointsList.remove(0); localBatchDataState.update(pointsList); int dim = centroids[0].size(); int parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); // Computes new centroids. - DenseVector[] sums = new DenseVector[k]; + DenseIntDoubleVector[] sums = new DenseIntDoubleVector[k]; int[] counts = new int[k]; for (int i = 0; i < k; i++) { - sums[i] = new DenseVector(dim); + sums[i] = new DenseIntDoubleVector(dim); } - for (DenseVector point : points) { + for (DenseIntDoubleVector point : points) { int closestCentroidId = distanceMeasure.findClosest(centroidsWithNorm, new VectorWithNorm(point)); counts[closestCentroidId]++; @@ -313,7 +315,7 @@ private void alignAndComputeModelData() throws Exception { continue; } - DenseVector centroid = centroids[i]; + DenseIntDoubleVector centroid = centroids[i]; weights.values[i] = weights.values[i] + counts[i]; double lambda = counts[i] / weights.values[i]; @@ -325,7 +327,7 @@ private void alignAndComputeModelData() throws Exception { } } - private static class FeaturesExtractor implements MapFunction { + private static class FeaturesExtractor implements MapFunction { private final String featuresCol; private FeaturesExtractor(String featuresCol) { @@ -333,8 +335,8 @@ private FeaturesExtractor(String featuresCol) { } @Override - public DenseVector map(Row row) { - return ((Vector) row.getField(featuresCol)).toDense(); + public DenseIntDoubleVector map(Row row) { + return ((IntDoubleVector) row.getField(featuresCol)).toDense(); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java index 43742f186..91cc7633f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.distance.DistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -173,7 +173,9 @@ public void processElement1(StreamRecord streamRecord) throws Exception { bufferedPointsState.add(dataPoint); return; } - DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense(); + DenseIntDoubleVector point = + (DenseIntDoubleVector) + (((IntDoubleVector) dataPoint.getField(featuresCol)).toDense()); int closestCentroidId = distanceMeasure.findClosest(centroids, new VectorWithNorm(point)); output.collect(new StreamRecord<>(Row.join(dataPoint, Row.of(closestCentroidId)))); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java index cd24c0684..23c8b438d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java @@ -22,7 +22,8 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; /** The loss function for binary logistic loss. See {@link LogisticRegression} for example. */ @Internal @@ -32,19 +33,33 @@ public class BinaryLogisticLoss implements LossFunc { private BinaryLogisticLoss() {} @Override - public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; - return dataPoint.getWeight() * Math.log(1 + Math.exp(-dot * labelScaled)); + public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { + double dot = BLAS.dot((IntDoubleVector) dataPoint.features, coefficient); + double labelScaled = 2 * dataPoint.label - 1; + return dataPoint.weight * Math.log(1 + Math.exp(-dot * labelScaled)); } @Override public void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; - double multiplier = - dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); - BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient, dataPoint.getFeatures().size()); + LabeledPointWithWeight dataPoint, + DenseIntDoubleVector coefficient, + DenseIntDoubleVector cumGradient) { + IntDoubleVector feature = (IntDoubleVector) dataPoint.features; + double dot = BLAS.dot(feature, coefficient); + double labelScaled = 2 * dataPoint.label - 1; + double multiplier = dataPoint.weight * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); + BLAS.axpy(multiplier, feature, cumGradient, feature.size()); + } + + @Override + public double computeLoss(double label, double prediction) { + double labelScaled = 2 * label - 1; + return Math.log(1 + Math.exp(-prediction * labelScaled)); + } + + @Override + public double computeGradient(double label, double prediction) { + double labelScaled = 2 * label - 1; + return -labelScaled / (Math.exp(prediction * labelScaled) + 1); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java index eb0f3bf58..38f7a43a8 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java @@ -22,7 +22,8 @@ import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; /** * The loss function for hinge loss. See {@link LinearSVC} for example. @@ -36,23 +37,22 @@ public class HingeLoss implements LossFunc { private HingeLoss() {} @Override - public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; - return dataPoint.getWeight() * Math.max(0, 1 - labelScaled * dot); + public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { + double dot = BLAS.dot((IntDoubleVector) dataPoint.features, coefficient); + double labelScaled = 2 * dataPoint.label - 1; + return dataPoint.weight * Math.max(0, 1 - labelScaled * dot); } @Override public void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; + LabeledPointWithWeight dataPoint, + DenseIntDoubleVector coefficient, + DenseIntDoubleVector cumGradient) { + IntDoubleVector feature = (IntDoubleVector) dataPoint.features; + double dot = BLAS.dot(feature, coefficient); + double labelScaled = 2 * dataPoint.label - 1; if (1 - labelScaled * dot > 0) { - BLAS.axpy( - -labelScaled * dataPoint.getWeight(), - dataPoint.getFeatures(), - cumGradient, - dataPoint.getFeatures().size()); + BLAS.axpy(-labelScaled * dataPoint.weight, feature, cumGradient, feature.size()); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java index ea64649b1..d76093189 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java @@ -21,7 +21,8 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.regression.linearregression.LinearRegression; /** The loss function for least square loss. See {@link LinearRegression} for example. */ @@ -32,19 +33,22 @@ public class LeastSquareLoss implements LossFunc { private LeastSquareLoss() {} @Override - public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - return dataPoint.getWeight() * 0.5 * Math.pow(dot - dataPoint.getLabel(), 2); + public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { + double dot = BLAS.dot((IntDoubleVector) dataPoint.features, coefficient); + return dataPoint.weight * 0.5 * Math.pow(dot - dataPoint.label, 2); } @Override public void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + LabeledPointWithWeight dataPoint, + DenseIntDoubleVector coefficient, + DenseIntDoubleVector cumGradient) { + IntDoubleVector feature = (IntDoubleVector) dataPoint.features; + double dot = BLAS.dot(feature, coefficient); BLAS.axpy( - (dot - dataPoint.getLabel()) * dataPoint.getWeight(), - dataPoint.getFeatures(), + (dot - dataPoint.label) * dataPoint.weight, + (IntDoubleVector) dataPoint.features, cumGradient, - dataPoint.getFeatures().size()); + feature.size()); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java index a90967a73..777130142 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java @@ -20,7 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import java.io.Serializable; @@ -37,7 +37,7 @@ public interface LossFunc extends Serializable { * @param coefficient The model parameters. * @return The loss of the input data. */ - double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient); + double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient); /** * Computes the gradient on the given data point and adds the computed gradient to cumGradient. @@ -47,5 +47,17 @@ public interface LossFunc extends Serializable { * @param cumGradient The accumulated gradient. */ void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient); + LabeledPointWithWeight dataPoint, + DenseIntDoubleVector coefficient, + DenseIntDoubleVector cumGradient); + + /** Computes loss using the label and the prediction. */ + default double computeLoss(double label, double prediction) { + throw new UnsupportedOperationException("Not supported yet."); + } + + /** Computes gradient using the label and the prediction. */ + default double computeGradient(double label, double prediction) { + throw new UnsupportedOperationException("Not supported yet."); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java index 647741d13..3cabc8ff6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java @@ -21,7 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.common.lossfunc.LossFunc; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; /** @@ -39,8 +39,8 @@ public interface Optimizer { * @param lossFunc The loss function to optimize. * @return The fitted model data. */ - DataStream optimize( - DataStream initModelData, + DataStream optimize( + DataStream initModelData, DataStream trainData, LossFunc lossFunc); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java index 3d36d9aba..69f1a83f2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java @@ -20,7 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; /** * A utility class for algorithms that need to handle regularization. The regularization term is @@ -45,7 +45,7 @@ class RegularizationUtils { * @return The loss introduced by regularization. */ public static double regularize( - DenseVector coefficient, + DenseIntDoubleVector coefficient, final double reg, final double elasticNet, final double learningRate) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java index 2f7800482..471d24cdb 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java @@ -37,8 +37,8 @@ import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; import org.apache.flink.ml.common.lossfunc.LossFunc; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; @@ -79,8 +79,8 @@ public SGD( } @Override - public DataStream optimize( - DataStream initModelData, + public DataStream optimize( + DataStream initModelData, DataStream trainData, LossFunc lossFunc) { DataStreamList resultList = @@ -111,8 +111,8 @@ public IterationBodyResult process( // totalLoss]. DataStream variableStream = variableStreams.get(0); DataStream trainData = dataStreams.get(0); - final OutputTag modelDataOutputTag = - new OutputTag("MODEL_OUTPUT") {}; + final OutputTag modelDataOutputTag = + new OutputTag("MODEL_OUTPUT") {}; SingleOutputStreamOperator modelUpdateAndWeightAndLoss = trainData @@ -164,7 +164,7 @@ private static class CacheDataAndDoTrain extends AbstractStreamOperator modelDataOutputTag; + private final OutputTag modelDataOutputTag; /** The cached training data. */ private List trainData; @@ -177,9 +177,9 @@ private static class CacheDataAndDoTrain extends AbstractStreamOperator nextBatchOffsetState; /** The model coefficient. */ - private DenseVector coefficient; + private DenseIntDoubleVector coefficient; - private ListState coefficientState; + private ListState coefficientState; /** The dimension of the coefficient. */ private int coefficientDim; @@ -196,7 +196,9 @@ private static class CacheDataAndDoTrain extends AbstractStreamOperator modelDataOutputTag) { + LossFunc lossFunc, + SGDParams params, + OutputTag modelDataOutputTag) { this.lossFunc = lossFunc; this.params = params; this.modelDataOutputTag = modelDataOutputTag; @@ -232,7 +234,7 @@ private void updateModel() { if (getTotalWeight() > 0) { BLAS.axpy( -params.learningRate / getTotalWeight(), - new DenseVector(feedbackArray), + new DenseIntDoubleVector(feedbackArray), coefficient, coefficientDim); double regLoss = @@ -247,7 +249,7 @@ public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { if (epochWatermark == 0) { - coefficient = new DenseVector(feedbackArray); + coefficient = new DenseIntDoubleVector(feedbackArray); coefficientDim = coefficient.size(); feedbackArray = new double[coefficient.size() + 2]; } else { @@ -271,11 +273,11 @@ public void onEpochWatermarkIncremented( Arrays.fill(feedbackArray, 0); double totalLoss = 0; double totalWeight = 0; - DenseVector cumGradientsWrapper = new DenseVector(feedbackArray); + DenseIntDoubleVector cumGradientsWrapper = new DenseIntDoubleVector(feedbackArray); for (LabeledPointWithWeight dataPoint : miniBatchData) { totalLoss += lossFunc.computeLoss(dataPoint, coefficient); lossFunc.computeGradient(dataPoint, coefficient, cumGradientsWrapper); - totalWeight += dataPoint.getWeight(); + totalWeight += dataPoint.weight; } setTotalLoss(totalLoss); setTotalWeight(totalWeight); @@ -311,7 +313,8 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "coefficientState", DenseVectorTypeInfo.INSTANCE)); + "coefficientState", + DenseIntDoubleVectorTypeInfo.INSTANCE)); OperatorStateUtils.getUniqueElement(coefficientState, "coefficientState") .ifPresent(x -> coefficient = x); if (coefficient != null) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOptimizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOptimizer.java new file mode 100644 index 000000000..8165db869 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOptimizer.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.WithParams; + +/** + * cc + * + * @param + */ +public interface HasOptimizer extends WithParams { + Param OPTIMIZER = + new Param<>("optimizer", Optimizer.class, "The optimizer", new SGD(0.1, 100), null); + + default Optimizer getOptimizerParam() { + return get(OPTIMIZER); + } + + default T setOptimizer(Optimizer optimizer) { + return set(OPTIMIZER, optimizer); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/Optimizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/Optimizer.java new file mode 100644 index 000000000..970c34d2e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/Optimizer.java @@ -0,0 +1,3 @@ +package org.apache.flink.ml.common.param; + +public interface Optimizer {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/SGD.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/SGD.java new file mode 100644 index 000000000..e2b47996d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/SGD.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.ml.common.param; + +/** xx */ +public class SGD implements Optimizer { + public final double stepSize; + public final int globalBatchSize; + + public SGD(double stepSize, int globalBatchSize) { + this.stepSize = stepSize; + this.globalBatchSize = globalBatchSize; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ResponseAssemblerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ResponseAssemblerOperator.java new file mode 100644 index 000000000..29e7c4b3e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ResponseAssemblerOperator.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.message.Message; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +/** + * Assembles the responses from different servers for one pull request. + * + *

Note that for each single-thread worker, there are at exactly #numServers segments for each + * pull request in the responses. + */ +public class ResponseAssemblerOperator extends AbstractStreamOperator + implements OneInputStreamOperator, byte[]> { + private final int numServers; + + private int workerId; + + private int numResponsesReceived = 0; + private ListState numResponsesReceivedState; + + private ListState responsesReceived; + + public ResponseAssemblerOperator(int numServers) { + this.numServers = numServers; + } + + @Override + public void open() throws Exception { + super.open(); + this.workerId = getRuntimeContext().getIndexOfThisSubtask(); + } + + @Override + public void processElement(StreamRecord> element) throws Exception { + Preconditions.checkState(element.getValue().f0 == workerId); + responsesReceived.add(element.getValue().f1); + numResponsesReceived++; + + if (numResponsesReceived == numServers) { + Message message = Message.assembleMessages(responsesReceived.get().iterator()); + output.collect(new StreamRecord<>(message.bytes)); + responsesReceived.clear(); + numResponsesReceived = 0; + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + responsesReceived = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "responsesReceivedState", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + numResponsesReceivedState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("numResponsesReceivedState", Types.INT)); + numResponsesReceived = + OperatorStateUtils.getUniqueElement( + numResponsesReceivedState, "numResponsesReceived") + .orElse(0); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + responsesReceived.clear(); + if (numResponsesReceived > 0) { + numResponsesReceivedState.clear(); + numResponsesReceivedState.add(numResponsesReceived); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java new file mode 100644 index 000000000..99f8afab0 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.ps.message.Message; +import org.apache.flink.ml.common.ps.message.MessageType; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; + +/** ServerAgent resides on each worker. It serves as an agent for workers to talk with servers. */ +public class ServerAgent { + /** Index of the worker that this agent resides on. */ + private final int workerId; + /** Number of servers to talk to. */ + private int numServers; + /** Key ranges of each server. */ + private long[] ranges; + /** The collector on this worker. */ + private final Output>> output; + + ServerAgent(int workerId, Output>> output) { + this.workerId = workerId; + this.output = output; + } + + void open(int numServers, long maxKey) { + this.numServers = numServers; + this.ranges = new long[numServers + 1]; + long shardSize = (maxKey + 1) / numServers; + + for (int serverId = 0; serverId < numServers; serverId++) { + ranges[serverId] = shardSize * serverId; + } + ranges[numServers] = maxKey + 1; + } + + /** Sends a request to servers to initialize key range on each server. */ + void initialize() { + for (int serverId = 0; serverId < numServers; serverId++) { + long start = ranges[serverId]; + long end = ranges[serverId + 1]; + Message message = + new Message( + serverId, + workerId, + MessageType.INITIALIZE, + new long[] {start, end}, + new double[0]); + output.collect(new StreamRecord<>(Tuple2.of(serverId, message.bytes))); + } + } + + /** Pushes a key-value arrays to servers. */ + void push(long[] indices, double[] values) { + Iterator> requests = sliceRequest(indices, values); + while (requests.hasNext()) { + Tuple3 request = requests.next(); + Message message = + new Message(request.f0, workerId, MessageType.PUSH, request.f1, request.f2); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, message.bytes))); + } + } + + /** Pulls the values from servers with the specified indices. */ + void pull(long[] indices) { + Iterator> requests = sliceRequest(indices, null); + while (requests.hasNext()) { + Tuple3 request = requests.next(); + Message message = + new Message(request.f0, workerId, MessageType.PULL, request.f1, new double[0]); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, message.bytes))); + } + } + + /** + * Pushes the values to servers to apply all reduce operation. + * + *

Note that the values pushed by this function are not going to update the model, but just + * perform an all reduce operation. + */ + void allReduce(V[] values, TypeSerializer typeSerializer) throws IOException { + int messageSize = values.length / numServers + 1; + for (int serverId = 0; serverId < numServers; serverId++) { + int s = Math.min(serverId * messageSize, values.length); + int e = Math.min(s + messageSize, values.length); + V[] segment = Arrays.copyOfRange(values, s, e); + Message message = + new Message( + workerId, + serverId, + MessageType.ALL_REDUCE, + new long[0], + segment, + typeSerializer); + output.collect(new StreamRecord<>(Tuple2.of(serverId, message.bytes))); + } + } + + /** + * Splits the push/pull request according to the given sorted indices and the corresponding + * values. + * + * @param indices sorted indices of push/pull request. + * @param values the push values if not null. + * @return the split requests for each server. + */ + private Iterator> sliceRequest( + long[] indices, @Nullable double[] values) { + return new RequestsIterator(numServers, indices, values, ranges); + } + + private static class RequestsIterator implements Iterator> { + private final int numServers; + private final long[] indices; + private final double[] values; + /** + * Number of values per key. If the model data is a vector, numValuesPerKey is one. If the + * model data is a matrix, numValuesPerKey is the number of columns. + */ + private final int numValuesPerKey; + + private final long[] ranges; + + private int serverId = 0; + + private int s = 0; + + public RequestsIterator( + int numServers, long[] indices, @Nullable double[] values, long[] ranges) { + this.numServers = numServers; + this.indices = indices; + this.values = values; + this.ranges = ranges; + if (indices.length != 0 && values != null) { + numValuesPerKey = values.length / indices.length; + Preconditions.checkArgument( + numValuesPerKey * indices.length == values.length, + String.format( + "The size of values [%d] cannot be divided by size of keys [%d].", + values.length, indices.length)); + } else { + numValuesPerKey = 1; + } + } + + @Override + public boolean hasNext() { + return serverId < numServers; + } + + @Override + public Tuple3 next() { + int e = s; + while (e < indices.length && indices[e] < ranges[serverId + 1]) { + e++; + } + + long[] splitIndices = new long[0]; + double[] splitValues = values == null ? null : new double[0]; + if (s < e) { + splitIndices = Arrays.copyOfRange(indices, s, e); + splitValues = + values == null + ? null + : Arrays.copyOfRange( + values, s * numValuesPerKey, e * numValuesPerKey); + } + s = e; + serverId++; + return Tuple3.of(serverId - 1, splitIndices, splitValues); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java new file mode 100644 index 000000000..1d557adb4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.message.Message; +import org.apache.flink.ml.common.ps.message.MessageType; +import org.apache.flink.ml.common.ps.training.AllReduceStage; +import org.apache.flink.ml.common.ps.training.IterationStage; +import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.PullStage; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +/** + * The server operator maintains the shared parameters. The shared parameters can be modeled as a + * collection of {key:value} pairs. By default, the keys are evenly distributed across servers + * through range partitioning. For example, if there are two servers and the keys are {1,2,3,4,5,6}, + * then server-0 maintains keys {1,2,3} and server-1 maintains keys {4,5,6}. + * + *

The server receives push/pull/allreduce requests from {@link WorkerOperator} and sends the + * answer request to {@link ResponseAssemblerOperator}. It works closely with {@link ModelUpdater} + * in the following way: + * + *

    + *
  • The server operator deals with the message from workers and decides when to process the + * received message. + *
  • The server operator calls {@link ModelUpdater#update(long[], double[])} and {@link + * ModelUpdater#get(long[])} to process the messages in detail. + *
  • The server operator triggers checkpoint for {@link ModelUpdater}. + *
  • The server operator outputs the final output parameters by calling {@link + * ModelUpdater#getModelSegments()}. + *
+ * + *

Moreover, it accepts all-reduce request from workers and returns the reduced result to all + * workers. Note that the input of all reduce operation is not going to be used in {@link + * ModelUpdater}. + * + *

TODO: Add support for asynchronous operations on servers. + * + * @param output format of model data. + */ +public class ServerOperator extends AbstractStreamOperator> + implements OneInputStreamOperator, Tuple2>, + IterationListener> { + /** The iterationStage list that asks responses from servers. */ + private final List stageList; + /** Number of workers to communicate with. */ + private final int numWorkers; + /** The logic to answer push/pull request from workers. */ + private final ModelUpdater modelUpdater; + /** Format of model data. */ + private final OutputTag modelOutputTag; + /** Index of the server task. */ + private int serverId = -1; + /** + * Thread pool to answer push/pull requests, to decouple the network traffic and computation + * logic. + */ + private transient ExecutorService singleThreadExecutor; + /** The future objects of thread calls in one epoch. */ + private final List> futuresInEpoch = new ArrayList<>(); + /** The merger for push requests. */ + private final PushRequestMerger pushRequestMerger; + /** The pending pull requests. */ + private ListState pendingPulls; + + /** The pending allreduce requests. */ + private ListState pendingAllReduces; + + public ServerOperator( + IterationStageList iterationStageList, + int numWorkers, + ModelUpdater modelUpdater, + OutputTag modelOutputTag) { + this.stageList = new ArrayList<>(); + for (IterationStage stage : iterationStageList.stageList) { + if (stage instanceof PullStage || stage instanceof AllReduceStage) { + stageList.add(stage); + } + } + this.numWorkers = numWorkers; + this.modelUpdater = modelUpdater; + this.modelOutputTag = modelOutputTag; + this.pushRequestMerger = new PushRequestMerger(); + } + + @Override + public void open() throws Exception { + super.open(); + this.serverId = getRuntimeContext().getIndexOfThisSubtask(); + this.singleThreadExecutor = Executors.newSingleThreadExecutor(); + } + + @Override + public void processElement(StreamRecord> element) throws Exception { + byte[] request = element.getValue().f1; + Message message = new Message(element.getValue().f1); + MessageType type = message.getMessageType(); + switch (type) { + case INITIALIZE: + long[] indices = message.getKeys(); + Preconditions.checkState(serverId == message.getServerId() && indices.length == 2); + if (message.getWorkerId() == 0) { + modelUpdater.open(indices[0], indices[1]); + } + break; + case PUSH: + futuresInEpoch.add( + singleThreadExecutor.submit( + () -> pushRequestMerger.processPushRequest(message))); + break; + case PULL: + pendingPulls.add(request); + break; + case ALL_REDUCE: + pendingAllReduces.add(request); + break; + default: + throw new UnsupportedOperationException("Unsupported message type: " + type + "."); + } + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector> collector) + throws Exception { + // Waits until all pushes have been processed. + for (Future future : futuresInEpoch) { + future.get(); + } + futuresInEpoch.clear(); + + // Processes the pushes first. + Tuple2 kvs = pushRequestMerger.toKvArrays(); + pushRequestMerger.accumulatedKvsForMatrix.clear(); + pushRequestMerger.accumulatedKvsForVector.clear(); + if (kvs.f0.length > 0) { + // There are pushes at this epoch. + modelUpdater.update(kvs.f0, kvs.f1); + } + + Iterator pullsIterator = pendingPulls.get().iterator(); + if (pullsIterator.hasNext()) { + // This is a pull stage. + while (pullsIterator.hasNext()) { + byte[] pull = pullsIterator.next(); + futuresInEpoch.add( + singleThreadExecutor.submit(() -> processPullRequest(new Message(pull)))); + } + } + Iterator allreduceIterator = pendingAllReduces.get().iterator(); + if (allreduceIterator.hasNext()) { + int stageId = epochWatermark % stageList.size(); + AllReduceStage allReduceStage = (AllReduceStage) stageList.get(stageId); + Message reducedResult = processAllReduceRequest(allReduceStage, allreduceIterator); + for (int workerId = 0; workerId < numWorkers; workerId++) { + reducedResult.setWorkerId(workerId); + output.collect(new StreamRecord<>(Tuple2.of(workerId, reducedResult.bytes))); + } + } + + for (Future future : futuresInEpoch) { + future.get(); + } + pendingPulls.clear(); + pendingAllReduces.clear(); + futuresInEpoch.clear(); + } + + private Message processAllReduceRequest(AllReduceStage stage, Iterator requests) + throws Exception { + ReduceFunction reduceFunction = stage.reducer; + V[] reducedResult = null; + while (requests.hasNext()) { + byte[] allreduceRequest = requests.next(); + Message message = new Message(allreduceRequest); + V[] receivedResult = message.getValues(stage.typeSerializer); + if (reducedResult == null) { + reducedResult = receivedResult; + } else { + reducedResult = reduceFunction.reduce(receivedResult, reducedResult); + } + } + + return new Message( + -1, -1, MessageType.ALL_REDUCE, new long[0], reducedResult, stage.typeSerializer); + } + + @Override + public void onIterationTerminated( + Context context, Collector> collector) { + Iterator modelSegments = modelUpdater.getModelSegments(); + while (modelSegments.hasNext()) { + MT modelSegment = modelSegments.next(); + output.collect(modelOutputTag, new StreamRecord<>(modelSegment)); + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + pendingPulls = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "pendingPulls", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + pendingAllReduces = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "pendingAllReduces", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + modelUpdater.initializeState(context); + pushRequestMerger.initializeState(context); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + + // Waits until the futures to finish. + for (Future future : futuresInEpoch) { + future.get(); + } + futuresInEpoch.clear(); + modelUpdater.snapshotState(context); + pushRequestMerger.snapshotState(context); + } + + private Object processPullRequest(Message message) { + Preconditions.checkState(serverId == message.getServerId()); + int workerId = message.getWorkerId(); + double[] pulledValues = modelUpdater.get(message.getKeys()); + Message pulledMessage = + new Message(serverId, workerId, MessageType.PULL, new long[0], pulledValues); + StreamRecord> record = + new StreamRecord<>(Tuple2.of(workerId, pulledMessage.bytes)); + + output.collect(record); + return new Object(); + } + + /** Utility class to merge the push request from different workers. */ + private static class PushRequestMerger implements Serializable { + /** + * The accumulated kv if the push request is for a vector. If the value is a double, we use + * {@link Long2DoubleOpenHashMap} for better efficiency. + */ + private final Long2DoubleOpenHashMap accumulatedKvsForVector; + /** The accumulated kv if the push request is for a matrix. */ + private final Map accumulatedKvsForMatrix; + /** The state for accumulated kv. */ + private ListState accumulatedKvsState; + + public PushRequestMerger() { + this.accumulatedKvsForVector = new Long2DoubleOpenHashMap(); + this.accumulatedKvsForMatrix = new HashMap<>(); + } + + private Object processPushRequest(Message message) { + long[] keys = message.getKeys(); + double[] values = message.getValuesInDoubleArray(); + + if (values.length == keys.length) { + for (int i = 0; i < keys.length; i++) { + accumulatedKvsForVector.merge(keys[i], values[i], Double::sum); + } + } else { + int valuesPerKey = values.length / keys.length; + for (int i = 0; i < keys.length; i++) { + accumulatedKvsForMatrix.putIfAbsent(keys[i], new double[valuesPerKey]); + double[] partialValue = accumulatedKvsForMatrix.get(keys[i]); + for (int j = 0; j < valuesPerKey; j++) { + partialValue[j] += values[i * valuesPerKey + j]; + } + } + } + return new Object(); + } + + /** Transforms the processed push request to kv arrays. */ + private Tuple2 toKvArrays() { + long[] indices = new long[0]; + double[] values = new double[0]; + if (accumulatedKvsForVector.size() != 0) { + indices = new long[accumulatedKvsForVector.size()]; + values = new double[indices.length]; + + int idx = 0; + for (Map.Entry entry : accumulatedKvsForVector.entrySet()) { + indices[idx] = entry.getKey(); + values[idx] = entry.getValue(); + idx++; + } + } else if (accumulatedKvsForMatrix.size() != 0) { + indices = new long[accumulatedKvsForMatrix.size()]; + int numValuesPerKey = + accumulatedKvsForMatrix.entrySet().iterator().next().getValue().length; + values = new double[indices.length * numValuesPerKey]; + int idx = 0; + for (Map.Entry entry : accumulatedKvsForMatrix.entrySet()) { + indices[idx] = entry.getKey(); + System.arraycopy( + entry.getValue(), 0, values, idx * numValuesPerKey, numValuesPerKey); + idx++; + } + } + return Tuple2.of(indices, values); + } + + private void initializeState(StateInitializationContext context) throws Exception { + accumulatedKvsState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accumulatedKvs", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + + byte[] accumulatedKvsInBytes = + OperatorStateUtils.getUniqueElement(accumulatedKvsState, "accumulatedKvs") + .orElse(null); + + if (accumulatedKvsInBytes != null) { + Tuple2 kvs = Bits.getLongDoubleArray(accumulatedKvsInBytes, 0); + long[] keys = kvs.f0; + double[] values = kvs.f1; + int numValuesPerKey = values.length / keys.length; + if (numValuesPerKey == 1) { + for (int i = 0; i < keys.length; i++) { + accumulatedKvsForVector.put(keys[i], values[i]); + } + } else { + for (int i = 0; i < keys.length; i++) { + accumulatedKvsForMatrix.put( + keys[i], + Arrays.copyOfRange( + values, + i * numValuesPerKey, + i * numValuesPerKey + numValuesPerKey)); + } + } + } + } + + private void snapshotState(StateSnapshotContext context) throws Exception { + Tuple2 kvs = toKvArrays(); + accumulatedKvsState.clear(); + if (kvs.f0.length > 0) { + byte[] bytes = new byte[Bits.getLongDoubleArraySizeInBytes(kvs)]; + Bits.putLongDoubleArray(kvs, bytes, 0); + accumulatedKvsState.add(bytes); + } + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java new file mode 100644 index 000000000..54855c961 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.message.Message; +import org.apache.flink.ml.common.ps.training.AllReduceStage; +import org.apache.flink.ml.common.ps.training.IterationStage; +import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.MLSession; +import org.apache.flink.ml.common.ps.training.ProcessStage; +import org.apache.flink.ml.common.ps.training.ProxySideOutput; +import org.apache.flink.ml.common.ps.training.PullStage; +import org.apache.flink.ml.common.ps.training.PushStage; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; + +import java.io.IOException; +import java.util.Iterator; + +/** + * The worker operator that executes the machine learning training process following {@link + * IterationStageList}. + * + *

In detail, the worker operator is responsible for the following: + * + *

    + *
  • Caches the training data. + *
  • Initializes the {@link MLSession}. + *
  • Splits the {@link IterationStageList} by {@link PullStage} and {@link AllReduceStage} into + * multiple sequences and map it into flink-ml-iterations. + *
  • Executes the process function in each {@link ProcessStage}. + *
  • Executes the push/pull request in {@link PushStage} and {@link PullStage} and talk to + * servers, by reading/writing {@link MLSession}. + *
+ */ +public class WorkerOperator + extends AbstractStreamOperator> + implements TwoInputStreamOperator>, + IterationListener> { + /** Number of servers that this worker needs to talk to. */ + private final int numServers; + + /** The user defined iteration logic. */ + private final IterationStageList iterationStages; + + /** The agent for each worker to talk with servers. */ + private ServerAgent serverAgent; + + /** + * Iteration id in terms of {@link IterationStageList}. When we finished processing all stages + * in stageList, the iteration id increments by one. + */ + private int iterationId; + + private ListState iterationIdState; + + /** The id of the stages to execute in iterationStages. */ + private int nextStageToExecute = 0; + + private ListState nextStageToExecuteState; + + /** The cached training data. */ + private ListStateWithCache
trainDataState; + + /** The feedback array from iterations. */ + private byte[] feedback; + + private ListState feedbackState; + + /** Dimension of the model. */ + private long modelDim = 0; + + private ListState modelDimState; + + public WorkerOperator(IterationStageList iterationStages, int numServers) { + this.iterationStages = iterationStages; + this.numServers = numServers; + } + + @Override + public void open() { + int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); + int workerId = getRuntimeContext().getIndexOfThisSubtask(); + this.serverAgent = new ServerAgent(workerId, output); + iterationStages.session.setWorldInfo(workerId, numTasks); + iterationStages.session.setOutput(new ProxySideOutput(output)); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector> collector) + throws Exception { + if (epochWatermark == 0) { + modelDim = Bits.getLong(feedback, 0); + serverAgent.open(numServers, modelDim - 1); + serverAgent.initialize(); + iterationStages.session.setInputData(new ResettableTrainDataIterator<>(trainDataState)); + nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); + } + } + + @Override + public void onIterationTerminated( + Context context, Collector> collector) { + trainDataState.clear(); + } + + @Override + public void processElement1(StreamRecord
streamRecord) throws Exception { + trainDataState.add(streamRecord.getValue()); + } + + @Override + public void processElement2(StreamRecord streamRecord) throws Exception { + feedback = streamRecord.getValue(); + if (modelDim > 0) { + Message message = new Message(streamRecord.getValue()); + IterationStage stage = iterationStages.stageList.get(nextStageToExecute); + if (stage instanceof PullStage) { + PullStage pullStage = (PullStage) stage; + pullStage.valuesConsumer.accept(message.getValuesInDoubleArray()); + } else if (stage instanceof AllReduceStage) { + AllReduceStage allReduceStage = (AllReduceStage) stage; + processAllReduceStage(allReduceStage, message); + } else { + throw new IllegalStateException( + String.format("Illegal stage type: %s", stage.getClass().getSimpleName())); + } + + nextStageToExecute++; + nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); + } + } + + private void processAllReduceStage(AllReduceStage stage, Message message) + throws IOException { + V[] reducedResult = message.getValues(stage.typeSerializer); + stage.valuesConsumer.accept(reducedResult); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + feedbackState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "feedbackState", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + OperatorStateUtils.getUniqueElement(feedbackState, "feedbackState") + .ifPresent(x -> feedback = x); + + trainDataState = + new ListStateWithCache<>( + (getOperatorConfig().getTypeSerializerIn(0, getClass().getClassLoader())), + getContainingTask(), + getRuntimeContext(), + context, + config.getOperatorID()); + + nextStageToExecuteState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("nextStageToExecuteState", Types.INT)); + nextStageToExecute = + OperatorStateUtils.getUniqueElement( + nextStageToExecuteState, "nextStageToExecuteState") + .orElse(0); + + modelDimState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("modelDimState", Types.LONG)); + modelDim = OperatorStateUtils.getUniqueElement(modelDimState, "modelDimState").orElse(0L); + + iterationIdState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("iterationIdState", Types.INT)); + iterationId = + OperatorStateUtils.getUniqueElement(iterationIdState, "iterationIdState").orElse(0); + + if (modelDim > 0) { + serverAgent.open(numServers, modelDim - 1); + } + + iterationStages.session.initializeState(context); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + feedbackState.clear(); + if (feedback != null) { + feedbackState.add(feedback); + } + + nextStageToExecuteState.clear(); + nextStageToExecuteState.add(nextStageToExecute); + modelDimState.clear(); + modelDimState.add(modelDim); + iterationIdState.clear(); + iterationIdState.add(iterationId); + + trainDataState.snapshotState(context); + iterationStages.session.snapshotState(context); + } + + /** + * Processes the stages described in the given iterationStages from the given nextStage id. This + * function processes the stages until it meets a {@link PullStage} or {@link AllReduceStage}. + * + * @param nextStageToExecute id of the next stage to execute in the given iteration stages. + * @param iterationStages iteration stages used to describe the training logic. + * @return the id of the next stage to execute. + */ + @SuppressWarnings("unchecked") + private int processIterationStages( + int nextStageToExecute, IterationStageList iterationStages) throws Exception { + while (true) { + if (nextStageToExecute >= iterationStages.stageList.size()) { + iterationId++; + iterationStages.session.setIterationId(iterationId); + if (iterationStages.shouldTerminate.apply(iterationStages.session)) { + return -1; + } + nextStageToExecute -= iterationStages.stageList.size(); + } + IterationStage stage = iterationStages.stageList.get(nextStageToExecute); + + // We are not incrementing nextStageToExecute for PullStage and AllReduceStage, since we + // will need to receive values from servers. + if (stage instanceof PullStage) { + PullStage pullStage = ((PullStage) stage); + serverAgent.pull(pullStage.keysSupplier.get()); + return nextStageToExecute; + + } else if (stage instanceof AllReduceStage) { + AllReduceStage allReduceStage = (AllReduceStage) stage; + serverAgent.allReduce( + allReduceStage.valuesSupplier.get(), allReduceStage.typeSerializer); + return nextStageToExecute; + + } else if (stage instanceof PushStage) { + PushStage pushStage = (PushStage) stage; + serverAgent.push(pushStage.keysSupplier.get(), pushStage.valuesSupplier.get()); + nextStageToExecute++; + + } else if (stage instanceof ProcessStage) { + ((ProcessStage) stage).process(iterationStages.session); + nextStageToExecute++; + + } else { + throw new IllegalStateException( + "Illegal type of IterationStage: + " + stage.getClass().getSimpleName()); + } + } + } + + /** A resettable iterator for {@link ListStateWithCache}. */ + private static class ResettableTrainDataIterator implements ResettableIterator { + private final ListStateWithCache data; + private Iterator dataIterator; + + public ResettableTrainDataIterator(ListStateWithCache data) throws Exception { + this.data = data; + this.dataIterator = data.get().iterator(); + } + + @Override + public void reset() { + try { + this.dataIterator = data.get().iterator(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean hasNext() { + return dataIterator.hasNext(); + } + + @Override + public T next() { + return dataIterator.next(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java new file mode 100644 index 000000000..67997a7f4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +/** + * {@link Message} is responsible for encoding all messages exchanged between a worker and a server. + * The message format follows this structure: + * + *

`workerId serverId messageType keyLength keys valuesLength values` + * + *

where the message fields include the worker ID, server ID, message type, length of the keys, + * keys themselves, length of the values, and the values. + */ +public class Message { + private static final int WORKER_ID_OFFSET = 0; + private static final int SERVER_ID_OFFSET = Integer.BYTES; + private static final int MESSAGE_TYPE_OFFSET = Integer.BYTES + SERVER_ID_OFFSET; + private static final int KVS_OFFSET = Integer.BYTES + MESSAGE_TYPE_OFFSET; + + /** The storage of message in bytes. */ + public final byte[] bytes; + + /** Constructs a message instance from the bytes. */ + public Message(byte[] bytes) { + this.bytes = bytes; + } + + /** Constructs a message instance from long keys and double values. */ + public Message( + int serverId, int workerId, MessageType messageType, long[] keys, double[] values) { + int sizeInBytes = KVS_OFFSET + Bits.getLongDoubleArraySizeInBytes(Tuple2.of(keys, values)); + bytes = new byte[sizeInBytes]; + Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); + Bits.putInt(bytes, SERVER_ID_OFFSET, serverId); + Bits.putInt(bytes, MESSAGE_TYPE_OFFSET, messageType.type); + Bits.putLongDoubleArray(Tuple2.of(keys, values), bytes, KVS_OFFSET); + } + + /** Constructs a message instance from long keys and generics values. */ + public Message( + int serverId, + int workerId, + MessageType messageType, + long[] keys, + V[] values, + TypeSerializer serializer) + throws IOException { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(byteArrayOutputStream); + dataOutputViewStreamWrapper.writeInt(workerId); + dataOutputViewStreamWrapper.writeInt(serverId); + dataOutputViewStreamWrapper.writeInt(messageType.type); + + dataOutputViewStreamWrapper.writeInt(keys.length); + for (long key : keys) { + dataOutputViewStreamWrapper.writeLong(key); + } + dataOutputViewStreamWrapper.writeInt(values.length); + for (V value : values) { + serializer.serialize(value, dataOutputViewStreamWrapper); + } + bytes = byteArrayOutputStream.toByteArray(); + } + + /** Retrieves the keys. */ + public long[] getKeys() { + return Bits.getLongArray(bytes, KVS_OFFSET); + } + + /** Retrieves the values using the given serializer. */ + public V[] getValues(TypeSerializer serializer) throws IOException { + int numIndices = Bits.getInt(bytes, KVS_OFFSET); + int offset = KVS_OFFSET + Integer.BYTES + numIndices * Long.BYTES; + int numValues = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + + // Since the generics got erased, we use reflections to create the array. + V[] result = (V[]) Array.newInstance(serializer.createInstance().getClass(), numValues); + ByteArrayInputStream byteArrayInputStream = + new ByteArrayInputStream(bytes, offset, bytes.length - offset); + DataInputViewStreamWrapper dataInputViewStreamWrapper = + new DataInputViewStreamWrapper(byteArrayInputStream); + for (int i = 0; i < numValues; i++) { + result[i] = serializer.deserialize(dataInputViewStreamWrapper); + } + return result; + } + + /** + * Retrieves the values in double array format. + * + *

Note that getting double array in this function using {@link Bits#getDoubleArray(byte[], + * int)} is faster than {@link Message#getValues} by up to 2.3X. + */ + public double[] getValuesInDoubleArray() { + int offset = KVS_OFFSET + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + Integer.BYTES; + return Bits.getDoubleArray(bytes, offset); + } + + /** Retrieves the worker id. */ + public int getWorkerId() { + return Bits.getInt(bytes, WORKER_ID_OFFSET); + } + + /** Sets the worker id. */ + public void setWorkerId(int workerId) { + Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); + } + + /** Retrieves the server id. */ + public int getServerId() { + return Bits.getInt(bytes, SERVER_ID_OFFSET); + } + + /** Sets the server id. */ + public void setServerId(int serverId) { + Bits.putInt(bytes, SERVER_ID_OFFSET, serverId); + } + + /** Retrieves the message type. */ + public MessageType getMessageType() { + return MessageType.valueOf(Bits.getInt(bytes, MESSAGE_TYPE_OFFSET)); + } + + /** + * Assembles the received messages from servers according to the server id. Note that these + * messages should come from the same request. + */ + public static Message assembleMessages(Iterator messageIterator) { + List messages = new ArrayList<>(); + while (messageIterator.hasNext()) { + messages.add(new Message(messageIterator.next())); + } + messages.sort(Comparator.comparingInt(Message::getServerId)); + + int numMessages = messages.size(); + int numKeys = 0, numValues = 0; + int numAssembledBytes = 0; + int workerId = -1; + for (Message message : messages) { + Preconditions.checkState(workerId == -1 || workerId == message.getWorkerId()); + workerId = message.getWorkerId(); + numKeys += message.getNumKeys(); + numValues += message.getNumValues(); + numAssembledBytes += message.bytes.length; + } + numAssembledBytes -= (numMessages - 1) * (KVS_OFFSET + Integer.BYTES * 2); + byte[] assembledBytes = new byte[numAssembledBytes]; + int keysOffset = KVS_OFFSET; + Bits.putInt(assembledBytes, keysOffset, numKeys); + keysOffset += Integer.BYTES; + int valuesOffset = keysOffset + numKeys * Long.BYTES; + Bits.putInt(assembledBytes, valuesOffset, numValues); + valuesOffset += Integer.BYTES; + + for (Message message : messages) { + Tuple2 keyOoffsetAndLength = message.getKeysOffsetAndLength(); + System.arraycopy( + message.bytes, + keyOoffsetAndLength.f0, + assembledBytes, + keysOffset, + keyOoffsetAndLength.f1); + keysOffset += keyOoffsetAndLength.f1; + Tuple2 valuesOffsetAndLength = message.getValuesOffSetAndLength(); + System.arraycopy( + message.bytes, + valuesOffsetAndLength.f0, + assembledBytes, + valuesOffset, + valuesOffsetAndLength.f1); + valuesOffset += valuesOffsetAndLength.f1; + } + + Message message = new Message(assembledBytes); + message.setServerId(-1); + message.setWorkerId(workerId); + return message; + } + + private Tuple2 getKeysOffsetAndLength() { + int start = KVS_OFFSET + Integer.BYTES; + int numBytes = Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES; + return Tuple2.of(start, numBytes); + } + + private Tuple2 getValuesOffSetAndLength() { + int start = + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + + KVS_OFFSET + + Integer.BYTES + + Integer.BYTES; + return Tuple2.of(start, bytes.length - start); + } + + private int getNumKeys() { + return Bits.getInt(bytes, KVS_OFFSET); + } + + private int getNumValues() { + return Bits.getInt(bytes, KVS_OFFSET + Integer.BYTES + Long.BYTES * getNumKeys()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java new file mode 100644 index 000000000..de0e4f6fe --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +/** Message type between workers and servers. */ +public enum MessageType { + /** The initialization request. */ + INITIALIZE(0), + /** The pull request. */ + PUSH(1), + /** The push request. */ + PULL(2), + /** The all reduce request. */ + ALL_REDUCE(3); + + public final int type; + + MessageType(int type) { + this.type = type; + } + + public static MessageType valueOf(int value) { + switch (value) { + case 0: + return MessageType.INITIALIZE; + case 1: + return MessageType.PUSH; + case 2: + return MessageType.PULL; + case 3: + return MessageType.ALL_REDUCE; + default: + throw new UnsupportedOperationException(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java new file mode 100644 index 000000000..aaabbe633 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** A communication stage that conducts all-reduce on the given array. */ +public final class AllReduceStage implements IterationStage { + public final Supplier valuesSupplier; + public final Consumer valuesConsumer; + public final ReduceFunction reducer; + public final TypeSerializer typeSerializer; + + public AllReduceStage( + Supplier valuesSupplier, + Consumer valuesConsumer, + ReduceFunction reducer, + TypeSerializer typeSerializer) { + this.valuesSupplier = valuesSupplier; + this.valuesConsumer = valuesConsumer; + this.reducer = reducer; + this.typeSerializer = typeSerializer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java new file mode 100644 index 000000000..45f175dea --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; +import org.apache.flink.ml.linalg.Vectors; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +import java.io.IOException; +import java.util.List; + +/** An iteration stage that uses the pulled model values and batch data to compute the gradients. */ +public class ComputeGradients extends ProcessStage> { + private final LossFunc lossFunc; + + public ComputeGradients(LossFunc lossFunc) { + this.lossFunc = lossFunc; + } + + @Override + public void process(MiniBatchMLSession session) throws IOException { + long[] indices = ComputeIndices.getSortedIndices(session.batchData); + double[] modelValues = session.pulledValues; + double[] gradients = + computeGradient( + session.batchData, Tuple2.of(indices, modelValues), session.numWorkers); + + session.pushIndices = indices; + session.pushValues = gradients; + } + + private double[] computeGradient( + List batchData, + Tuple2 modelData, + int numWorkers) { + long[] modelIndices = modelData.f0; + double[] modelValues = modelData.f1; + Long2DoubleOpenHashMap modelInMap = new Long2DoubleOpenHashMap(modelIndices.length); + for (int i = 0; i < modelIndices.length; i++) { + modelInMap.put(modelIndices[i], modelValues[i]); + } + Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(modelIndices.length); + + for (LabeledPointWithWeight dataPoint : batchData) { + SparseLongDoubleVector feature = (SparseLongDoubleVector) dataPoint.features; + double dot = dot(feature, modelInMap); + double multiplier = lossFunc.computeGradient(dataPoint.label, dot) * dataPoint.weight; + + long[] featureIndices = feature.indices; + double[] featureValues = feature.values; + double z; + for (int i = 0; i < featureIndices.length; i++) { + long currentIndex = featureIndices[i]; + z = featureValues[i] * multiplier + cumGradients.getOrDefault(currentIndex, 0.); + cumGradients.put(currentIndex, z); + } + } + double[] cumGradientValues = new double[modelIndices.length]; + for (int i = 0; i < modelIndices.length; i++) { + cumGradientValues[i] = cumGradients.get(modelIndices[i]); + } + BLAS.scal(1.0 / batchData.size() / numWorkers, Vectors.dense(cumGradientValues)); + return cumGradientValues; + } + + private static double dot(SparseLongDoubleVector feature, Long2DoubleOpenHashMap coefficient) { + double dot = 0; + for (int i = 0; i < feature.indices.length; i++) { + dot += feature.values[i] * coefficient.get(feature.indices[i]); + } + return dot; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java new file mode 100644 index 000000000..d624887a6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; + +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +/** + * An iteration stage that samples a batch of training data and computes the indices needed to + * compute gradients. + */ +public class ComputeIndices extends ProcessStage> { + + @Override + public void process(MiniBatchMLSession context) throws Exception { + context.readInNextBatchData(); + context.pullIndices = getSortedIndices(context.batchData); + } + + public static long[] getSortedIndices(List dataPoints) { + LongOpenHashSet indices = new LongOpenHashSet(); + for (LabeledPointWithWeight dataPoint : dataPoints) { + SparseLongDoubleVector feature = (SparseLongDoubleVector) dataPoint.features; + long[] notZeros = feature.indices; + for (long index : notZeros) { + indices.add(index); + } + } + + long[] sortedIndices = new long[indices.size()]; + Iterator iterator = indices.iterator(); + int i = 0; + while (iterator.hasNext()) { + sortedIndices[i++] = iterator.next(); + } + Arrays.sort(sortedIndices); + return sortedIndices; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java new file mode 100644 index 000000000..4db772c25 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import java.io.Serializable; + +/** + * Iterative machine learning training usually incurs local computation step (e.g., computing + * gradients) and global communication step (e.g., all-reduce and parameter servers to aggregate the + * updates from workers). + * + *

To describe the above iteration training process, we model the training process as a sequence + * of iteration stages. An iteration stage could be either local computation or global + * communication. + */ +public interface IterationStage extends Serializable {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java new file mode 100644 index 000000000..9c430d17d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.util.function.SerializableFunction; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * A list of iteration stages to express the logic of an iterative machine learning training + * process. + */ +public class IterationStageList implements Serializable { + public final T session; + public Function shouldTerminate; + public List stageList; + + public IterationStageList(T session) { + this.stageList = new ArrayList<>(); + this.session = session; + } + + /** Sets the criteria of termination. */ + public IterationStageList setTerminationCriteria( + SerializableFunction shouldTerminate) { + this.shouldTerminate = shouldTerminate; + return this; + } + + /** Adds an iteration stage into the stage list. */ + public IterationStageList addStage(IterationStage stage) { + stageList.add(stage); + return this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java new file mode 100644 index 000000000..21a65d2c8 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.ml.common.ps.WorkerOperator; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.util.OutputTag; + +import java.io.Serializable; +import java.util.List; + +/** + * Stores the session information that is alive during the training process on {@link + * WorkerOperator}. Note that the session information will be updated by each {@link + * IterationStage}. + * + *

Note that subclasses should take care of the snapshot of object stored in {@link MLSession} if + * the object satisfies that: the write-process is followed by a {@link PullStage} or a {@link + * AllReduceStage}, which is later again read by other stages. + */ +public interface MLSession extends Serializable { + /** Sets the current iteration ID. */ + default void setIterationId(int iterationId) {} + + /** Sets the worker id and total number of workers. */ + default void setWorldInfo(int workerId, int numWorkers) {} + + /** Sets the training data. */ + default void setInputData(ResettableIterator inputData) {} + + /** Sets the collector that users can output records to downstream tasks. */ + default void setOutput(ProxySideOutput collector) {} + + /** + * Retrieves the output tags from the {@link MLSession} which can be used to output records from + * the worker operator. + */ + default List> getOutputTags() { + return null; + } + + /** Recovers from state. */ + default void initializeState(StateInitializationContext context) throws Exception {} + + /** Snapshots to state. */ + default void snapshotState(StateSnapshotContext context) throws Exception {} +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java new file mode 100644 index 000000000..196fbd215 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.util.OutputTag; + +import java.util.List; + +/** + * The default implementation of {@link MLSession}. + * + * @param

Data type of input data. + */ +public class MLSessionImpl
implements MLSession { + /** Current iteration id. */ + public int iterationId; + /** Index of this worker. */ + public int workerId; + /** Number of workers in total for this distributed ML job. */ + public int numWorkers; + /** The input data. */ + public ResettableIterator
inputData; + + public List> outputTags; + + /** Constructs an instance with side outputs. */ + public MLSessionImpl(List> outputTags) { + this.outputTags = outputTags; + } + + /** Constructs an instance without side outputs. */ + public MLSessionImpl() {} + + @Override + public List> getOutputTags() { + return outputTags; + } + + @Override + public void setIterationId(int iterationId) { + this.iterationId = iterationId; + } + + @Override + public void setWorldInfo(int workerId, int numWorkers) { + this.workerId = workerId; + this.numWorkers = numWorkers; + } + + @Override + public void setInputData(ResettableIterator inputData) { + this.inputData = (ResettableIterator
) inputData; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java new file mode 100644 index 000000000..de5d5da45 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.util.OutputTag; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** The ML session for machine learning algorithms that adopts mini-batch training. */ +public class MiniBatchMLSession
extends MLSessionImpl
{ + + /** The placeholder for indices to pull for each iteration. */ + public long[] pullIndices; + /** The placeholder for the pulled values for each iteration. */ + public double[] pulledValues; + /** The placeholder for indices to push for each iteration. */ + public long[] pushIndices; + /** The placeholder for values to push for each iteration. */ + public double[] pushValues; + + /** The batch of training data for computing gradients. */ + public List
batchData; + + private ListState
batchDataState; + /** Global batch size. */ + private final int globalBatchSize; + /** The local batch size. */ + private int localBatchSize; + /** Type information of the input data. */ + private final TypeInformation
typeInformation; + + public MiniBatchMLSession(int globalBatchSize, TypeInformation
typeInformation) { + this.globalBatchSize = globalBatchSize; + this.typeInformation = typeInformation; + } + + public MiniBatchMLSession( + int globalBatchSize, + TypeInformation
typeInformation, + List> outputTags) { + super(outputTags); + this.globalBatchSize = globalBatchSize; + this.typeInformation = typeInformation; + } + + @Override + public void setWorldInfo(int workerId, int numWorkers) { + super.setWorldInfo(workerId, numWorkers); + this.localBatchSize = globalBatchSize / numWorkers; + if (globalBatchSize % numWorkers > workerId) { + localBatchSize++; + } + this.batchData = new ArrayList<>(localBatchSize); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + batchDataState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("batchDataState", typeInformation)); + + Iterator
batchDataIterator = batchDataState.get().iterator(); + if (batchDataIterator.hasNext()) { + batchData = IteratorUtils.toList(batchDataIterator); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + batchDataState.clear(); + if (batchData.size() > 0) { + batchDataState.addAll(batchData); + } + } + + /** Reads in next batch of training data. */ + public void readInNextBatchData() throws IOException { + batchData.clear(); + int i = 0; + while (i < localBatchSize && inputData.hasNext()) { + batchData.add(inputData.next()); + i++; + } + if (!inputData.hasNext()) { + inputData.reset(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java new file mode 100644 index 000000000..8a2daa751 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +/** + * A local computation stage of the training process. The input and output of {@link ProcessStage} + * can be accessed via {@link MLSession}. + * + * @param Type of the training data. + */ +public abstract class ProcessStage implements IterationStage { + /** + * Does a local computation logic using the information from context. Example stages could be + * computing gradients. + */ + public abstract void process(T session) throws Exception; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java new file mode 100644 index 000000000..6c8036204 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +/** A collector that can only output using {@link OutputTag}. */ +public final class ProxySideOutput { + private final Output output; + + public ProxySideOutput(Output output) { + this.output = output; + } + + public void output(OutputTag outputTag, StreamRecord record) { + Preconditions.checkNotNull(outputTag); + output.collect(outputTag, record); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java new file mode 100644 index 000000000..fec86d87e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** + * A communication stage that pulls data from servers using keys as {@code + * PullStage#keysSupplier#get()} and stores the pulled values by {@code + * PullStage#valuesConsumer#accept()}. + */ +public final class PullStage implements IterationStage { + public final Supplier keysSupplier; + public final Consumer valuesConsumer; + + public PullStage(Supplier keysSupplier, Consumer valuesConsumer) { + this.keysSupplier = keysSupplier; + this.valuesConsumer = valuesConsumer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java new file mode 100644 index 000000000..814aa5b96 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import java.util.function.Supplier; + +/** + * A communication stage that push (indices, values) to servers. + * + *

Note that the length of the values array must be evenly divisible by the length of the keys + * array. Additionally, each value corresponding to a given key must have the same length. For + * instance, considering the keys {1, 4} and values {1,2,3,4,5,6}, the value at index 1 would be + * {1,2,3}, and the value at index 4 would be {4,5,6}. + */ +public class PushStage implements IterationStage { + public final Supplier keysSupplier; + public final Supplier valuesSupplier; + + public PushStage(Supplier keysSupplier, Supplier valuesSupplier) { + this.keysSupplier = keysSupplier; + this.valuesSupplier = valuesSupplier; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableConsumer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableConsumer.java new file mode 100644 index 000000000..ec09d3de3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableConsumer.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import java.io.Serializable; +import java.util.function.Consumer; + +/** + * A serializable {@link Consumer}. + * + * @param the type of results consumed by this consumer. + */ +public interface SerializableConsumer extends Consumer, Serializable {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java new file mode 100644 index 000000000..cfb327c4a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.ps.ResponseAssemblerOperator; +import org.apache.flink.ml.common.ps.ServerOperator; +import org.apache.flink.ml.common.ps.WorkerOperator; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.util.OutputTag; + +import java.util.ArrayList; +import java.util.List; + +/** Utility function to describe iterative training process. */ +public final class TrainingUtils { + /** + * Executes the training logic described in {@link IterationStageList} and returns the fitted + * model data as well as the outputs from worker operator. The outputs from worker operator are + * specified via {@link MLSession#getOutputTags()}. + * + * @param inputData the input data. + * @param iterationStages the iterative processing logic. + * @param maxKey max value of the key. For example, the maxKey should be the max feature index + * in LogisticRegression. + * @param modelDataType output type information of model data. + * @param modelUpdater the logic to update model on servers. + * @param numServers number of servers. + * @return the fitted model data as well as the outputs from worker operator. The orders are + * {modelData, sideOutputs from workers}. Note that the outputs from workers shares the same + * order with the {@link MLSession#getOutputTags()}. + * @param

type information of input data. + * @param type information of the output model data. + */ + public static DataStreamList train( + DataStream
inputData, + IterationStageList iterationStages, + DataStream maxKey, + TypeInformation modelDataType, + ModelUpdater modelUpdater, + int numServers) { + // TODO: Support incremental training. + + DataStream variableStream = + maxKey.broadcast() + .map( + (MapFunction) + value -> { + byte[] buffer = new byte[Long.BYTES]; + Bits.putLong(buffer, 0, value + 1); + return buffer; + }); + + return Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(variableStream), + ReplayableDataStreamList.notReplay(inputData), + IterationConfig.newBuilder().build(), + new TrainIterationBody<>(modelUpdater, modelDataType, iterationStages, numServers)); + } + + /** The iteration implementation for training process. */ + private static class TrainIterationBody implements IterationBody { + private final ModelUpdater modelUpdater; + + private final TypeInformation modelType; + private final IterationStageList iterationStages; + private final int numServers; + + public TrainIterationBody( + ModelUpdater modelUpdater, + TypeInformation modelType, + IterationStageList iterationStages, + int numServers) { + this.iterationStages = iterationStages; + this.modelType = modelType; + this.modelUpdater = modelUpdater; + this.numServers = numServers; + } + + @Override + @SuppressWarnings("unchecked") + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream variableStream = variableStreams.get(0); + DataStream trainData = dataStreams.get(0); + final OutputTag modelDataOutputTag = new OutputTag<>("MODEL_OUTPUT", modelType); + + SingleOutputStreamOperator> messageToServer = + trainData + .connect(variableStream) + .transform( + "workerNode", + new TupleTypeInfo<>( + Types.INT, + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO), + new WorkerOperator(iterationStages, numServers)) + .name("WorkerOp"); + int numWorkers = messageToServer.getParallelism(); + + SingleOutputStreamOperator> messageToWorker = + messageToServer + .partitionCustom( + (Partitioner) + (key, numPartitions) -> key % numPartitions, + (KeySelector, Integer>) + value -> value.f0) + .transform( + "ServerOp", + new TupleTypeInfo<>( + Types.INT, + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO), + new ServerOperator<>( + iterationStages, + numWorkers, + modelUpdater, + modelDataOutputTag)); + messageToWorker.setParallelism(numServers); + + DataStream combinedMessageToWorker = + messageToWorker + .partitionCustom( + (Partitioner) + (key, numPartitions) -> key % numPartitions, + (KeySelector, Integer>) + value -> value.f0) + .transform( + "MirrorWorkerOp", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + new ResponseAssemblerOperator(numServers)) + .setParallelism(numWorkers); + + DataStream model = messageToWorker.getSideOutput(modelDataOutputTag); + + List> result = new ArrayList<>(); + result.add(model); + + List> sideOutputTags = iterationStages.session.getOutputTags(); + if (sideOutputTags != null) { + for (OutputTag outputTag : sideOutputTags) { + result.add(messageToServer.getSideOutput(outputTag)); + } + } + + return new IterationBodyResult( + DataStreamList.of(combinedMessageToWorker), new DataStreamList(result), null); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java new file mode 100644 index 000000000..2f403e4b3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.updater; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * The FTRL (Follow-the-Regularized-Leader) algorithm is an optimization algorithm used for + * large-scale linear models. It aims to minimize the sum of a loss function over all training + * examples, subject to a regularization constraint. + * + *

FTRL is well-suited for sparse data and can handle problems with billions of features. + */ +public class FTRL implements ModelUpdater> { + private final double alpha; + private final double beta; + private final double lambda1; + private final double lambda2; + + // ------ Model data of FTRL optimizer. ----- + private long startIndex; + private long endIndex; + private double[] weight; + private double[] sigma; + private double[] z; + private double[] n; + + private ListState boundaryState; + private ListState modelDataState; + + public FTRL(double alpha, double beta, double lambda1, double lambda2) { + this.alpha = alpha; + this.beta = beta; + this.lambda1 = lambda1; + this.lambda2 = lambda2; + } + + @Override + public void open(long startKeyIndex, long endKeyIndex) { + this.startIndex = startKeyIndex; + this.endIndex = endKeyIndex; + int modelShardSize = (int) (endIndex - startIndex); + weight = new double[modelShardSize]; + sigma = new double[modelShardSize]; + z = new double[modelShardSize]; + n = new double[modelShardSize]; + } + + @Override + public void update(long[] keys, double[] values) { + for (int i = 0; i < keys.length; i++) { + int index = (int) (keys[i] - startIndex); + double gi = values[i]; + updateModelOnOneDim(gi, index, weight); + } + } + + private void updateModelOnOneDim(double gi, int index, double[] weight) { + double gigi = gi * gi; + sigma[index] = 1 / alpha * (Math.sqrt(n[index] + gigi) - Math.sqrt(n[index])); + z[index] += gi - sigma[index] * weight[index]; + n[index] += gigi; + + if (Math.abs(z[index]) <= lambda1) { + weight[index] = 0; + } else { + weight[index] = + -(z[index] - Math.signum(z[index]) * lambda1) + / ((beta + Math.sqrt(n[index])) / alpha + lambda2); + } + } + + @Override + public double[] get(long[] keys) { + double[] values = new double[keys.length]; + for (int i = 0; i < keys.length; i++) { + values[i] = weight[(int) (keys[i] - startIndex)]; + } + return values; + } + + @Override + public Iterator> getModelSegments() { + List> modelSegments = new ArrayList<>(); + modelSegments.add(Tuple3.of(startIndex, endIndex, weight)); + return modelSegments.iterator(); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + boundaryState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("BoundaryState", Types.LONG)); + + Iterator iterator = boundaryState.get().iterator(); + if (iterator.hasNext()) { + startIndex = iterator.next(); + endIndex = iterator.next(); + } + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelDataState", + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); + Iterator modelData = modelDataState.get().iterator(); + if (modelData.hasNext()) { + weight = modelData.next(); + sigma = modelData.next(); + z = modelData.next(); + n = modelData.next(); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + if (weight != null) { + boundaryState.clear(); + boundaryState.add(startIndex); + boundaryState.add(endIndex); + + modelDataState.clear(); + modelDataState.add(weight); + modelDataState.add(sigma); + modelDataState.add(z); + modelDataState.add(n); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java new file mode 100644 index 000000000..0d7ac3ed4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.updater; + +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A model updater that could be used to update and retrieve model data. + * + *

Note that model updater should also ensure that model data is robust to failures, by writing + * model data to snapshots. + * + * @param data type of model. + */ +public interface ModelUpdater extends Serializable { + + /** Initializes the model data. */ + void open(long startKeyIndex, long endKeyIndex); + + /** Applies the push to update the model data, e.g., using gradient to update model. */ + void update(long[] keys, double[] values); + + /** Retrieves the model data of the given keys. */ + double[] get(long[] keys); + + /** + * Returns model segments. The model segments are continuously updated/retrieved by + * push/pull(i.e., {@link ModelUpdater#update(long[], double[])} and {@link + * ModelUpdater#get(long[])}). + */ + Iterator getModelSegments(); + + /** Recovers the model data from state. */ + void initializeState(StateInitializationContext context) throws Exception; + + /** Snapshots the model data to state. */ + void snapshotState(StateSnapshotContext context) throws Exception; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java index 3d37c1fed..e0da19729 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java @@ -18,22 +18,22 @@ package org.apache.flink.ml.common.util; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import java.util.ArrayList; import java.util.List; -/** Provides utility functions for {@link Vector}. */ +/** Provides utility functions for {@link IntDoubleVector}. */ public class VectorUtils { /** * Selects a subset of the vector base on the indices. Note that the input indices must be * sorted in ascending order. */ - public static Vector selectByIndices(Vector vector, int[] sortedIndices) { - if (vector instanceof DenseVector) { - DenseVector resultVec = new DenseVector(sortedIndices.length); + public static IntDoubleVector selectByIndices(IntDoubleVector vector, int[] sortedIndices) { + if (vector instanceof DenseIntDoubleVector) { + DenseIntDoubleVector resultVec = new DenseIntDoubleVector(sortedIndices.length); for (int i = 0; i < sortedIndices.length; i++) { resultVec.set(i, vector.get(sortedIndices[i])); } @@ -42,18 +42,18 @@ public static Vector selectByIndices(Vector vector, int[] sortedIndices) { List resultIndices = new ArrayList<>(); List resultValues = new ArrayList<>(); - int[] indices = ((SparseVector) vector).indices; + int[] indices = ((SparseIntDoubleVector) vector).indices; for (int i = 0, j = 0; i < indices.length && j < sortedIndices.length; ) { if (indices[i] == sortedIndices[j]) { resultIndices.add(j++); - resultValues.add(((SparseVector) vector).values[i++]); + resultValues.add(((SparseIntDoubleVector) vector).values[i++]); } else if (indices[i] > sortedIndices[j]) { j++; } else { i++; } } - return new SparseVector( + return new SparseIntDoubleVector( sortedIndices.length, resultIndices.stream().mapToInt(Integer::intValue).toArray(), resultValues.stream().mapToDouble(Double::doubleValue).toArray()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java index d74e40b24..97d648f86 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java @@ -35,7 +35,7 @@ import org.apache.flink.ml.api.AlgoOperator; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -659,8 +659,8 @@ public Tuple3 map(Row value) throws Exception { double label = ((Number) value.getFieldAs(labelCol)).doubleValue(); Object probOrigin = value.getField(rawPredictionCol); double prob = - probOrigin instanceof Vector - ? ((Vector) probOrigin).get(1) + probOrigin instanceof IntDoubleVector + ? ((IntDoubleVector) probOrigin).get(1) : ((Number) probOrigin).doubleValue(); double weight = weightCol == null ? 1.0 : ((Number) value.getField(weightCol)).doubleValue(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java index aafbf6e7b..5384fdbc7 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java @@ -24,11 +24,11 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -72,11 +72,11 @@ public Table[] transform(Table... inputs) { for (int i = 0; i < inputCols.length; ++i) { int idx = inputTypeInfo.getFieldIndex(inputCols[i]); Class typeClass = inputTypeInfo.getTypeAt(idx).getTypeClass(); - if (typeClass.equals(SparseVector.class)) { - outputTypes[i] = SparseVectorTypeInfo.INSTANCE; - } else if (typeClass.equals(DenseVector.class)) { - outputTypes[i] = DenseVectorTypeInfo.INSTANCE; - } else if (typeClass.equals(Vector.class)) { + if (typeClass.equals(SparseIntDoubleVector.class)) { + outputTypes[i] = SparseIntDoubleVectorTypeInfo.INSTANCE; + } else if (typeClass.equals(DenseIntDoubleVector.class)) { + outputTypes[i] = DenseIntDoubleVectorTypeInfo.INSTANCE; + } else if (typeClass.equals(IntDoubleVector.class)) { outputTypes[i] = VectorTypeInfo.INSTANCE; } else { outputTypes[i] = Types.DOUBLE; @@ -119,15 +119,15 @@ public Row map(Row input) { } private Object binarizerFunc(Object obj, double threshold) { - if (obj instanceof DenseVector) { - DenseVector inputVec = (DenseVector) obj; - DenseVector vec = inputVec.clone(); + if (obj instanceof DenseIntDoubleVector) { + DenseIntDoubleVector inputVec = (DenseIntDoubleVector) obj; + DenseIntDoubleVector vec = inputVec.clone(); for (int i = 0; i < vec.size(); ++i) { vec.values[i] = inputVec.get(i) > threshold ? 1.0 : 0.0; } return vec; - } else if (obj instanceof SparseVector) { - SparseVector inputVec = (SparseVector) obj; + } else if (obj instanceof SparseIntDoubleVector) { + SparseIntDoubleVector inputVec = (SparseIntDoubleVector) obj; int[] newIndices = new int[inputVec.indices.length]; int pos = 0; @@ -139,7 +139,8 @@ private Object binarizerFunc(Object obj, double threshold) { double[] newValues = new double[pos]; Arrays.fill(newValues, 1.0); - return new SparseVector(inputVec.size(), Arrays.copyOf(newIndices, pos), newValues); + return new SparseIntDoubleVector( + inputVec.size(), Arrays.copyOf(newIndices, pos), newValues); } else { return Double.parseDouble(obj.toString()) > threshold ? 1.0 : 0.0; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java index 390d99762..d3fa93f35 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java @@ -23,9 +23,9 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -107,7 +107,8 @@ public Table[] transform(Table... inputs) { RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll( - inputTypeInfo.getFieldTypes(), SparseVectorTypeInfo.INSTANCE), + inputTypeInfo.getFieldTypes(), + SparseIntDoubleVectorTypeInfo.INSTANCE), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream output = @@ -177,7 +178,7 @@ public Row map(Row row) throws Exception { } } - SparseVector outputVec = + SparseIntDoubleVector outputVec = Vectors.sparse( termCounts.length, indices.stream().mapToInt(i -> i).toArray(), diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/dct/DCT.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/dct/DCT.java index c343cd763..d2c0d2716 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/dct/DCT.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/dct/DCT.java @@ -22,10 +22,10 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -70,7 +70,8 @@ public Table[] transform(Table... inputs) { RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll( - inputTypeInfo.getFieldTypes(), DenseVectorTypeInfo.INSTANCE), + inputTypeInfo.getFieldTypes(), + DenseIntDoubleVectorTypeInfo.INSTANCE), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream stream = @@ -101,7 +102,7 @@ private DCTFunction(String inputCol, boolean isInverse) { @Override public Row map(Row row) throws Exception { - Vector vector = row.getFieldAs(inputCol); + IntDoubleVector vector = row.getFieldAs(inputCol); if (previousVectorSize != vector.size()) { if (isInverse) { @@ -113,7 +114,7 @@ public Row map(Row row) throws Exception { } double[] array = vector.toArray(); - if (vector instanceof DenseVector) { + if (vector instanceof DenseIntDoubleVector) { array = array.clone(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.java index 5a49fb64c..4e9c00fe0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -77,16 +77,16 @@ public Table[] transform(Table... inputs) { private static class ElementwiseProductFunction implements MapFunction { private final String inputCol; - private final Vector scalingVec; + private final IntDoubleVector scalingVec; - public ElementwiseProductFunction(String inputCol, Vector scalingVec) { + public ElementwiseProductFunction(String inputCol, IntDoubleVector scalingVec) { this.inputCol = inputCol; this.scalingVec = scalingVec; } @Override public Row map(Row value) { - Vector inputVec = value.getFieldAs(inputCol); + IntDoubleVector inputVec = value.getFieldAs(inputCol); if (inputVec != null) { if (scalingVec.size() != inputVec.size()) { throw new IllegalArgumentException( @@ -96,7 +96,7 @@ public Row map(Row value) { + inputVec.size() + ")."); } - Vector retVec = inputVec.clone(); + IntDoubleVector retVec = inputVec.clone(); BLAS.hDot(scalingVec, retVec); return Row.join(value, Row.of(retVec)); } else { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProductParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProductParams.java index 4bf612cd1..36aa046d6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProductParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProductParams.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.common.param.HasInputCol; import org.apache.flink.ml.common.param.HasOutputCol; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.ParamValidators; import org.apache.flink.ml.param.VectorParam; @@ -32,18 +32,18 @@ */ public interface ElementwiseProductParams extends HasInputCol, HasOutputCol { - Param SCALING_VEC = + Param SCALING_VEC = new VectorParam( "scalingVec", "The scaling vector to multiply with input vectors using hadamard product.", null, ParamValidators.notNull()); - default Vector getScalingVec() { + default IntDoubleVector getScalingVec() { return get(SCALING_VEC); } - default T setScalingVec(Vector value) { + default T setScalingVec(IntDoubleVector value) { set(SCALING_VEC, value); return (T) this; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java index 3ad50b758..abff9434d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java @@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -137,7 +137,7 @@ public Row map(Row row) { values[pos] = entry.getValue(); pos++; } - return Row.join(row, Row.of(new SparseVector(numFeatures, indices, values))); + return Row.join(row, Row.of(new SparseIntDoubleVector(numFeatures, indices, values))); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java index 392019804..95660e694 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -76,7 +76,8 @@ public Table[] transform(Table... inputs) { RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll( - inputTypeInfo.getFieldTypes(), SparseVectorTypeInfo.INSTANCE), + inputTypeInfo.getFieldTypes(), + SparseIntDoubleVectorTypeInfo.INSTANCE), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream output = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java index 7c2d051c7..0c3e61c5b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java @@ -24,8 +24,9 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -98,7 +99,7 @@ public static IDF load(StreamTableEnvironment tEnv, String path) throws IOExcept /** The main logic to compute the model data of IDF. */ private static class IDFAggregator - implements AggregateFunction, IDFModelData> { + implements AggregateFunction, IDFModelData> { private final int minDocFreq; public IDFAggregator(int minDocFreq) { @@ -106,37 +107,38 @@ public IDFAggregator(int minDocFreq) { } @Override - public Tuple2 createAccumulator() { - return Tuple2.of(0L, new DenseVector(new double[0])); + public Tuple2 createAccumulator() { + return Tuple2.of(0L, new DenseIntDoubleVector(new double[0])); } @Override - public Tuple2 add( - Vector vector, Tuple2 numDocsAndDocFreq) { + public Tuple2 add( + Vector vector, Tuple2 numDocsAndDocFreq) { + IntDoubleVector intDoubleVector = (IntDoubleVector) vector; if (numDocsAndDocFreq.f0 == 0) { - numDocsAndDocFreq.f1 = new DenseVector(vector.size()); + numDocsAndDocFreq.f1 = new DenseIntDoubleVector(intDoubleVector.size()); } numDocsAndDocFreq.f0 += 1L; double[] values; - if (vector instanceof SparseVector) { - values = ((SparseVector) vector).values; + if (vector instanceof SparseIntDoubleVector) { + values = ((SparseIntDoubleVector) vector).values; } else { - values = ((DenseVector) vector).values; + values = ((DenseIntDoubleVector) vector).values; } for (int i = 0; i < values.length; i++) { values[i] = values[i] > 0 ? 1 : 0; } - BLAS.axpy(1, vector, numDocsAndDocFreq.f1); + BLAS.axpy(1, intDoubleVector, numDocsAndDocFreq.f1); return numDocsAndDocFreq; } @Override - public IDFModelData getResult(Tuple2 numDocsAndDocFreq) { + public IDFModelData getResult(Tuple2 numDocsAndDocFreq) { long numDocs = numDocsAndDocFreq.f0; - DenseVector docFreq = numDocsAndDocFreq.f1; + DenseIntDoubleVector docFreq = numDocsAndDocFreq.f1; Preconditions.checkState(numDocs > 0, "The training set is empty."); long[] filteredDocFreq = new long[docFreq.size()]; @@ -152,9 +154,9 @@ public IDFModelData getResult(Tuple2 numDocsAndDocFreq) { } @Override - public Tuple2 merge( - Tuple2 numDocsAndDocFreq1, - Tuple2 numDocsAndDocFreq2) { + public Tuple2 merge( + Tuple2 numDocsAndDocFreq1, + Tuple2 numDocsAndDocFreq2) { if (numDocsAndDocFreq1.f0 == 0) { return numDocsAndDocFreq2; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java index 87a2f254a..42fee1e84 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -125,7 +125,7 @@ public static IDFModel load(StreamTableEnvironment tEnv, String path) throws IOE private static class ComputeTfIdfFunction extends RichMapFunction { private final String inputCol; private final String broadcastKey; - private DenseVector idf; + private DenseIntDoubleVector idf; public ComputeTfIdfFunction(String broadcastKey, String inputCol) { this.broadcastKey = broadcastKey; @@ -141,7 +141,8 @@ public Row map(Row row) { idf = idfModelDataData.idf; } - Vector outputVec = ((Vector) Objects.requireNonNull(row.getField(inputCol))).clone(); + IntDoubleVector outputVec = + ((IntDoubleVector) Objects.requireNonNull(row.getField(inputCol))).clone(); BLAS.hDot(idf, outputVec); return Row.join(row, Row.of(outputVec)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java index a808454e8..3c495968d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java @@ -29,8 +29,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -48,7 +48,7 @@ */ public class IDFModelData { /** Inverse document frequency for all terms. */ - public DenseVector idf; + public DenseIntDoubleVector idf; /** Document frequency for all terms after filtering out infrequent terms. */ public long[] docFreq; /** Number of docs in the training set. */ @@ -56,7 +56,7 @@ public class IDFModelData { public IDFModelData() {} - public IDFModelData(DenseVector idf, long[] docFreq, long numDocs) { + public IDFModelData(DenseIntDoubleVector idf, long[] docFreq, long numDocs) { this.idf = idf; this.docFreq = docFreq; this.numDocs = numDocs; @@ -77,7 +77,8 @@ public static DataStream getModelDataStream(Table modelDataTable) /** Encoder for {@link IDFModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer denseVectorSerializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer denseVectorSerializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(IDFModelData modelData, OutputStream outputStream) throws IOException { @@ -93,14 +94,14 @@ public static class ModelDataDecoder extends SimpleStreamFormat { @Override public Reader createReader(Configuration config, FSDataInputStream stream) { return new Reader() { - private final DenseVectorSerializer denseVectorSerializer = - new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer denseVectorSerializer = + new DenseIntDoubleVectorSerializer(); @Override public IDFModelData read() throws IOException { DataInputView source = new DataInputViewStreamWrapper(stream); try { - DenseVector idf = denseVectorSerializer.deserialize(source); + DenseIntDoubleVector idf = denseVectorSerializer.deserialize(source); long[] docFreq = LongPrimitiveArraySerializer.INSTANCE.deserialize(source); long numDocs = LongSerializer.INSTANCE.deserialize(source); return new IDFModelData(idf, docFreq, numDocs); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/interaction/Interaction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/interaction/Interaction.java index 76e878981..b01910c50 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/interaction/Interaction.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/interaction/Interaction.java @@ -22,9 +22,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; @@ -104,8 +104,8 @@ public Row map(Row value) { return Row.join(value, Row.of((Object) null)); } - if (obj instanceof DenseVector) { - featureSize[i] = ((Vector) obj).size(); + if (obj instanceof DenseIntDoubleVector) { + featureSize[i] = ((IntDoubleVector) obj).size(); if (featureIndices[i] == null || featureIndices[i].length != featureSize[i]) { featureIndices[i] = new int[featureSize[i]]; for (int j = 0; j < featureSize[i]; ++j) { @@ -113,13 +113,13 @@ public Row map(Row value) { } } - featureValues[i] = ((DenseVector) obj).values; + featureValues[i] = ((DenseIntDoubleVector) obj).values; nnz *= featureSize[i]; - } else if (obj instanceof SparseVector) { - featureSize[i] = ((Vector) obj).size(); - featureIndices[i] = ((SparseVector) obj).indices; - featureValues[i] = ((SparseVector) obj).values; - nnz *= ((SparseVector) obj).values.length; + } else if (obj instanceof SparseIntDoubleVector) { + featureSize[i] = ((IntDoubleVector) obj).size(); + featureIndices[i] = ((SparseIntDoubleVector) obj).indices; + featureValues[i] = ((SparseIntDoubleVector) obj).values; + nnz *= ((SparseIntDoubleVector) obj).values.length; hasSparse = true; } else { featureSize[i] = 1; @@ -128,7 +128,7 @@ public Row map(Row value) { } } - Vector ret; + IntDoubleVector ret; int featureIter = inputCols.length - 1; if (hasSparse) { int[] indices = new int[nnz]; @@ -170,7 +170,7 @@ public Row map(Row value) { } idxOffset *= prevValues.length; } - ret = new DenseVector(values); + ret = new DenseIntDoubleVector(values); } return Row.join(value, Row.of(ret)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java index 763a0df22..64b868072 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java @@ -23,8 +23,8 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler.MinMaxReduceFunctionOperator; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -93,13 +93,15 @@ public KBinsDiscretizerModel fit(Table... inputs) { String strategy = getStrategy(); int numBins = getNumBins(); - DataStream inputData = + DataStream inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction) - value -> ((Vector) value.getField(inputCol)).toDense()); + (MapFunction) + value -> + ((IntDoubleVector) value.getField(inputCol)) + .toDense()); - DataStream preprocessedData; + DataStream preprocessedData; if (strategy.equals(UNIFORM)) { preprocessedData = inputData @@ -121,12 +123,13 @@ public KBinsDiscretizerModel fit(Table... inputs) { DataStream modelData = DataStreamUtils.mapPartition( preprocessedData, - new MapPartitionFunction() { + new MapPartitionFunction< + DenseIntDoubleVector, KBinsDiscretizerModelData>() { @Override public void mapPartition( - Iterable iterable, + Iterable iterable, Collector collector) { - List list = new ArrayList<>(); + List list = new ArrayList<>(); iterable.iterator().forEachRemaining(list::add); if (list.size() == 0) { @@ -180,9 +183,9 @@ public static KBinsDiscretizer load(StreamTableEnvironment tEnv, String path) } private static double[][] findBinEdgesWithUniformStrategy( - List input, int numBins) { - DenseVector minVector = input.get(0); - DenseVector maxVector = input.get(1); + List input, int numBins) { + DenseIntDoubleVector minVector = input.get(0); + DenseIntDoubleVector maxVector = input.get(1); int numColumns = minVector.size(); double[][] binEdges = new double[numColumns][]; @@ -206,7 +209,7 @@ private static double[][] findBinEdgesWithUniformStrategy( } private static double[][] findBinEdgesWithQuantileStrategy( - List input, int numBins) { + List input, int numBins) { int numColumns = input.get(0).size(); int numData = input.size(); double[][] binEdges = new double[numColumns][]; @@ -271,7 +274,8 @@ private static double[][] findBinEdgesWithQuantileStrategy( return binEdges; } - private static double[][] findBinEdgesWithKMeansStrategy(List input, int numBins) { + private static double[][] findBinEdgesWithKMeansStrategy( + List input, int numBins) { int numColumns = input.get(0).size(); int numData = input.size(); double[][] binEdges = new double[numColumns][numBins + 1]; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java index 7053e313c..e42b57540 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -80,7 +80,7 @@ public Table[] transform(Table... inputs) { new RowTypeInfo( ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), - TypeInformation.of(DenseVector.class)), + TypeInformation.of(DenseIntDoubleVector.class)), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream output = @@ -149,8 +149,8 @@ public Row map(Row row) { getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); binEdges = modelData.binEdges; } - DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense(); - DenseVector outputVec = inputVec.clone(); + DenseIntDoubleVector inputVec = ((IntDoubleVector) row.getField(inputCol)).toDense(); + DenseIntDoubleVector outputVec = inputVec.clone(); for (int i = 0; i < inputVec.size(); i++) { double targetFeature = inputVec.get(i); int index = Arrays.binarySearch(binEdges[i], targetFeature); @@ -164,7 +164,7 @@ public Row map(Row row) { index = Math.min(index, (binEdges[i].length - 2)); index = Math.max(index, 0); - outputVec.set(i, index); + outputVec.set(i, (double) index); } return Row.join(row, Row.of(outputVec)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java index 80df8c4d6..60493bc91 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -80,7 +80,7 @@ private static DataStream getVectorSize(DataStream input, String v DataStream vectorSizes = input.map( d -> { - Vector vec = d.getFieldAs(vectorCol); + IntDoubleVector vec = d.getFieldAs(vectorCol); return vec.size(); }); return DataStreamUtils.reduce( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java index 66c88ea40..e72141d21 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java @@ -33,9 +33,9 @@ import org.apache.flink.ml.common.datastream.EndOfStreamWindows; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.typeinfo.PriorityQueueTypeInfo; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -108,7 +108,7 @@ public Table[] transform(Table... inputs) { tEnv.toDataStream(modelDataTable, modelDataClass); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); - TypeInformation outputType = TypeInformation.of(DenseVector[].class); + TypeInformation outputType = TypeInformation.of(DenseIntDoubleVector[].class); RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType), @@ -138,7 +138,7 @@ public Table[] transform(Table... inputs) { * @return A dataset containing at most k items closest to the key with a column named `distCol` * appended. */ - public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) { + public Table approxNearestNeighbors(Table dataset, IntDoubleVector key, int k, String distCol) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) dataset).getTableEnvironment(); Table transformedTable = @@ -189,7 +189,7 @@ public Table approxNearestNeighbors(Table dataset, Vector key, int k, String dis * An overloaded version of `approxNearestNeighbors` with "distCol" as default value of * `distCol`. */ - public Table approxNearestNeighbors(Table dataset, Vector key, int k) { + public Table approxNearestNeighbors(Table dataset, IntDoubleVector key, int k) { return approxNearestNeighbors(dataset, key, k, "distCol"); } @@ -308,7 +308,7 @@ private RowTypeInfo getOutputType(Table dataTable, String idCol) { idColType, VectorTypeInfo.INSTANCE, Types.INT, - DenseVectorTypeInfo.INSTANCE + DenseIntDoubleVectorTypeInfo.INSTANCE }, new String[] {idCol, getInputCol(), indexCol, hashValueCol}); return outputTypeInfo; @@ -330,7 +330,7 @@ public Row map(Row value) throws Exception { (LSHModelData) getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0); } - Vector[] hashValues = modelData.hashFunction(value.getFieldAs(inputCol)); + IntDoubleVector[] hashValues = modelData.hashFunction(value.getFieldAs(inputCol)); return Row.join(value, Row.of((Object) hashValues)); } } @@ -338,11 +338,11 @@ public Row map(Row value) throws Exception { private static class FilterByBucketFunction extends RichFlatMapFunction { private final String inputCol; private final String outputCol; - private final Vector key; + private final IntDoubleVector key; private LSHModelData modelData; - private DenseVector[] keyHashes; + private DenseIntDoubleVector[] keyHashes; - public FilterByBucketFunction(String inputCol, String outputCol, Vector key) { + public FilterByBucketFunction(String inputCol, String outputCol, IntDoubleVector key) { this.inputCol = inputCol; this.outputCol = outputCol; this.key = key; @@ -356,7 +356,7 @@ public void flatMap(Row value, Collector out) throws Exception { getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0); keyHashes = modelData.hashFunction(key); } - DenseVector[] hashes = value.getFieldAs(outputCol); + DenseIntDoubleVector[] hashes = value.getFieldAs(outputCol); boolean sameBucket = false; for (int i = 0; i < keyHashes.length; i += 1) { if (keyHashes[i].equals(hashes[i])) { @@ -367,7 +367,7 @@ public void flatMap(Row value, Collector out) throws Exception { if (!sameBucket) { return; } - Vector vec = value.getFieldAs(inputCol); + IntDoubleVector vec = value.getFieldAs(inputCol); double dist = modelData.keyDistance(key, vec); out.collect(Row.join(value, Row.of(dist))); } @@ -447,7 +447,7 @@ public ExplodeHashValuesFunction(String idCol, String inputCol, String outputCol @Override public void flatMap(Row value, Collector out) throws Exception { Row kept = Row.of(value.getField(idCol), value.getField(inputCol)); - DenseVector[] hashValues = value.getFieldAs(outputCol); + DenseIntDoubleVector[] hashValues = value.getFieldAs(outputCol); for (int i = 0; i < hashValues.length; i += 1) { out.collect(Row.join(kept, Row.of(i, hashValues[i]))); } @@ -455,10 +455,10 @@ public void flatMap(Row value, Collector out) throws Exception { } private static class IndexHashValueKeySelector - implements KeySelector> { + implements KeySelector> { @Override - public Tuple2 getKey(Row value) throws Exception { + public Tuple2 getKey(Row value) throws Exception { return Tuple2.of(value.getFieldAs(2), value.getFieldAs(3)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModelData.java index 1b95e1906..a69e5087e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModelData.java @@ -18,8 +18,8 @@ package org.apache.flink.ml.feature.lsh; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; /** * Base class for LSH model data. A concrete class extending this base class should implement how to @@ -33,7 +33,7 @@ abstract class LSHModelData { * @param vec input vector. * @return the mapping of LSH functions. */ - public abstract DenseVector[] hashFunction(Vector vec); + public abstract DenseIntDoubleVector[] hashFunction(IntDoubleVector vec); /** * Calculates the distance between two different feature vectors using the corresponding @@ -43,5 +43,5 @@ abstract class LSHModelData { * @param y One input vector in the metric space. * @return The distance between x and y. */ - public abstract double keyDistance(Vector x, Vector y); + public abstract double keyDistance(IntDoubleVector x, IntDoubleVector y); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java index 7e4e7e012..6217de24c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java @@ -28,8 +28,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.util.Preconditions; import java.io.EOFException; @@ -122,7 +122,7 @@ public TypeInformation getProducedType() { } @Override - public DenseVector[] hashFunction(Vector vec) { + public DenseIntDoubleVector[] hashFunction(IntDoubleVector vec) { int[] indices = vec.toSparse().indices; Preconditions.checkArgument(indices.length > 0, "Must have at least 1 non zero entry."); double[][] hashValues = new double[numHashTables][numHashFunctionsPerTable]; @@ -139,11 +139,13 @@ public DenseVector[] hashFunction(Vector vec) { hashValues[i][j] = minv; } } - return Arrays.stream(hashValues).map(DenseVector::new).toArray(DenseVector[]::new); + return Arrays.stream(hashValues) + .map(DenseIntDoubleVector::new) + .toArray(DenseIntDoubleVector[]::new); } @Override - public double keyDistance(Vector x, Vector y) { + public double keyDistance(IntDoubleVector x, IntDoubleVector y) { int[] xIndices = x.toSparse().indices; int[] yIndices = y.toSparse().indices; Preconditions.checkArgument( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java index 26aa0caa9..d754ea89c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java @@ -23,10 +23,11 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.api.Estimator; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -90,7 +91,7 @@ public MaxAbsScalerModel fit(Table... inputs) { DataStream modelData = maxAbsValues.map( (MapFunction) - vector -> new MaxAbsScalerModelData((DenseVector) vector)); + vector -> new MaxAbsScalerModelData((DenseIntDoubleVector) vector)); MaxAbsScalerModel model = new MaxAbsScalerModel().setModelData(tEnv.fromDataStream(modelData)); @@ -104,8 +105,8 @@ public MaxAbsScalerModel fit(Table... inputs) { */ private static class MaxAbsReduceFunctionOperator extends AbstractStreamOperator implements OneInputStreamOperator, BoundedOneInput { - private ListState maxAbsState; - private DenseVector maxAbsVector; + private ListState maxAbsState; + private DenseIntDoubleVector maxAbsVector; @Override public void endInput() { @@ -116,22 +117,24 @@ public void endInput() { @Override public void processElement(StreamRecord streamRecord) { - Vector currentValue = streamRecord.getValue(); + IntDoubleVector currentValue = (IntDoubleVector) streamRecord.getValue(); maxAbsVector = - maxAbsVector == null ? new DenseVector(currentValue.size()) : maxAbsVector; + maxAbsVector == null + ? new DenseIntDoubleVector(currentValue.size()) + : maxAbsVector; Preconditions.checkArgument( currentValue.size() == maxAbsVector.size(), "The training data should all have same dimensions."); - if (currentValue instanceof DenseVector) { - double[] values = ((DenseVector) currentValue).values; + if (currentValue instanceof DenseIntDoubleVector) { + double[] values = ((DenseIntDoubleVector) currentValue).values; for (int i = 0; i < currentValue.size(); ++i) { maxAbsVector.values[i] = Math.max(maxAbsVector.values[i], Math.abs(values[i])); } - } else if (currentValue instanceof SparseVector) { - int[] indices = ((SparseVector) currentValue).indices; - double[] values = ((SparseVector) currentValue).values; + } else if (currentValue instanceof SparseIntDoubleVector) { + int[] indices = ((SparseIntDoubleVector) currentValue).indices; + double[] values = ((SparseIntDoubleVector) currentValue).values; for (int i = 0; i < indices.length; ++i) { maxAbsVector.values[indices[i]] = @@ -147,7 +150,7 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "maxAbsState", DenseVectorTypeInfo.INSTANCE)); + "maxAbsState", DenseIntDoubleVectorTypeInfo.INSTANCE)); OperatorStateUtils.getUniqueElement(maxAbsState, "maxAbsState") .ifPresent(x -> maxAbsVector = x); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java index 5f5d7e4c8..036f2423d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -132,7 +132,7 @@ public static MaxAbsScalerModel load(StreamTableEnvironment tEnv, String path) private static class PredictOutputFunction extends RichMapFunction { private final String inputCol; private final String broadcastKey; - private DenseVector scaleVector; + private DenseIntDoubleVector scaleVector; public PredictOutputFunction(String broadcastKey, String inputCol) { this.broadcastKey = broadcastKey; @@ -156,8 +156,8 @@ public Row map(Row row) { } } - Vector inputVec = row.getFieldAs(inputCol); - Vector outputVec = inputVec.clone(); + IntDoubleVector inputVec = row.getFieldAs(inputCol); + IntDoubleVector outputVec = inputVec.clone(); BLAS.hDot(scaleVector, outputVec); return Row.join(row, Row.of(outputVec)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java index 4c1e76db5..bac6b3b9f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java @@ -27,8 +27,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -45,11 +45,11 @@ * classes to save/load model data. */ public class MaxAbsScalerModelData { - public DenseVector maxVector; + public DenseIntDoubleVector maxVector; public MaxAbsScalerModelData() {} - public MaxAbsScalerModelData(DenseVector maxVector) { + public MaxAbsScalerModelData(DenseIntDoubleVector maxVector) { this.maxVector = maxVector; } @@ -63,12 +63,13 @@ public static DataStream getModelDataStream(Table modelDa StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); return tEnv.toDataStream(modelDataTable) - .map(x -> new MaxAbsScalerModelData((DenseVector) x.getField(0))); + .map(x -> new MaxAbsScalerModelData((DenseIntDoubleVector) x.getField(0))); } /** Encoder for {@link MaxAbsScalerModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(MaxAbsScalerModelData modelData, OutputStream outputStream) @@ -84,13 +85,14 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration config, FSDataInputStream stream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public MaxAbsScalerModelData read() throws IOException { DataInputView source = new DataInputViewStreamWrapper(stream); try { - DenseVector maxVector = serializer.deserialize(source); + DenseIntDoubleVector maxVector = serializer.deserialize(source); return new MaxAbsScalerModelData(maxVector); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java index d21fd9a79..dec5f44ec 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java @@ -26,8 +26,8 @@ import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -78,12 +78,14 @@ public MinMaxScalerModel fit(Table... inputs) { final String inputCol = getInputCol(); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream inputData = + DataStream inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction) - value -> ((Vector) value.getField(inputCol)).toDense()); - DataStream minMaxValues = + (MapFunction) + value -> + ((IntDoubleVector) value.getField(inputCol)) + .toDense()); + DataStream minMaxValues = inputData .transform( "reduceInEachPartition", @@ -97,14 +99,15 @@ public MinMaxScalerModel fit(Table... inputs) { DataStream modelData = DataStreamUtils.mapPartition( minMaxValues, - new RichMapPartitionFunction() { + new RichMapPartitionFunction< + DenseIntDoubleVector, MinMaxScalerModelData>() { @Override public void mapPartition( - Iterable values, + Iterable values, Collector out) { - Iterator iter = values.iterator(); - DenseVector minVector = iter.next(); - DenseVector maxVector = iter.next(); + Iterator iter = values.iterator(); + DenseIntDoubleVector minVector = iter.next(); + DenseIntDoubleVector maxVector = iter.next(); out.collect(new MinMaxScalerModelData(minVector, maxVector)); } }); @@ -119,13 +122,15 @@ public void mapPartition( * A stream operator to compute the min and max values in each partition of the input bounded * data stream. */ - public static class MinMaxReduceFunctionOperator extends AbstractStreamOperator - implements OneInputStreamOperator, BoundedOneInput { - private ListState minState; - private ListState maxState; + public static class MinMaxReduceFunctionOperator + extends AbstractStreamOperator + implements OneInputStreamOperator, + BoundedOneInput { + private ListState minState; + private ListState maxState; - private DenseVector minVector; - private DenseVector maxVector; + private DenseIntDoubleVector minVector; + private DenseIntDoubleVector maxVector; @Override public void endInput() { @@ -136,12 +141,12 @@ public void endInput() { } @Override - public void processElement(StreamRecord streamRecord) { - DenseVector currentValue = streamRecord.getValue(); + public void processElement(StreamRecord streamRecord) { + DenseIntDoubleVector currentValue = streamRecord.getValue(); if (minVector == null) { int vecSize = currentValue.size(); - minVector = new DenseVector(vecSize); - maxVector = new DenseVector(vecSize); + minVector = new DenseIntDoubleVector(vecSize); + maxVector = new DenseIntDoubleVector(vecSize); System.arraycopy(currentValue.values, 0, minVector.values, 0, vecSize); System.arraycopy(currentValue.values, 0, maxVector.values, 0, vecSize); } else { @@ -163,12 +168,14 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "minState", TypeInformation.of(DenseVector.class))); + "minState", + TypeInformation.of(DenseIntDoubleVector.class))); maxState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "maxState", TypeInformation.of(DenseVector.class))); + "maxState", + TypeInformation.of(DenseIntDoubleVector.class))); OperatorStateUtils.getUniqueElement(minState, "minState").ifPresent(x -> minVector = x); OperatorStateUtils.getUniqueElement(maxState, "maxState").ifPresent(x -> maxVector = x); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java index 858c0f426..78d818db0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -79,7 +79,7 @@ public Table[] transform(Table... inputs) { new RowTypeInfo( ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), - TypeInformation.of(DenseVector.class)), + TypeInformation.of(DenseIntDoubleVector.class)), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream output = BroadcastUtils.withBroadcastStream( @@ -131,8 +131,8 @@ private static class PredictOutputFunction extends RichMapFunction { private final String broadcastKey; private final double upperBound; private final double lowerBound; - private DenseVector scaleVector; - private DenseVector offsetVector; + private DenseIntDoubleVector scaleVector; + private DenseIntDoubleVector offsetVector; public PredictOutputFunction( String broadcastKey, double upperBound, double lowerBound, String inputCol) { @@ -148,10 +148,10 @@ public Row map(Row row) { MinMaxScalerModelData minMaxScalerModelData = (MinMaxScalerModelData) getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); - DenseVector minVector = minMaxScalerModelData.minVector; - DenseVector maxVector = minMaxScalerModelData.maxVector; - scaleVector = new DenseVector(minVector.size()); - offsetVector = new DenseVector(minVector.size()); + DenseIntDoubleVector minVector = minMaxScalerModelData.minVector; + DenseIntDoubleVector maxVector = minMaxScalerModelData.maxVector; + scaleVector = new DenseIntDoubleVector(minVector.size()); + offsetVector = new DenseIntDoubleVector(minVector.size()); for (int i = 0; i < maxVector.size(); ++i) { if (Math.abs(minVector.values[i] - maxVector.values[i]) < 1.0e-5) { scaleVector.values[i] = 0.0; @@ -165,8 +165,8 @@ public Row map(Row row) { } } } - DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense(); - DenseVector outputVec = new DenseVector(scaleVector.size()); + DenseIntDoubleVector inputVec = ((IntDoubleVector) row.getField(inputCol)).toDense(); + DenseIntDoubleVector outputVec = new DenseIntDoubleVector(scaleVector.size()); for (int i = 0; i < scaleVector.size(); ++i) { outputVec.values[i] = inputVec.values[i] * scaleVector.values[i] + offsetVector.values[i]; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java index 451406e46..3e18b76f5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java @@ -27,8 +27,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -45,13 +45,13 @@ * classes to save/load model data. */ public class MinMaxScalerModelData { - public DenseVector minVector; + public DenseIntDoubleVector minVector; - public DenseVector maxVector; + public DenseIntDoubleVector maxVector; public MinMaxScalerModelData() {} - public MinMaxScalerModelData(DenseVector minVector, DenseVector maxVector) { + public MinMaxScalerModelData(DenseIntDoubleVector minVector, DenseIntDoubleVector maxVector) { this.minVector = minVector; this.maxVector = maxVector; } @@ -69,12 +69,14 @@ public static DataStream getModelDataStream(Table modelDa .map( x -> new MinMaxScalerModelData( - (DenseVector) x.getField(0), (DenseVector) x.getField(1))); + (DenseIntDoubleVector) x.getField(0), + (DenseIntDoubleVector) x.getField(1))); } /** Encoder for {@link MinMaxScalerModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(MinMaxScalerModelData modelData, OutputStream outputStream) @@ -91,14 +93,15 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration config, FSDataInputStream stream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public MinMaxScalerModelData read() throws IOException { DataInputView source = new DataInputViewStreamWrapper(stream); try { - DenseVector minVector = serializer.deserialize(source); - DenseVector maxVector = serializer.deserialize(source); + DenseIntDoubleVector minVector = serializer.deserialize(source); + DenseIntDoubleVector maxVector = serializer.deserialize(source); return new MinMaxScalerModelData(minVector, maxVector); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/normalizer/Normalizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/normalizer/Normalizer.java index a9034ae3a..c3ed94fe4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/normalizer/Normalizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/normalizer/Normalizer.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -95,8 +95,8 @@ public NormalizationFunction(double p, String inputCol) { @Override public Row map(Row row) throws Exception { - Vector inputVec = row.getFieldAs(inputCol); - Vector outputVec = inputVec.clone(); + IntDoubleVector inputVec = row.getFieldAs(inputCol); + IntDoubleVector outputVec = inputVec.clone(); double norm = BLAS.norm(inputVec, p); BLAS.scal(1.0 / norm, outputVec); return Row.join(row, Row.of(outputVec)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java index aab391bc1..d2aefb2a2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java @@ -27,7 +27,7 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -80,7 +80,8 @@ public Table[] transform(Table... inputs) { ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), Collections.nCopies( - outputCols.length, SparseVectorTypeInfo.INSTANCE) + outputCols.length, + SparseIntDoubleVectorTypeInfo.INSTANCE) .toArray(new TypeInformation[0])), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java index 86fd2eaad..68080787f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java @@ -23,9 +23,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -126,18 +126,19 @@ public PolynomialExpansionFunction(int degree, String inputCol) { @Override public Row map(Row row) throws Exception { - Vector vec = row.getFieldAs(inputCol); + IntDoubleVector vec = row.getFieldAs(inputCol); if (vec == null) { throw new IllegalArgumentException("The vector must not be null."); } - Vector outputVec; - if (vec instanceof DenseVector) { + IntDoubleVector outputVec; + if (vec instanceof DenseIntDoubleVector) { int size = vec.size(); double[] retVals = new double[getResultVectorSize(size, degree) - 1]; - expandDenseVector(((DenseVector) vec).values, size - 1, degree, 1.0, retVals, -1); - outputVec = new DenseVector(retVals); - } else if (vec instanceof SparseVector) { - SparseVector sparseVec = (SparseVector) vec; + expandDenseVector( + ((DenseIntDoubleVector) vec).values, size - 1, degree, 1.0, retVals, -1); + outputVec = new DenseIntDoubleVector(retVals); + } else if (vec instanceof SparseIntDoubleVector) { + SparseIntDoubleVector sparseVec = (SparseIntDoubleVector) vec; int[] indices = sparseVec.indices; double[] values = sparseVec.values; int size = sparseVec.size(); @@ -158,7 +159,7 @@ public Row map(Row row) throws Exception { -1); outputVec = - new SparseVector( + new SparseIntDoubleVector( getResultVectorSize(size, degree) - 1, polyIndices.f1, polyValues.f1); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java index cb6890811..8fc28e796 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java @@ -23,8 +23,8 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.util.QuantileSummary; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -73,11 +73,13 @@ public RobustScalerModel fit(Table... inputs) { final String inputCol = getInputCol(); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream inputData = + DataStream inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction) - value -> ((Vector) value.getField(inputCol)).toDense()); + (MapFunction) + value -> + ((IntDoubleVector) value.getField(inputCol)) + .toDense()); DataStream modelData = DataStreamUtils.aggregate( inputData, @@ -93,7 +95,8 @@ public RobustScalerModel fit(Table... inputs) { * RobustScalerModelData}. */ private static class QuantileAggregator - implements AggregateFunction { + implements AggregateFunction< + DenseIntDoubleVector, QuantileSummary[], RobustScalerModelData> { private final double relativeError; private final double lower; @@ -111,7 +114,8 @@ public QuantileSummary[] createAccumulator() { } @Override - public QuantileSummary[] add(DenseVector denseVector, QuantileSummary[] quantileSummaries) { + public QuantileSummary[] add( + DenseIntDoubleVector denseVector, QuantileSummary[] quantileSummaries) { if (quantileSummaries.length == 0) { quantileSummaries = new QuantileSummary[denseVector.size()]; for (int i = 0; i < denseVector.size(); i++) { @@ -136,8 +140,8 @@ public QuantileSummary[] add(DenseVector denseVector, QuantileSummary[] quantile @Override public RobustScalerModelData getResult(QuantileSummary[] quantileSummaries) { Preconditions.checkState(quantileSummaries.length != 0, "The training set is empty."); - DenseVector medianVector = new DenseVector(quantileSummaries.length); - DenseVector rangeVector = new DenseVector(quantileSummaries.length); + DenseIntDoubleVector medianVector = new DenseIntDoubleVector(quantileSummaries.length); + DenseIntDoubleVector rangeVector = new DenseIntDoubleVector(quantileSummaries.length); for (int i = 0; i < quantileSummaries.length; i++) { QuantileSummary compressed = quantileSummaries[i].compress(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java index deda6e339..b6239527d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -98,8 +98,8 @@ private static class PredictOutputFunction extends RichMapFunction { private final boolean withCentering; private final boolean withScaling; - private DenseVector medians; - private DenseVector scales; + private DenseIntDoubleVector medians; + private DenseIntDoubleVector scales; public PredictOutputFunction( String broadcastModelKey, @@ -120,12 +120,13 @@ public Row map(Row row) throws Exception { getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); medians = modelData.medians; scales = - new DenseVector( + new DenseIntDoubleVector( Arrays.stream(modelData.ranges.values) .map(range -> range == 0 ? 0 : 1 / range) .toArray()); } - DenseVector outputVec = ((Vector) row.getField(inputCol)).clone().toDense(); + DenseIntDoubleVector outputVec = + ((IntDoubleVector) row.getField(inputCol)).clone().toDense(); Preconditions.checkState( medians.size() == outputVec.size(), "Number of features must be %s but got %s.", diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java index 807fe2497..253dd8f1b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java @@ -25,8 +25,8 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -43,13 +43,13 @@ * classes to save/load model data. */ public class RobustScalerModelData { - public DenseVector medians; + public DenseIntDoubleVector medians; - public DenseVector ranges; + public DenseIntDoubleVector ranges; public RobustScalerModelData() {} - public RobustScalerModelData(DenseVector medians, DenseVector ranges) { + public RobustScalerModelData(DenseIntDoubleVector medians, DenseIntDoubleVector ranges) { this.medians = medians; this.ranges = ranges; } @@ -67,13 +67,14 @@ public static DataStream getModelDataStream(Table modelDa .map( x -> new RobustScalerModelData( - (DenseVector) x.getField("medians"), - (DenseVector) x.getField("ranges"))); + (DenseIntDoubleVector) x.getField("medians"), + (DenseIntDoubleVector) x.getField("ranges"))); } /** Data encoder for the {@link RobustScalerModel} model data. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(RobustScalerModelData modelData, OutputStream outputStream) @@ -92,15 +93,18 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration configuration, FSDataInputStream inputStream) throws IOException { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public RobustScalerModelData read() throws IOException { DataInputViewStreamWrapper inputViewStreamWrapper = new DataInputViewStreamWrapper(inputStream); try { - DenseVector medians = serializer.deserialize(inputViewStreamWrapper); - DenseVector ranges = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector medians = + serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector ranges = + serializer.deserialize(inputViewStreamWrapper); return new RobustScalerModelData(medians, ranges); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java index 5c7d44748..c4ff95f0f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java @@ -28,10 +28,10 @@ import org.apache.flink.ml.common.window.EventTimeTumblingWindows; import org.apache.flink.ml.common.window.Windows; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -125,16 +125,17 @@ public void process( Iterable iterable, Collector collector) throws Exception { - ListState sumState = + ListState sumState = context.globalState() .getListState( new ListStateDescriptor<>( - "sumState", DenseVectorTypeInfo.INSTANCE)); - ListState squaredSumState = + "sumState", DenseIntDoubleVectorTypeInfo.INSTANCE)); + ListState squaredSumState = context.globalState() .getListState( new ListStateDescriptor<>( - "squaredSumState", DenseVectorTypeInfo.INSTANCE)); + "squaredSumState", + DenseIntDoubleVectorTypeInfo.INSTANCE)); ListState numElementsState = context.globalState() .getListState( @@ -143,9 +144,9 @@ public void process( context.globalState() .getListState( new ListStateDescriptor<>("modelVersionState", Types.LONG)); - DenseVector sum = + DenseIntDoubleVector sum = OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null); - DenseVector squaredSum = + DenseIntDoubleVector squaredSum = OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState") .orElse(null); long numElements = @@ -157,11 +158,12 @@ public void process( long numElementsBefore = numElements; for (Row element : iterable) { - Vector inputVec = - ((Vector) Objects.requireNonNull(element.getField(inputCol))).clone(); + IntDoubleVector inputVec = + ((IntDoubleVector) Objects.requireNonNull(element.getField(inputCol))) + .clone(); if (numElements == 0) { - sum = new DenseVector(inputVec.size()); - squaredSum = new DenseVector(inputVec.size()); + sum = new DenseIntDoubleVector(inputVec.size()); + squaredSum = new DenseIntDoubleVector(inputVec.size()); } BLAS.axpy(1, inputVec, sum); BLAS.hDot(inputVec, inputVec); @@ -190,8 +192,8 @@ public void process( private static StandardScalerModelData buildModelData( long numElements, - DenseVector sum, - DenseVector squaredSum, + DenseIntDoubleVector sum, + DenseIntDoubleVector squaredSum, long modelVersion, long currentTimeStamp) { BLAS.scal(1.0 / numElements, sum); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java index a491e6866..0fae000c6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java @@ -30,8 +30,8 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.metrics.MLMetrics; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -135,10 +135,10 @@ private static class PredictionOperator extends AbstractStreamOperator /** Model data for inference. */ private StandardScalerModelData modelData; - private DenseVector mean; + private DenseIntDoubleVector mean; /** Inverse of standard deviation. */ - private DenseVector scale; + private DenseIntDoubleVector scale; private long modelVersion; @@ -249,7 +249,7 @@ private void initializeModelData(StandardScalerModelData modelData) { modelTimeStamp = modelData.timestamp; modelVersion = modelData.version; mean = modelData.mean; - DenseVector std = modelData.std; + DenseIntDoubleVector std = modelData.std; if (withStd) { scale = std; @@ -263,11 +263,12 @@ private void initializeModelData(StandardScalerModelData modelData) { private void doPrediction(StreamRecord streamRecord) { Row dataPoint = streamRecord.getValue(); - Vector outputVec = - ((Vector) (Objects.requireNonNull(dataPoint.getField(inputCol)))).clone(); + IntDoubleVector outputVec = + ((IntDoubleVector) (Objects.requireNonNull(dataPoint.getField(inputCol)))) + .clone(); if (withMean) { outputVec = outputVec.toDense(); - BLAS.axpy(-1, mean, (DenseVector) outputVec); + BLAS.axpy(-1, mean, (DenseIntDoubleVector) outputVec); } if (withStd) { BLAS.hDot(scale, outputVec); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java index 59f519f28..e4ddc338f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java @@ -27,8 +27,8 @@ import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -76,15 +76,16 @@ public StandardScalerModel fit(Table... inputs) { Preconditions.checkArgument(inputs.length == 1); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream> sumAndSquaredSumAndWeight = - tEnv.toDataStream(inputs[0]) - .transform( - "computeMeta", - new TupleTypeInfo<>( - TypeInformation.of(DenseVector.class), - TypeInformation.of(DenseVector.class), - BasicTypeInfo.LONG_TYPE_INFO), - new ComputeMetaOperator(getInputCol())); + DataStream> + sumAndSquaredSumAndWeight = + tEnv.toDataStream(inputs[0]) + .transform( + "computeMeta", + new TupleTypeInfo<>( + TypeInformation.of(DenseIntDoubleVector.class), + TypeInformation.of(DenseIntDoubleVector.class), + BasicTypeInfo.LONG_TYPE_INFO), + new ComputeMetaOperator(getInputCol())); DataStream modelData = sumAndSquaredSumAndWeight @@ -105,13 +106,14 @@ public StandardScalerModel fit(Table... inputs) { */ private static class BuildModelOperator extends AbstractStreamOperator implements OneInputStreamOperator< - Tuple3, StandardScalerModelData>, + Tuple3, + StandardScalerModelData>, BoundedOneInput { - private ListState sumState; - private ListState squaredSumState; + private ListState sumState; + private ListState squaredSumState; private ListState numElementsState; - private DenseVector sum; - private DenseVector squaredSum; + private DenseIntDoubleVector sum; + private DenseIntDoubleVector squaredSum; private long numElements; @Override @@ -141,8 +143,9 @@ public void endInput() { } @Override - public void processElement(StreamRecord> element) { - Tuple3 value = element.getValue(); + public void processElement( + StreamRecord> element) { + Tuple3 value = element.getValue(); if (numElements == 0) { sum = value.f0; squaredSum = value.f1; @@ -161,13 +164,14 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "sumState", TypeInformation.of(DenseVector.class))); + "sumState", + TypeInformation.of(DenseIntDoubleVector.class))); squaredSumState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( "squaredSumState", - TypeInformation.of(DenseVector.class))); + TypeInformation.of(DenseIntDoubleVector.class))); numElementsState = context.getOperatorStateStore() .getListState( @@ -196,14 +200,15 @@ public void snapshotState(StateSnapshotContext context) throws Exception { /** Computes sum, squared sum and number of elements in each partition. */ private static class ComputeMetaOperator - extends AbstractStreamOperator> - implements OneInputStreamOperator>, + extends AbstractStreamOperator> + implements OneInputStreamOperator< + Row, Tuple3>, BoundedOneInput { - private ListState sumState; - private ListState squaredSumState; + private ListState sumState; + private ListState squaredSumState; private ListState numElementsState; - private DenseVector sum; - private DenseVector squaredSum; + private DenseIntDoubleVector sum; + private DenseIntDoubleVector squaredSum; private long numElements; private final String inputCol; @@ -221,10 +226,10 @@ public void endInput() { @Override public void processElement(StreamRecord element) { - Vector inputVec = (Vector) element.getValue().getField(inputCol); + IntDoubleVector inputVec = (IntDoubleVector) element.getValue().getField(inputCol); if (numElements == 0) { - sum = new DenseVector(inputVec.size()); - squaredSum = new DenseVector(inputVec.size()); + sum = new DenseIntDoubleVector(inputVec.size()); + squaredSum = new DenseIntDoubleVector(inputVec.size()); } BLAS.axpy(1, inputVec, sum); BLAS.hDot(inputVec, inputVec); @@ -239,13 +244,14 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "sumState", TypeInformation.of(DenseVector.class))); + "sumState", + TypeInformation.of(DenseIntDoubleVector.class))); squaredSumState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( "squaredSumState", - TypeInformation.of(DenseVector.class))); + TypeInformation.of(DenseIntDoubleVector.class))); numElementsState = context.getOperatorStateStore() .getListState( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java index c3d31c83f..69dd42790 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -96,8 +96,8 @@ private static class PredictOutputFunction extends RichMapFunction { private final String inputCol; private final boolean withMean; private final boolean withStd; - private DenseVector mean; - private DenseVector scale; + private DenseIntDoubleVector mean; + private DenseIntDoubleVector scale; public PredictOutputFunction( String broadcastModelKey, String inputCol, boolean withMean, boolean withStd) { @@ -114,7 +114,7 @@ public Row map(Row dataPoint) { (StandardScalerModelData) getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); mean = modelData.mean; - DenseVector std = modelData.std; + DenseIntDoubleVector std = modelData.std; if (withStd) { scale = std; @@ -125,10 +125,10 @@ public Row map(Row dataPoint) { } } - Vector outputVec = ((Vector) (dataPoint.getField(inputCol))).clone(); + IntDoubleVector outputVec = ((IntDoubleVector) (dataPoint.getField(inputCol))).clone(); if (withMean) { outputVec = outputVec.toDense(); - BLAS.axpy(-1, mean, (DenseVector) outputVec); + BLAS.axpy(-1, mean, (DenseIntDoubleVector) outputVec); } if (withStd) { BLAS.hDot(scale, outputVec); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModelData.java index fcaa2986d..01085b2dd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModelData.java @@ -27,8 +27,8 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -47,9 +47,9 @@ */ public class StandardScalerModelData { /** Mean of each dimension. */ - public DenseVector mean; + public DenseIntDoubleVector mean; /** Standard deviation of each dimension. */ - public DenseVector std; + public DenseIntDoubleVector std; /** Model version. */ public long version; /** Model timestamp. */ @@ -57,12 +57,12 @@ public class StandardScalerModelData { public StandardScalerModelData() {} - public StandardScalerModelData(DenseVector mean, DenseVector std) { + public StandardScalerModelData(DenseIntDoubleVector mean, DenseIntDoubleVector std) { this(mean, std, 0, Long.MAX_VALUE); } public StandardScalerModelData( - DenseVector mean, DenseVector std, long version, long timestamp) { + DenseIntDoubleVector mean, DenseIntDoubleVector std, long version, long timestamp) { this.mean = mean; this.std = std; this.version = version; @@ -93,7 +93,8 @@ public static DataStream getModelDataStream(Table model /** Data encoder for the {@link StandardScalerModel} model data. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(StandardScalerModelData modelData, OutputStream outputStream) @@ -114,7 +115,8 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration configuration, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public StandardScalerModelData read() throws IOException { @@ -122,8 +124,8 @@ public StandardScalerModelData read() throws IOException { new DataInputViewStreamWrapper(inputStream); try { - DenseVector mean = serializer.deserialize(inputViewStreamWrapper); - DenseVector std = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector mean = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector std = serializer.deserialize(inputViewStreamWrapper); long version = LongSerializer.INSTANCE.deserialize(inputViewStreamWrapper); long timestamp = LongSerializer.INSTANCE.deserialize(inputViewStreamWrapper); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java index f5acf7a05..a0637a157 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java @@ -24,7 +24,7 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.util.VectorUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; @@ -127,14 +127,14 @@ public Row map(Row row) { if (indices.length == 0) { return Row.join(row, Row.of(Vectors.dense())); } else { - Vector inputVec = ((Vector) row.getField(inputCol)); + IntDoubleVector inputVec = ((IntDoubleVector) row.getField(inputCol)); Preconditions.checkArgument( inputVec.size() > indices[indices.length - 1], "Input %s features, but UnivariateFeatureSelector is " + "expecting at least %s features as input.", inputVec.size(), indices[indices.length - 1] + 1); - Vector outputVec = VectorUtils.selectByIndices(inputVec, indices); + IntDoubleVector outputVec = VectorUtils.selectByIndices(inputVec, indices); return Row.join(row, Row.of(outputVec)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java index 621415f07..e7b83a550 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java @@ -26,9 +26,10 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -80,8 +81,8 @@ public VarianceThresholdSelectorModel fit(Table... inputs) { new VarianceThresholdSelectorAggregator(getVarianceThreshold()), Types.TUPLE( Types.LONG, - DenseVectorTypeInfo.INSTANCE, - DenseVectorTypeInfo.INSTANCE), + DenseIntDoubleVectorTypeInfo.INSTANCE, + DenseIntDoubleVectorTypeInfo.INSTANCE), TypeInformation.of(VarianceThresholdSelectorModelData.class)); VarianceThresholdSelectorModel model = @@ -97,7 +98,7 @@ public VarianceThresholdSelectorModel fit(Table... inputs) { private static class VarianceThresholdSelectorAggregator implements AggregateFunction< Vector, - Tuple3, + Tuple3, VarianceThresholdSelectorModelData> { private final double varianceThreshold; @@ -107,31 +108,36 @@ public VarianceThresholdSelectorAggregator(double varianceThreshold) { } @Override - public Tuple3 createAccumulator() { - return Tuple3.of(0L, new DenseVector(new double[0]), new DenseVector(new double[0])); + public Tuple3 createAccumulator() { + return Tuple3.of( + 0L, + new DenseIntDoubleVector(new double[0]), + new DenseIntDoubleVector(new double[0])); } @Override - public Tuple3 add( - Vector vector, Tuple3 numAndSums) { + public Tuple3 add( + Vector vector, + Tuple3 numAndSums) { + IntDoubleVector intDoubleVector = (IntDoubleVector) vector; if (numAndSums.f0 == 0) { - numAndSums.f1 = new DenseVector(vector.size()); - numAndSums.f2 = new DenseVector(vector.size()); + numAndSums.f1 = new DenseIntDoubleVector(intDoubleVector.size()); + numAndSums.f2 = new DenseIntDoubleVector(intDoubleVector.size()); } numAndSums.f0 += 1L; - BLAS.axpy(1.0, vector, numAndSums.f1); - for (int i = 0; i < vector.size(); i++) { - numAndSums.f2.values[i] += vector.get(i) * vector.get(i); + BLAS.axpy(1.0, intDoubleVector, numAndSums.f1); + for (int i = 0; i < intDoubleVector.size(); i++) { + numAndSums.f2.values[i] += intDoubleVector.get(i) * intDoubleVector.get(i); } return numAndSums; } @Override public VarianceThresholdSelectorModelData getResult( - Tuple3 numAndSums) { + Tuple3 numAndSums) { long numRows = numAndSums.f0; - DenseVector sumVector = numAndSums.f1; - DenseVector squareSumVector = numAndSums.f2; + DenseIntDoubleVector sumVector = numAndSums.f1; + DenseIntDoubleVector squareSumVector = numAndSums.f2; Preconditions.checkState(numRows > 0, "The training set is empty."); int[] indices = @@ -149,9 +155,9 @@ public VarianceThresholdSelectorModelData getResult( } @Override - public Tuple3 merge( - Tuple3 numAndSums1, - Tuple3 acc) { + public Tuple3 merge( + Tuple3 numAndSums1, + Tuple3 acc) { if (numAndSums1.f0 == 0) { return acc; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java index f042c9f7d..6654547ce 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java @@ -24,7 +24,7 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.util.VectorUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; @@ -152,7 +152,7 @@ public Row map(Row row) { .toArray(); } - Vector inputVec = ((Vector) row.getField(inputCol)); + IntDoubleVector inputVec = ((IntDoubleVector) row.getField(inputCol)); Preconditions.checkArgument( inputVec.size() == expectedNumOfFeatures, "%s has %s features, but VarianceThresholdSelector is expecting %s features as input.", @@ -162,7 +162,7 @@ public Row map(Row row) { if (indices.length == 0) { return Row.join(row, Row.of(Vectors.dense())); } else { - Vector outputVec = VectorUtils.selectByIndices(inputVec, indices); + IntDoubleVector outputVec = VectorUtils.selectByIndices(inputVec, indices); return Row.join(row, Row.of(outputVec)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java index e951f80e9..159cf8433 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java @@ -24,9 +24,9 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; @@ -112,7 +112,7 @@ public void flatMap(Row value, Collector out) { Tuple2 vectorSizeAndNnz = computeVectorSizeAndNnz(value); int vectorSize = vectorSizeAndNnz.f0; int nnz = vectorSizeAndNnz.f1; - Vector assembledVec = + IntDoubleVector assembledVec = nnz * RATIO > vectorSize ? assembleDense(inputCols, value, vectorSize) : assembleSparse(inputCols, value, vectorSize, nnz); @@ -139,16 +139,16 @@ private Tuple2 computeVectorSizeAndNnz(Row value) { } vectorSize += 1; nnz += 1; - } else if (object instanceof SparseVector) { - int localSize = ((SparseVector) object).size(); + } else if (object instanceof SparseIntDoubleVector) { + int localSize = ((SparseIntDoubleVector) object).size(); checkSize(inputSizes[i], localSize); - nnz += ((SparseVector) object).indices.length; + nnz += ((SparseIntDoubleVector) object).indices.length; vectorSize += localSize; - } else if (object instanceof DenseVector) { - int localSize = ((DenseVector) object).size(); + } else if (object instanceof DenseIntDoubleVector) { + int localSize = ((DenseIntDoubleVector) object).size(); checkSize(inputSizes[i], localSize); vectorSize += localSize; - nnz += ((DenseVector) object).size(); + nnz += ((DenseIntDoubleVector) object).size(); } else { throw new IllegalArgumentException( String.format( @@ -160,7 +160,7 @@ private Tuple2 computeVectorSizeAndNnz(Row value) { nnz += inputSizes[i]; if (keepInvalid) { if (inputSizes[i] > 1) { - DenseVector tmpVec = new DenseVector(inputSizes[i]); + DenseIntDoubleVector tmpVec = new DenseIntDoubleVector(inputSizes[i]); for (int j = 0; j < inputSizes[i]; ++j) { tmpVec.values[j] = Double.NaN; } @@ -205,7 +205,7 @@ public Map, Object> getParamMap() { } /** Assembles the input columns into a dense vector. */ - private static Vector assembleDense(String[] inputCols, Row inputRow, int vectorSize) { + private static IntDoubleVector assembleDense(String[] inputCols, Row inputRow, int vectorSize) { double[] values = new double[vectorSize]; int currentOffset = 0; @@ -213,15 +213,15 @@ private static Vector assembleDense(String[] inputCols, Row inputRow, int vector Object object = inputRow.getField(inputCol); if (object instanceof Number) { values[currentOffset++] = ((Number) object).doubleValue(); - } else if (object instanceof SparseVector) { - SparseVector sparseVector = (SparseVector) object; + } else if (object instanceof SparseIntDoubleVector) { + SparseIntDoubleVector sparseVector = (SparseIntDoubleVector) object; for (int i = 0; i < sparseVector.indices.length; i++) { values[currentOffset + sparseVector.indices[i]] = sparseVector.values[i]; } currentOffset += sparseVector.size(); } else { - DenseVector denseVector = (DenseVector) object; + DenseIntDoubleVector denseVector = (DenseIntDoubleVector) object; System.arraycopy(denseVector.values, 0, values, currentOffset, denseVector.size()); currentOffset += denseVector.size(); @@ -231,7 +231,7 @@ private static Vector assembleDense(String[] inputCols, Row inputRow, int vector } /** Assembles the input columns into a sparse vector. */ - private static Vector assembleSparse( + private static IntDoubleVector assembleSparse( String[] inputCols, Row inputRow, int vectorSize, int nnz) { int[] indices = new int[nnz]; double[] values = new double[nnz]; @@ -246,8 +246,8 @@ private static Vector assembleSparse( values[currentOffset] = ((Number) object).doubleValue(); currentOffset++; currentIndex++; - } else if (object instanceof SparseVector) { - SparseVector sparseVector = (SparseVector) object; + } else if (object instanceof SparseIntDoubleVector) { + SparseIntDoubleVector sparseVector = (SparseIntDoubleVector) object; for (int i = 0; i < sparseVector.indices.length; i++) { indices[currentOffset + i] = sparseVector.indices[i] + currentIndex; } @@ -256,7 +256,7 @@ private static Vector assembleSparse( currentIndex += sparseVector.size(); currentOffset += sparseVector.indices.length; } else { - DenseVector denseVector = (DenseVector) object; + DenseIntDoubleVector denseVector = (DenseIntDoubleVector) object; for (int i = 0; i < denseVector.size(); ++i) { indices[currentOffset + i] = i + currentIndex; } @@ -266,6 +266,6 @@ private static Vector assembleSparse( currentOffset += denseVector.size(); } } - return new SparseVector(vectorSize, indices, values); + return new SparseIntDoubleVector(vectorSize, indices, values); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java index 0bf4c7e57..3486adc26 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java @@ -27,7 +27,7 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -185,14 +185,14 @@ public void endInput() { public void processElement(StreamRecord element) { if (doublesByColumn == null) { // First record. - Vector vector = (Vector) element.getValue().getField(inputCol); + IntDoubleVector vector = (IntDoubleVector) element.getValue().getField(inputCol); doublesByColumn = new HashSet[vector.size()]; for (int i = 0; i < doublesByColumn.length; i++) { doublesByColumn[i] = new HashSet<>(); } } - Vector vector = (Vector) element.getValue().getField(inputCol); + IntDoubleVector vector = (IntDoubleVector) element.getValue().getField(inputCol); Preconditions.checkState( vector.size() == doublesByColumn.length, "The size of the all input vectors should be the same."); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java index 3becd230d..f5ed91819 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java @@ -24,7 +24,7 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -149,7 +149,7 @@ public void flatMap(Row input, Collector out) { categoryMaps = modelData.categoryMaps; } - Vector outputVector = ((Vector) input.getField(inputCol)).clone(); + IntDoubleVector outputVector = ((IntDoubleVector) input.getField(inputCol)).clone(); for (Map.Entry> entry : categoryMaps.entrySet()) { int columnId = entry.getKey(); Map mapping = entry.getValue(); @@ -158,7 +158,7 @@ public void flatMap(Row input, Collector out) { if (categoricalFeature == null) { return; } else { - outputVector.set(columnId, categoricalFeature); + outputVector.set(columnId, (double) categoricalFeature); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java index 2abca8910..417e0af1f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java @@ -22,9 +22,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -108,8 +108,8 @@ public VectorSliceFunction(Integer[] indices, String inputCol) { @Override public Row map(Row row) throws Exception { - Vector inputVec = row.getFieldAs(inputCol); - Vector outputVec; + IntDoubleVector inputVec = row.getFieldAs(inputCol); + IntDoubleVector outputVec; if (maxIndex >= inputVec.size()) { throw new IllegalArgumentException( "Index value " @@ -117,15 +117,15 @@ public Row map(Row row) throws Exception { + " is greater than vector size:" + inputVec.size()); } - if (inputVec instanceof DenseVector) { + if (inputVec instanceof DenseIntDoubleVector) { double[] values = new double[indices.length]; for (int i = 0; i < indices.length; ++i) { - values[i] = ((DenseVector) inputVec).values[indices[i]]; + values[i] = ((DenseIntDoubleVector) inputVec).values[indices[i]]; } - outputVec = new DenseVector(values); + outputVec = new DenseIntDoubleVector(values); } else { int nnz = 0; - SparseVector vec = (SparseVector) inputVec; + SparseIntDoubleVector vec = (SparseIntDoubleVector) inputVec; int[] outputIndices = new int[indices.length]; double[] outputValues = new double[indices.length]; for (int i = 0; i < indices.length; i++) { @@ -140,7 +140,7 @@ public Row map(Row row) throws Exception { outputIndices = Arrays.copyOf(outputIndices, nnz); outputValues = Arrays.copyOf(outputValues, nnz); } - outputVec = new SparseVector(indices.length, outputIndices, outputValues); + outputVec = new SparseIntDoubleVector(indices.length, outputIndices, outputValues); } return Row.join(row, Row.of(outputVec)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java index d977eb0eb..de48b817d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java @@ -25,8 +25,8 @@ import org.apache.flink.ml.common.lossfunc.LeastSquareLoss; import org.apache.flink.ml.common.optimizer.Optimizer; import org.apache.flink.ml.common.optimizer.SGD; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -73,15 +73,15 @@ public LinearRegressionModel fit(Table... inputs) { double label = ((Number) dataPoint.getField(getLabelCol())) .doubleValue(); - DenseVector features = - ((Vector) dataPoint.getField(getFeaturesCol())) + DenseIntDoubleVector features = + ((IntDoubleVector) dataPoint.getField(getFeaturesCol())) .toDense(); return new LabeledPointWithWeight(features, label, weight); }); - DataStream initModelData = + DataStream initModelData = DataStreamUtils.reduce( - trainData.map(x -> x.getFeatures().size()), + trainData.map(x -> (Integer) x.features.size()), (ReduceFunction) (t0, t1) -> { Preconditions.checkState( @@ -89,7 +89,7 @@ public LinearRegressionModel fit(Table... inputs) { "The training data should all have same dimensions."); return t0; }) - .map(DenseVector::new); + .map(DenseIntDoubleVector::new); Optimizer optimizer = new SGD( @@ -99,7 +99,7 @@ public LinearRegressionModel fit(Table... inputs) { getTol(), getReg(), getElasticNet()); - DataStream rawModelData = + DataStream rawModelData = optimizer.optimize(initModelData, trainData, LeastSquareLoss.INSTANCE); DataStream modelData = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java index dec2f4365..2e5290b9b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java @@ -25,8 +25,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -127,7 +127,7 @@ private static class PredictLabelFunction extends RichMapFunction { private final String featuresCol; - private DenseVector coefficient; + private DenseIntDoubleVector coefficient; public PredictLabelFunction(String broadcastModelKey, String featuresCol) { this.broadcastModelKey = broadcastModelKey; @@ -142,7 +142,8 @@ public Row map(Row dataPoint) { getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); coefficient = modelData.coefficient; } - DenseVector features = ((Vector) dataPoint.getField(featuresCol)).toDense(); + DenseIntDoubleVector features = + ((IntDoubleVector) dataPoint.getField(featuresCol)).toDense(); Row predictionResult = predictOneDataPoint(features, coefficient); return Row.join(dataPoint, predictionResult); } @@ -155,7 +156,8 @@ public Row map(Row dataPoint) { * @param coefficient The model parameters. * @return The prediction label and the raw probabilities. */ - private static Row predictOneDataPoint(DenseVector feature, DenseVector coefficient) { + private static Row predictOneDataPoint( + DenseIntDoubleVector feature, DenseIntDoubleVector coefficient) { return Row.of(BLAS.dot(feature, coefficient)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java index efab0f6d7..927bdfb35 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java @@ -25,8 +25,8 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -44,9 +44,9 @@ */ public class LinearRegressionModelData { - public DenseVector coefficient; + public DenseIntDoubleVector coefficient; - public LinearRegressionModelData(DenseVector coefficient) { + public LinearRegressionModelData(DenseIntDoubleVector coefficient) { this.coefficient = coefficient; } @@ -62,12 +62,13 @@ public static DataStream getModelDataStream(Table mod StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); return tEnv.toDataStream(modelData) - .map(x -> new LinearRegressionModelData((DenseVector) x.getField(0))); + .map(x -> new LinearRegressionModelData((DenseIntDoubleVector) x.getField(0))); } /** Data encoder for {@link LinearRegressionModel}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(LinearRegressionModelData modelData, OutputStream outputStream) @@ -84,12 +85,13 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration configuration, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public LinearRegressionModelData read() throws IOException { try { - DenseVector coefficient = + DenseIntDoubleVector coefficient = serializer.deserialize(new DataInputViewStreamWrapper(inputStream)); return new LinearRegressionModelData(coefficient); } catch (EOFException e) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java index d928774ac..64b3b37b0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java @@ -27,8 +27,8 @@ import org.apache.flink.ml.api.AlgoOperator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.param.HasFlatten; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -93,16 +93,16 @@ public Table[] transform(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream> inputData = + DataStream> inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction>) + (MapFunction>) row -> { Number number = (Number) row.getField(labelCol); Preconditions.checkNotNull( number, "Input data must contain label value."); return new Tuple2<>( - ((Vector) row.getField(featuresCol)), + ((IntDoubleVector) row.getField(featuresCol)), number.doubleValue()); }, Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE)); @@ -125,7 +125,7 @@ public Table[] transform(Table... inputs) { @SuppressWarnings("unchecked") private static class ANOVAAggregator implements AggregateFunction< - Tuple2, + Tuple2, Tuple3>>[], List> { @Override @@ -135,9 +135,9 @@ public Tuple3>>[] createAcc @Override public Tuple3>>[] add( - Tuple2 featuresAndLabel, + Tuple2 featuresAndLabel, Tuple3>>[] acc) { - Vector features = featuresAndLabel.f0; + IntDoubleVector features = featuresAndLabel.f0; double label = featuresAndLabel.f1; int numOfFeatures = features.size(); if (acc.length == 0) { @@ -257,15 +257,19 @@ private Table convertToTable( return tEnv.fromDataStream(output) .as("featureIndex", "pValue", "degreeOfFreedom", "fValue"); } else { - DataStream> output = + DataStream> output = datastream.map( - new MapFunction, Tuple3>() { + new MapFunction< + List, + Tuple3>() { @Override - public Tuple3 map( - List rows) { + public Tuple3 + map(List rows) { int numOfFeatures = rows.size(); - DenseVector pValues = new DenseVector(numOfFeatures); - DenseVector fValues = new DenseVector(numOfFeatures); + DenseIntDoubleVector pValues = + new DenseIntDoubleVector(numOfFeatures); + DenseIntDoubleVector fValues = + new DenseIntDoubleVector(numOfFeatures); long[] degrees = new long[numOfFeatures]; for (int i = 0; i < numOfFeatures; i++) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java index 01cac587d..c2f6dad97 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java @@ -32,9 +32,9 @@ import org.apache.flink.ml.api.AlgoOperator; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.param.HasFlatten; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -191,9 +191,9 @@ public Table[] transform(Table... inputs) { outputTypeInfo = new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, + DenseIntDoubleVectorTypeInfo.INSTANCE, Types.PRIMITIVE_ARRAY(Types.INT), - DenseVectorTypeInfo.INSTANCE + DenseIntDoubleVectorTypeInfo.INSTANCE }, new String[] {"pValues", "degreesOfFreedom", "statistics"}); } @@ -236,7 +236,7 @@ public void flatMap(Row row, Collector> collecto Double label = ((Number) row.getFieldAs(labelCol)).doubleValue(); - Vector features = row.getFieldAs(featuresCol); + IntDoubleVector features = row.getFieldAs(featuresCol); for (int i = 0; i < features.size(); i++) { collector.collect(Tuple3.of(i, features.get(i), label)); } @@ -650,8 +650,8 @@ private void endInputWithFlatten() { private void endInputWithoutFlatten() { int size = index2Statistic.size(); - Vector pValueScaledVector = new DenseVector(size); - Vector statisticScaledVector = new DenseVector(size); + IntDoubleVector pValueScaledVector = new DenseIntDoubleVector(size); + IntDoubleVector statisticScaledVector = new DenseIntDoubleVector(size); int[] dofArray = new int[size]; for (Map.Entry> entry : index2Statistic.entrySet()) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java index aaba3d46b..1924be69b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java @@ -33,8 +33,8 @@ import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.param.HasFlatten; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -99,24 +99,24 @@ public Table[] transform(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream> inputData = + DataStream> inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction>) + (MapFunction>) row -> { Number number = (Number) row.getField(labelCol); Preconditions.checkNotNull( number, "Input data must contain label value."); return new Tuple2<>( - ((Vector) row.getField(featuresCol)), + ((IntDoubleVector) row.getField(featuresCol)), number.doubleValue()); }) .returns(Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE)); - DataStream> summaries = - DataStreamUtils.aggregate(inputData, new SummaryAggregator()); + DataStream> + summaries = DataStreamUtils.aggregate(inputData, new SummaryAggregator()); - DataStream covarianceInEachPartition = + DataStream covarianceInEachPartition = BroadcastUtils.withBroadcastStream( Collections.singletonList(inputData), Collections.singletonMap(broadcastSummaryKey, summaries), @@ -126,10 +126,10 @@ public Table[] transform(Table... inputs) { input, new CalCovarianceOperator(broadcastSummaryKey)); }); - DataStream reducedCovariance = + DataStream reducedCovariance = DataStreamUtils.reduce( covarianceInEachPartition, - (ReduceFunction) + (ReduceFunction) (sums1, sums2) -> { BLAS.axpy(1.0, sums1, sums2); return sums2; @@ -156,24 +156,30 @@ private Table convertToTable( return tEnv.fromDataStream(dataStream) .as("featureIndex", "pValue", "degreeOfFreedom", "fValue"); } else { - DataStream> output = + DataStream> output = DataStreamUtils.mapPartition( dataStream, new MapPartitionFunction< Tuple4, - Tuple3>() { + Tuple3>() { @Override public void mapPartition( Iterable> iterable, - Collector> + Collector< + Tuple3< + DenseIntDoubleVector, + long[], + DenseIntDoubleVector>> collector) { List> rows = IteratorUtils.toList(iterable.iterator()); int numOfFeatures = rows.size(); - DenseVector pValues = new DenseVector(numOfFeatures); + DenseIntDoubleVector pValues = + new DenseIntDoubleVector(numOfFeatures); long[] degrees = new long[numOfFeatures]; - DenseVector fValues = new DenseVector(numOfFeatures); + DenseIntDoubleVector fValues = + new DenseIntDoubleVector(numOfFeatures); for (int i = 0; i < numOfFeatures; i++) { Tuple4 tuple = rows.get(i); @@ -190,7 +196,8 @@ public void mapPartition( /** Computes the covariance of each feature on each partition. */ private static class CalCovarianceOperator - extends RichMapPartitionFunction, DenseVector> { + extends RichMapPartitionFunction< + Tuple2, DenseIntDoubleVector> { private final String broadcastKey; @@ -200,14 +207,15 @@ private CalCovarianceOperator(String broadcastKey) { @Override public void mapPartition( - Iterable> iterable, Collector collector) { - Tuple5 summaries = - (Tuple5) + Iterable> iterable, + Collector collector) { + Tuple5 summaries = + (Tuple5) getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); int expectedNumOfFeatures = summaries.f3.size(); - DenseVector sumVector = new DenseVector(expectedNumOfFeatures); - for (Tuple2 featuresAndLabel : iterable) { + DenseIntDoubleVector sumVector = new DenseIntDoubleVector(expectedNumOfFeatures); + for (Tuple2 featuresAndLabel : iterable) { Preconditions.checkArgument( featuresAndLabel.f0.size() == expectedNumOfFeatures, "Input %s features, but FValueTest is expecting %s features.", @@ -229,10 +237,11 @@ public void mapPartition( /** Computes the p-value, fValues and the number of degrees of freedom of input features. */ private static class CalFValueOperator - extends RichMapPartitionFunction> { + extends RichMapPartitionFunction< + DenseIntDoubleVector, Tuple4> { private final String broadcastKey; - private DenseVector sumVector; + private DenseIntDoubleVector sumVector; private CalFValueOperator(String broadcastKey) { this.broadcastKey = broadcastKey; @@ -240,10 +249,10 @@ private CalFValueOperator(String broadcastKey) { @Override public void mapPartition( - Iterable iterable, + Iterable iterable, Collector> collector) { - Tuple5 summaries = - (Tuple5) + Tuple5 summaries = + (Tuple5) getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); int expectedNumOfFeatures = summaries.f4.size(); @@ -273,26 +282,31 @@ public void mapPartition( /** Computes the num, mean, and standard deviation of the input label and features. */ private static class SummaryAggregator implements AggregateFunction< - Tuple2, - Tuple5, - Tuple5> { + Tuple2, + Tuple5, + Tuple5> { @Override - public Tuple5 createAccumulator() { + public Tuple5 + createAccumulator() { return Tuple5.of( - 0L, 0.0, 0.0, new DenseVector(new double[0]), new DenseVector(new double[0])); + 0L, + 0.0, + 0.0, + new DenseIntDoubleVector(new double[0]), + new DenseIntDoubleVector(new double[0])); } @Override - public Tuple5 add( - Tuple2 featuresAndLabel, - Tuple5 summary) { - Vector features = featuresAndLabel.f0; + public Tuple5 add( + Tuple2 featuresAndLabel, + Tuple5 summary) { + IntDoubleVector features = featuresAndLabel.f0; double label = featuresAndLabel.f1; if (summary.f0 == 0) { - summary.f3 = new DenseVector(features.size()); - summary.f4 = new DenseVector(features.size()); + summary.f3 = new DenseIntDoubleVector(features.size()); + summary.f4 = new DenseIntDoubleVector(features.size()); } summary.f0 += 1L; summary.f1 += label; @@ -306,14 +320,14 @@ public Tuple5 add( } @Override - public Tuple5 getResult( - Tuple5 summary) { + public Tuple5 getResult( + Tuple5 summary) { final long numRows = summary.f0; Preconditions.checkState(numRows > 0, "The training set is empty."); int numOfFeatures = summary.f3.size(); double labelMean = summary.f1 / numRows; - Tuple5 result = + Tuple5 result = Tuple5.of( numRows, labelMean, @@ -321,8 +335,8 @@ public Tuple5 getResult( (summary.f2 / numRows - labelMean * labelMean) * numRows / (numRows - 1)), - new DenseVector(numOfFeatures), - new DenseVector(numOfFeatures)); + new DenseIntDoubleVector(numOfFeatures), + new DenseIntDoubleVector(numOfFeatures)); for (int i = 0; i < summary.f3.size(); i++) { double mean = summary.f3.get(i) / numRows; result.f3.values[i] = mean; @@ -336,9 +350,9 @@ public Tuple5 getResult( } @Override - public Tuple5 merge( - Tuple5 summary1, - Tuple5 summary2) { + public Tuple5 merge( + Tuple5 summary1, + Tuple5 summary2) { if (summary1.f0 == 0) { return summary2; } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/FunctionsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/FunctionsTest.java index c5ddfc4dc..d0472a00d 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/FunctionsTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/FunctionsTest.java @@ -19,8 +19,8 @@ package org.apache.flink.ml; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -60,10 +60,10 @@ public class FunctionsTest extends AbstractTestBase { private static final List longArrays = Arrays.asList(new long[] {0, 0}, new long[] {0, 1}); - private static final List denseVectors = + private static final List denseVectors = Arrays.asList(Vectors.dense(0.0, 0.0), Vectors.dense(0.0, 1.0)); - private static final List sparseVectors = + private static final List sparseVectors = Arrays.asList( Vectors.sparse(2, new int[0], new double[0]), Vectors.sparse(2, new int[] {1}, new double[] {1.0})); @@ -126,7 +126,7 @@ private void testArrayToVector(List array) { assertEquals(outputValues.size(), denseVectors.size()); for (int i = 0; i < denseVectors.size(); i++) { - DenseVector vector = outputValues.get(i).getFieldAs("vector"); + DenseIntDoubleVector vector = outputValues.get(i).getFieldAs("vector"); assertEquals(denseVectors.get(i), vector); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java index 483dc34d8..ea25e0e2e 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java @@ -23,9 +23,9 @@ import org.apache.flink.ml.classification.knn.Knn; import org.apache.flink.ml.classification.knn.KnnModel; import org.apache.flink.ml.classification.knn.KnnModelData; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.DenseMatrix; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -96,7 +96,7 @@ public void before() { tEnv = StreamTableEnvironment.create(env); Schema schema = Schema.newBuilder() - .column("f0", DataTypes.of(DenseVector.class)) + .column("f0", DataTypes.of(DenseIntDoubleVector.class)) .column("f1", DataTypes.DOUBLE()) .build(); DataStream dataStream = env.fromCollection(trainRows); @@ -179,10 +179,10 @@ public void testInputTypeConversion() throws Exception { trainData = TestUtils.convertDataTypesToSparseInt(tEnv, trainData); predictData = TestUtils.convertDataTypesToSparseInt(tEnv, predictData); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class}, TestUtils.getColumnDataTypes(trainData)); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class}, TestUtils.getColumnDataTypes(predictData)); Knn knn = new Knn(); @@ -232,12 +232,12 @@ public void testGetModelData() throws Exception { KnnModelData data = new KnnModelData( (DenseMatrix) modelRows.get(0).getField(0), - (DenseVector) modelRows.get(0).getField(1), - (DenseVector) modelRows.get(0).getField(2)); + (DenseIntDoubleVector) modelRows.get(0).getField(1), + (DenseIntDoubleVector) modelRows.get(0).getField(2)); Assert.assertNotNull(data); assertEquals(2, data.packedFeatures.numRows()); - assertEquals(data.packedFeatures.numCols(), data.labels.size()); - assertEquals(data.featureNormSquares.size(), data.labels.size()); + assertEquals(data.packedFeatures.numCols(), data.labels.size().intValue()); + assertEquals(data.featureNormSquares.size().intValue(), data.labels.size().intValue()); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java index b2a20eb22..e4e7d5bf9 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java @@ -24,11 +24,11 @@ import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; import org.apache.flink.ml.classification.linearsvc.LinearSVCModelData; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -93,7 +93,9 @@ public void before() { trainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); } @@ -104,9 +106,11 @@ private void verifyPredictionResult( throws Exception { List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); for (Row predictionRow : predResult) { - DenseVector feature = ((Vector) predictionRow.getField(featuresCol)).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.getField(featuresCol)).toDense(); double prediction = (Double) predictionRow.getField(predictionCol); - DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.getField(rawPredictionCol); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); assertTrue(rawPrediction.get(0) < 0); @@ -196,7 +200,7 @@ public void testFitAndPredict() throws Exception { public void testInputTypeConversion() throws Exception { trainDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, trainDataTable); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class, Integer.class}, TestUtils.getColumnDataTypes(trainDataTable)); LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); @@ -274,7 +278,9 @@ public void testMoreSubtaskThanData() throws Exception { trainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index f899c281e..739d41973 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -23,14 +23,15 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataSegment; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.types.BasicType; import org.apache.flink.ml.servable.types.DataTypes; @@ -55,6 +56,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; import static org.apache.flink.ml.util.TestUtils.saveAndLoadServable; import static org.junit.Assert.assertArrayEquals; @@ -104,6 +106,7 @@ public class LogisticRegressionTest extends AbstractTestBase { private static final double TOLERANCE = 1e-7; private Table binomialDataTable; + private Table binomialSparseDataTable; private Table multinomialDataTable; @@ -120,16 +123,43 @@ public void before() { binomialTrainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); + + List binomialSparseTrainData = + binomialTrainData.stream() + .map( + r -> { + DenseIntDoubleVector features = r.getFieldAs(0); + double label = r.getFieldAs(1); + double weight = r.getFieldAs(2); + return Row.of(features.toSparse(), label, weight); + }) + .collect(Collectors.toList()); + binomialSparseDataTable = + tEnv.fromDataStream( + env.fromCollection( + binomialSparseTrainData, + new RowTypeInfo( + new TypeInformation[] { + SparseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + multinomialDataTable = tEnv.fromDataStream( env.fromCollection( multinomialTrainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); binomialDataDataFrame = @@ -149,9 +179,11 @@ private void verifyPredictionResult( throws Exception { List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); for (Row predictionRow : predResult) { - DenseVector feature = ((Vector) predictionRow.getField(featuresCol)).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.getField(featuresCol)).toDense(); double prediction = (double) predictionRow.getField(predictionCol); - DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.getField(rawPredictionCol); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); assertTrue(rawPrediction.get(0) > 0.5); @@ -169,9 +201,11 @@ private void verifyPredictionResult( int rawPredictionColIndex = output.getIndex(rawPredictionCol); for (org.apache.flink.ml.servable.api.Row predictionRow : output.collect()) { - DenseVector feature = ((Vector) predictionRow.get(featuresColIndex)).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.get(featuresColIndex)).toDense(); double prediction = (double) predictionRow.get(predictionColIndex); - DenseVector rawPrediction = (DenseVector) predictionRow.get(rawPredictionColIndex); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.get(rawPredictionColIndex); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); assertTrue(rawPrediction.get(0) > 0.5); @@ -261,7 +295,7 @@ public void testFitAndPredict() throws Exception { public void testInputTypeConversion() throws Exception { binomialDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, binomialDataTable); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class, Integer.class}, TestUtils.getColumnDataTypes(binomialDataTable)); LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); @@ -290,7 +324,7 @@ public void testSaveLoadAndPredict() throws Exception { tempFolder.newFolder().getAbsolutePath(), LogisticRegressionModel::load); assertEquals( - Arrays.asList("coefficient", "modelVersion"), + Arrays.asList("coefficient", "startIndex", "endIndex", "modelVersion"), model.getModelData()[0].getResolvedSchema().getColumnNames()); Table output = model.transform(binomialDataTable)[0]; verifyPredictionResult( @@ -305,7 +339,20 @@ public void testSaveLoadAndPredict() throws Exception { public void testGetModelData() throws Exception { LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); - List modelData = + List modelData = + IteratorUtils.toList( + LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + assertEquals(1, modelData.size()); + assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, 0.1); + } + + @Test + @SuppressWarnings("unchecked") + public void testGetModelDataFromSparseInput() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + LogisticRegressionModel model = logisticRegression.fit(binomialSparseDataTable); + List modelData = IteratorUtils.toList( LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); @@ -396,7 +443,9 @@ public void testMoreSubtaskThanData() throws Exception { binomialTrainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); @@ -427,7 +476,7 @@ private void checkRegularization(double reg, double elasticNet, double[] expecte .setReg(reg) .setElasticNet(elasticNet) .fit(binomialDataTable); - List modelData = + List modelData = IteratorUtils.toList( LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java new file mode 100644 index 000000000..02e3579b3 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataSegment; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionWithFtrl; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseLongDoubleVectorTypeInfo; +import org.apache.flink.ml.servable.api.DataFrame; +import org.apache.flink.ml.servable.types.BasicType; +import org.apache.flink.ml.servable.types.DataTypes; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.ByteArrayInputStream; +import java.io.SequenceInputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.ml.util.TestUtils.saveAndLoadServable; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link LogisticRegressionWithFtrl}. */ +public class LogisticRegressionWithFtrlTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private final double[] expectedCoefficient = new double[] {0.52, -1.21, -1.07, -0.85}; + private static final int MAX_ITER = 100; + private static final int NUM_SERVERS = 2; + private static final double TOLERANCE = 1e-7; + + private static final List trainRows = + Arrays.asList( + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 1}, new double[] {1, 2}), + 0., + 1.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 2}, new double[] {2, 3}), + 0., + 2.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 3}, new double[] {3, 4}), + 0., + 3.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 2}, new double[] {4, 4}), + 0., + 4.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 1}, new double[] {5, 4}), + 0., + 5.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 2}, new double[] {11, 3}), + 1., + 1.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 3}, new double[] {12, 4}), + 1., + 2.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 1}, new double[] {13, 2}), + 1., + 3.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 3}, new double[] {14, 4}), + 1., + 4.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 2}, new double[] {15, 4}), + 1., + 5.)); + + private static final List testRows = + Arrays.asList( + Row.of(Vectors.sparse(4, new int[] {0, 1}, new double[] {1, 2}), 0., 1.), + Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {2, 3}), 0., 2.), + Row.of(Vectors.sparse(4, new int[] {0, 3}, new double[] {3, 4}), 0., 3.), + Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {4, 4}), 0., 4.), + Row.of(Vectors.sparse(4, new int[] {0, 1}, new double[] {5, 4}), 0., 5.), + Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {11, 3}), 1., 1.), + Row.of(Vectors.sparse(4, new int[] {0, 3}, new double[] {12, 4}), 1., 2.), + Row.of(Vectors.sparse(4, new int[] {0, 1}, new double[] {13, 2}), 1., 3.), + Row.of(Vectors.sparse(4, new int[] {0, 3}, new double[] {14, 4}), 1., 4.), + Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {15, 4}), 1., 5.)); + + private StreamTableEnvironment tEnv; + private StreamExecutionEnvironment env; + private Table trainTable; + private Table testTable; + private DataFrame testDataFrame; + + @Before + public void before() { + env = TestUtils.getExecutionEnvironment(); + tEnv = StreamTableEnvironment.create(env); + + trainTable = + tEnv.fromDataStream( + env.fromCollection( + trainRows, + new RowTypeInfo( + new TypeInformation[] { + SparseLongDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + testTable = + tEnv.fromDataStream( + env.fromCollection( + testRows, + new RowTypeInfo( + new TypeInformation[] { + SparseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + + testDataFrame = + TestUtils.constructDataFrame( + new ArrayList<>(Arrays.asList("features", "label", "weight")), + new ArrayList<>( + Arrays.asList( + DataTypes.VECTOR(BasicType.DOUBLE), + DataTypes.DOUBLE, + DataTypes.DOUBLE)), + testRows); + } + + @Test + public void testParam() { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = new LogisticRegressionWithFtrl(); + assertEquals("features", logisticRegressionWithFtrl.getFeaturesCol()); + assertEquals("label", logisticRegressionWithFtrl.getLabelCol()); + assertNull(logisticRegressionWithFtrl.getWeightCol()); + assertEquals(20, logisticRegressionWithFtrl.getMaxIter()); + assertEquals(1e-6, logisticRegressionWithFtrl.getTol(), TOLERANCE); + assertEquals(32, logisticRegressionWithFtrl.getGlobalBatchSize()); + assertEquals(0, logisticRegressionWithFtrl.getReg(), TOLERANCE); + assertEquals(0, logisticRegressionWithFtrl.getElasticNet(), TOLERANCE); + assertEquals("auto", logisticRegressionWithFtrl.getMultiClass()); + assertEquals("prediction", logisticRegressionWithFtrl.getPredictionCol()); + assertEquals("rawPrediction", logisticRegressionWithFtrl.getRawPredictionCol()); + + assertEquals(0.1, logisticRegressionWithFtrl.getAlpha(), TOLERANCE); + assertEquals(0.1, logisticRegressionWithFtrl.getBeta(), TOLERANCE); + assertEquals(0L, logisticRegressionWithFtrl.getModelDim()); + assertEquals(1, logisticRegressionWithFtrl.getNumServers()); + + logisticRegressionWithFtrl + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setMaxIter(1000) + .setTol(0.001) + .setGlobalBatchSize(1000) + .setReg(0.1) + .setElasticNet(0.5) + .setMultiClass("binomial") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol") + .setAlpha(0.2) + .setBeta(0.2) + .setModelDim(10000000L) + .setNumServers(4); + assertEquals("test_features", logisticRegressionWithFtrl.getFeaturesCol()); + assertEquals("test_label", logisticRegressionWithFtrl.getLabelCol()); + assertEquals("test_weight", logisticRegressionWithFtrl.getWeightCol()); + assertEquals(1000, logisticRegressionWithFtrl.getMaxIter()); + assertEquals(0.001, logisticRegressionWithFtrl.getTol(), TOLERANCE); + assertEquals(1000, logisticRegressionWithFtrl.getGlobalBatchSize()); + assertEquals(0.1, logisticRegressionWithFtrl.getReg(), TOLERANCE); + assertEquals(0.5, logisticRegressionWithFtrl.getElasticNet(), TOLERANCE); + assertEquals("binomial", logisticRegressionWithFtrl.getMultiClass()); + assertEquals("test_predictionCol", logisticRegressionWithFtrl.getPredictionCol()); + assertEquals("test_rawPredictionCol", logisticRegressionWithFtrl.getRawPredictionCol()); + + assertEquals(0.2, logisticRegressionWithFtrl.getAlpha(), TOLERANCE); + assertEquals(0.2, logisticRegressionWithFtrl.getBeta(), TOLERANCE); + assertEquals(10000000L, logisticRegressionWithFtrl.getModelDim()); + assertEquals(4, logisticRegressionWithFtrl.getNumServers()); + } + + @Test + public void testOutputSchema() { + Table tempTable = trainTable.as("test_features", "test_label", "test_weight"); + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl() + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + Table output = logisticRegressionWithFtrl.fit(trainTable).transform(tempTable)[0]; + assertEquals( + Arrays.asList( + "test_features", + "test_label", + "test_weight", + "test_predictionCol", + "test_rawPredictionCol"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + @SuppressWarnings("unchecked") + public void testGetModelData() throws Exception { + // Fix the parallelism as one for stability tests. + env.setParallelism(1); + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + List modelData = + IteratorUtils.toList( + LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + + assertEquals(NUM_SERVERS, modelData.size()); + + modelData.sort(Comparator.comparingLong(o -> o.startIndex)); + + double[] collectedCoefficient = new double[4]; + for (LogisticRegressionModelDataSegment modelSegment : modelData) { + int startIndex = (int) modelSegment.startIndex; + double[] segment = modelSegment.coefficient.values; + System.arraycopy(segment, 0, collectedCoefficient, startIndex, segment.length); + } + assertArrayEquals(expectedCoefficient, collectedCoefficient, 0.1); + } + + @Test + public void testFitAndPredict() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); + Table output = logisticRegressionWithFtrl.fit(trainTable).transform(testTable)[0]; + verifyPredictionResult( + output, + logisticRegressionWithFtrl.getFeaturesCol(), + logisticRegressionWithFtrl.getPredictionCol(), + logisticRegressionWithFtrl.getRawPredictionCol()); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); + logisticRegressionWithFtrl = + TestUtils.saveAndReload( + tEnv, + logisticRegressionWithFtrl, + tempFolder.newFolder().getAbsolutePath(), + LogisticRegressionWithFtrl::load); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + model = + TestUtils.saveAndReload( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + LogisticRegressionModel::load); + assertEquals( + Arrays.asList("coefficient", "startIndex", "endIndex", "modelVersion"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = model.transform(testTable)[0]; + verifyPredictionResult( + output, + logisticRegressionWithFtrl.getFeaturesCol(), + logisticRegressionWithFtrl.getPredictionCol(), + logisticRegressionWithFtrl.getRawPredictionCol()); + } + + @Test + public void testSetModelData() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + + LogisticRegressionModel newModel = new LogisticRegressionModel(); + ParamUtils.updateExistingParams(newModel, model.getParamMap()); + newModel.setModelData(model.getModelData()); + Table output = newModel.transform(testTable)[0]; + verifyPredictionResult( + output, + logisticRegressionWithFtrl.getFeaturesCol(), + logisticRegressionWithFtrl.getPredictionCol(), + logisticRegressionWithFtrl.getRawPredictionCol()); + } + + @Test + public void testSaveLoadServableAndPredict() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + + LogisticRegressionModelServable servable = + saveAndLoadServable( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + LogisticRegressionModel::loadServable); + + DataFrame output = servable.transform(testDataFrame); + verifyPredictionResult( + output, + servable.getFeaturesCol(), + servable.getPredictionCol(), + servable.getRawPredictionCol()); + } + + @Test + public void testSetModelDataToServable() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + List serializedModelData = + IteratorUtils.toList( + LogisticRegressionModelDataUtil.getModelDataByteStream( + model.getModelData()[0]) + .executeAndCollect()); + + LogisticRegressionModelServable servable = new LogisticRegressionModelServable(); + ParamUtils.updateExistingParams(servable, model.getParamMap()); + + List modelStreams = + serializedModelData.stream() + .map(ByteArrayInputStream::new) + .collect(Collectors.toList()); + servable.setModelData(new SequenceInputStream(Collections.enumeration(modelStreams))); + DataFrame output = servable.transform(testDataFrame); + verifyPredictionResult( + output, + servable.getFeaturesCol(), + servable.getPredictionCol(), + servable.getRawPredictionCol()); + } + + private void verifyPredictionResult( + Table output, String featuresCol, String predictionCol, String rawPredictionCol) + throws Exception { + List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row predictionRow : predResult) { + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.getField(featuresCol)).toDense(); + double prediction = (double) predictionRow.getField(predictionCol); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.getField(rawPredictionCol); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0.5); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0.5); + } + } + } + + private void verifyPredictionResult( + DataFrame output, String featuresCol, String predictionCol, String rawPredictionCol) { + int featuresColIndex = output.getIndex(featuresCol); + int predictionColIndex = output.getIndex(predictionCol); + int rawPredictionColIndex = output.getIndex(rawPredictionCol); + + for (org.apache.flink.ml.servable.api.Row predictionRow : output.collect()) { + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.get(featuresColIndex)).toDense(); + double prediction = (double) predictionRow.get(predictionColIndex); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.get(rawPredictionColIndex); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0.5); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0.5); + } + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java index cac6e4c91..f02a04482 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java @@ -21,8 +21,8 @@ import org.apache.flink.ml.classification.naivebayes.NaiveBayes; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -56,7 +56,7 @@ public class NaiveBayesTest extends AbstractTestBase { private StreamTableEnvironment tEnv; private Table trainTable; private Table predictTable; - private Map expectedOutput; + private Map expectedOutput; private NaiveBayes estimator; @Before @@ -82,7 +82,7 @@ public void before() { predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("features"); expectedOutput = - new HashMap() { + new HashMap() { { put(Vectors.dense(0, 1.), 11.0); put(Vectors.dense(0, 0.), 11.0); @@ -109,13 +109,13 @@ public void before() { * @param predictionCol Name of the column in the table that contains the prediction result * @return A map containing the collected results */ - private static Map executeAndCollect( + private static Map executeAndCollect( Table table, String featuresCol, String predictionCol) { - Map map = new HashMap<>(); + Map map = new HashMap<>(); for (CloseableIterator it = table.execute().collect(); it.hasNext(); ) { Row row = it.next(); map.put( - ((Vector) row.getField(featuresCol)).toDense(), + ((IntDoubleVector) row.getField(featuresCol)).toDense(), (Double) row.getField(predictionCol)); } @@ -159,7 +159,7 @@ public void testParam() { public void testFitAndPredict() { NaiveBayesModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } @@ -170,14 +170,15 @@ public void testInputTypeConversion() { predictTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictTable); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class}, TestUtils.getColumnDataTypes(trainTable)); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(predictTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(predictTable)); NaiveBayesModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } @@ -194,7 +195,7 @@ public void testOutputSchema() { NaiveBayesModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } @@ -282,7 +283,7 @@ public void testSaveLoad() throws Exception { Table outputTable = model.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } @@ -319,7 +320,7 @@ public void testSetModelData() { Table outputTable = modelB.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, modelB.getFeaturesCol(), modelB.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java index cac9473c3..9c8e37605 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java @@ -29,12 +29,12 @@ import org.apache.flink.metrics.Gauge; import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataSegment; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.InMemorySinkFunction; import org.apache.flink.ml.util.InMemorySourceFunction; @@ -153,7 +153,7 @@ public class OnlineLogisticRegressionTest extends TestLogger { private InMemorySourceFunction trainSparseSource; private InMemorySourceFunction predictSparseSource; private InMemorySinkFunction outputSink; - private InMemorySinkFunction modelDataSink; + private InMemorySinkFunction modelDataSink; private static InMemoryReporter reporter; private static MiniCluster miniCluster; @@ -212,7 +212,8 @@ public void before() throws Exception { trainDenseSource, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(DenseVector.class), Types.DOUBLE + TypeInformation.of(DenseIntDoubleVector.class), + Types.DOUBLE }, new String[] {"features", "label"}))); @@ -222,7 +223,8 @@ public void before() throws Exception { predictDenseSource, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(DenseVector.class), Types.DOUBLE + TypeInformation.of(DenseIntDoubleVector.class), + Types.DOUBLE }, new String[] {"features", "label"}))); @@ -232,7 +234,7 @@ public void before() throws Exception { trainSparseSource, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(SparseVector.class), + TypeInformation.of(SparseIntDoubleVector.class), Types.DOUBLE, Types.DOUBLE }, @@ -244,7 +246,8 @@ public void before() throws Exception { predictSparseSource, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(SparseVector.class), Types.DOUBLE + TypeInformation.of(SparseIntDoubleVector.class), + Types.DOUBLE }, new String[] {"features", "label"}))); @@ -252,20 +255,24 @@ public void before() throws Exception { tEnv.fromDataStream( env.fromElements( Row.of( - new DenseVector( + new DenseIntDoubleVector( new double[] { 0.41233679404769874, -0.18088118293232122 }), + 0L, + 2L, 0L))); initSparseModel = tEnv.fromDataStream( env.fromElements( Row.of( - new DenseVector( + new DenseIntDoubleVector( new double[] { 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01 }), + 0L, + 10L, 0L))); } @@ -330,7 +337,7 @@ private void waitModelDataUpdate(JobID jobID) throws InterruptedException { * * @param expectedRawInfo A list containing sets of expected result RawInfo. */ - private void predictAndAssert(List expectedRawInfo, boolean isSparse) + private void predictAndAssert(List expectedRawInfo, boolean isSparse) throws Exception { if (isSparse) { predictSparseSource.addAll(PREDICT_SPARSE_ROWS); @@ -339,7 +346,7 @@ private void predictAndAssert(List expectedRawInfo, boolean isSpars } List rawResult = outputSink.poll(isSparse ? PREDICT_SPARSE_ROWS.length : PREDICT_DENSE_ROWS.length); - List resultDetail = new ArrayList<>(rawResult.size()); + List resultDetail = new ArrayList<>(rawResult.size()); for (Row row : rawResult) { resultDetail.add(row.getFieldAs(3)); } @@ -412,14 +419,18 @@ public void testParam() { @Test public void testDenseFitAndPredict() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.04481034155642882, 0.9551896584435712}), - new DenseVector(new double[] {0.5353966697318491, 0.4646033302681509})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.04481034155642882, 0.9551896584435712}), + new DenseIntDoubleVector( + new double[] {0.5353966697318491, 0.4646033302681509})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.013104324065967066, 0.9868956759340329}), - new DenseVector(new double[] {0.5095144380001769, 0.49048556199982307})); + new DenseIntDoubleVector( + new double[] {0.013104324065967066, 0.9868956759340329}), + new DenseIntDoubleVector( + new double[] {0.5095144380001769, 0.49048556199982307})); OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression() .setFeaturesCol("features") @@ -446,14 +457,18 @@ public void testDenseFitAndPredict() throws Exception { @Test public void testSparseFitAndPredict() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.4452309884735286, 0.5547690115264714}), - new DenseVector(new double[] {0.5105551725414953, 0.4894448274585047})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.4452309884735286, 0.5547690115264714}), + new DenseIntDoubleVector( + new double[] {0.5105551725414953, 0.4894448274585047})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.40310431554310666, 0.5968956844568933}), - new DenseVector(new double[] {0.5249618837373886, 0.4750381162626114})); + new DenseIntDoubleVector( + new double[] {0.40310431554310666, 0.5968956844568933}), + new DenseIntDoubleVector( + new double[] {0.5249618837373886, 0.4750381162626114})); OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression() .setFeaturesCol("features") @@ -479,14 +494,18 @@ public void testSparseFitAndPredict() throws Exception { @Test public void testFitAndPredictWithWeightCol() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.452491993753382, 0.547508006246618}), - new DenseVector(new double[] {0.5069192929506545, 0.4930807070493455})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.452491993753382, 0.547508006246618}), + new DenseIntDoubleVector( + new double[] {0.5069192929506545, 0.4930807070493455})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.41108882806164193, 0.5889111719383581}), - new DenseVector(new double[] {0.5247727600974581, 0.4752272399025419})); + new DenseIntDoubleVector( + new double[] {0.41108882806164193, 0.5889111719383581}), + new DenseIntDoubleVector( + new double[] {0.5247727600974581, 0.4752272399025419})); OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression() .setFeaturesCol("features") @@ -517,20 +536,24 @@ public void testGenerateRandomModelData() throws Exception { LogisticRegressionModelDataUtil.generateRandomModelData(tEnv, 2, 2022); DataStream modelData = tEnv.toDataStream(modelDataTable); Row modelRow = (Row) IteratorUtils.toList(modelData.executeAndCollect()).get(0); - Assert.assertEquals(2, ((DenseVector) modelRow.getField(0)).size()); + Assert.assertEquals(2, ((DenseIntDoubleVector) modelRow.getField(0)).size().intValue()); Assert.assertEquals(0L, modelRow.getField(1)); } @Test public void testInitWithLogisticRegression() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.037327343811250024, 0.96267265618875}), - new DenseVector(new double[] {0.5684728224189707, 0.4315271775810293})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.037327343811250024, 0.96267265618875}), + new DenseIntDoubleVector( + new double[] {0.5684728224189707, 0.4315271775810293})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.007758574555505882, 0.9922414254444941}), - new DenseVector(new double[] {0.5257216567388069, 0.4742783432611931})); + new DenseIntDoubleVector( + new double[] {0.007758574555505882, 0.9922414254444941}), + new DenseIntDoubleVector( + new double[] {0.5257216567388069, 0.4742783432611931})); LogisticRegression logisticRegression = new LogisticRegression() .setLabelCol("label") @@ -586,14 +609,18 @@ public void testBatchSizeLessThanParallelism() { @Test public void testSaveAndReload() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.04481034155642882, 0.9551896584435712}), - new DenseVector(new double[] {0.5353966697318491, 0.4646033302681509})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.04481034155642882, 0.9551896584435712}), + new DenseIntDoubleVector( + new double[] {0.5353966697318491, 0.4646033302681509})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.013104324065967066, 0.9868956759340329}), - new DenseVector(new double[] {0.5095144380001769, 0.49048556199982307})); + new DenseIntDoubleVector( + new double[] {0.013104324065967066, 0.9868956759340329}), + new DenseIntDoubleVector( + new double[] {0.5095144380001769, 0.49048556199982307})); OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression() .setFeaturesCol("features") @@ -644,11 +671,12 @@ public void testGetModelData() throws Exception { submitJob(env.getStreamGraph().getJobGraph()); trainDenseSource.addAll(TRAIN_DENSE_ROWS_1); - LogisticRegressionModelData actualModelData = modelDataSink.poll(); + LogisticRegressionModelDataSegment actualModelData = modelDataSink.poll(); - LogisticRegressionModelData expectedModelData = - new LogisticRegressionModelData( - new DenseVector(new double[] {0.2994527071464283, -0.1412541067743284}), + LogisticRegressionModelDataSegment expectedModelData = + new LogisticRegressionModelDataSegment( + new DenseIntDoubleVector( + new double[] {0.2994527071464283, -0.1412541067743284}), 1L); Assert.assertArrayEquals( expectedModelData.coefficient.values, actualModelData.coefficient.values, 1e-5); @@ -657,28 +685,34 @@ public void testGetModelData() throws Exception { @Test public void testSetModelData() throws Exception { - LogisticRegressionModelData modelData1 = - new LogisticRegressionModelData(new DenseVector(new double[] {0.085, -0.22}), 1L); + LogisticRegressionModelDataSegment modelData1 = + new LogisticRegressionModelDataSegment( + new DenseIntDoubleVector(new double[] {0.085, -0.22}), 1L); - LogisticRegressionModelData modelData2 = - new LogisticRegressionModelData(new DenseVector(new double[] {0.075, -0.28}), 2L); + LogisticRegressionModelDataSegment modelData2 = + new LogisticRegressionModelDataSegment( + new DenseIntDoubleVector(new double[] {0.075, -0.28}), 2L); - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.6285496932692606, 0.3714503067307394}), - new DenseVector(new double[] {0.7588710471221473, 0.24112895287785274})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.6285496932692606, 0.3714503067307394}), + new DenseIntDoubleVector( + new double[] {0.7588710471221473, 0.24112895287785274})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.6673003248270917, 0.3326996751729083}), - new DenseVector(new double[] {0.8779865510655934, 0.12201344893440658})); + new DenseIntDoubleVector( + new double[] {0.6673003248270917, 0.3326996751729083}), + new DenseIntDoubleVector( + new double[] {0.8779865510655934, 0.12201344893440658})); - InMemorySourceFunction modelDataSource = + InMemorySourceFunction modelDataSource = new InMemorySourceFunction<>(); Table modelDataTable = tEnv.fromDataStream( env.addSource( modelDataSource, - TypeInformation.of(LogisticRegressionModelData.class))); + TypeInformation.of(LogisticRegressionModelDataSegment.class))); OnlineLogisticRegressionModel onlineModel = new OnlineLogisticRegressionModel() diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java index 292e793ed..0f5143250 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java @@ -31,9 +31,9 @@ import org.apache.flink.ml.common.window.EventTimeTumblingWindows; import org.apache.flink.ml.common.window.GlobalWindows; import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -71,7 +71,7 @@ public class AgglomerativeClusteringTest extends AbstractTestBase { private StreamExecutionEnvironment env; private Table inputDataTable; - private static final List INPUT_DATA = + private static final List INPUT_DATA = Arrays.asList( Vectors.dense(1, 1), Vectors.dense(1, 4), @@ -97,7 +97,7 @@ public class AgglomerativeClusteringTest extends AbstractTestBase { private static final double[] EUCLIDEAN_COMPLETE_MERGE_DISTANCES = new double[] {1, 1.5, 3, 3.3541019, 5}; - private static final List> EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT = + private static final List> EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT = Arrays.asList( new HashSet<>( Arrays.asList( @@ -107,38 +107,42 @@ public class AgglomerativeClusteringTest extends AbstractTestBase { Vectors.dense(4, 0))), new HashSet<>(Arrays.asList(Vectors.dense(1, 4), Vectors.dense(4, 4)))); - private static final List> EUCLIDEAN_WARD_THRESHOLD_AS_TWO_RESULT = + private static final List> EUCLIDEAN_WARD_THRESHOLD_AS_TWO_RESULT = Arrays.asList( new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), new HashSet<>(Collections.singletonList(Vectors.dense(1, 4))), new HashSet<>(Collections.singletonList(Vectors.dense(4, 4))), new HashSet<>(Arrays.asList(Vectors.dense(4, 1.5), Vectors.dense(4, 0)))); - private static final List> EUCLIDEAN_WARD_COUNT_FIVE_WINDOW_AS_TWO_RESULT = - Arrays.asList( - new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), - new HashSet<>( - Arrays.asList( - Vectors.dense(1, 4), - Vectors.dense(4, 4), - Vectors.dense(4, 1.5)))); - - private static final List> EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT = - Arrays.asList( - new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), - new HashSet<>(Collections.singletonList(Vectors.dense(1, 4))), - new HashSet<>(Arrays.asList(Vectors.dense(4, 0), Vectors.dense(4, 1.5))), - new HashSet<>(Collections.singletonList(Vectors.dense(4, 4)))); - - private static final List> EUCLIDEAN_AVERAGE_NUM_CLUSTERS_AS_TWO_RESULT = - Arrays.asList( - new HashSet<>( - Arrays.asList( - Vectors.dense(1, 1), - Vectors.dense(1, 0), - Vectors.dense(4, 1.5), - Vectors.dense(4, 0))), - new HashSet<>(Arrays.asList(Vectors.dense(1, 4), Vectors.dense(4, 4)))); + private static final List> + EUCLIDEAN_WARD_COUNT_FIVE_WINDOW_AS_TWO_RESULT = + Arrays.asList( + new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), + new HashSet<>( + Arrays.asList( + Vectors.dense(1, 4), + Vectors.dense(4, 4), + Vectors.dense(4, 1.5)))); + + private static final List> + EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT = + Arrays.asList( + new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), + new HashSet<>(Collections.singletonList(Vectors.dense(1, 4))), + new HashSet<>( + Arrays.asList(Vectors.dense(4, 0), Vectors.dense(4, 1.5))), + new HashSet<>(Collections.singletonList(Vectors.dense(4, 4)))); + + private static final List> + EUCLIDEAN_AVERAGE_NUM_CLUSTERS_AS_TWO_RESULT = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(1, 1), + Vectors.dense(1, 0), + Vectors.dense(4, 1.5), + Vectors.dense(4, 0))), + new HashSet<>(Arrays.asList(Vectors.dense(1, 4), Vectors.dense(4, 4)))); private static final double TOLERANCE = 1e-7; @@ -303,18 +307,22 @@ public void testTransformWithCountTumblingWindows() throws Exception { public void testTransformWithEventTimeTumblingWindows() throws Exception { RowTypeInfo outputTypeInfo = new RowTypeInfo( - new TypeInformation[] {DenseVectorTypeInfo.INSTANCE, Types.INSTANT}, + new TypeInformation[] { + DenseIntDoubleVectorTypeInfo.INSTANCE, Types.INSTANT + }, new String[] {"features", "ts"}); Instant baseTime = Instant.now(); DataStream inputDataStream = env.fromCollection(INPUT_DATA) .setParallelism(1) - .map(x -> Row.of(x, baseTime.plusSeconds((long) x.get(0))), outputTypeInfo); + .map( + x -> Row.of(x, baseTime.plusSeconds((long) x.get(0).doubleValue())), + outputTypeInfo); Schema schema = Schema.newBuilder() - .column("features", DataTypes.of(DenseVectorTypeInfo.INSTANCE)) + .column("features", DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE)) .column("ts", DataTypes.TIMESTAMP_LTZ(3)) .watermark("ts", "ts - INTERVAL '5' SECOND") .build(); @@ -331,16 +339,17 @@ public void testTransformWithEventTimeTumblingWindows() throws Exception { Table[] outputs = agglomerativeClustering.transform(inputDataTable); List output = IteratorUtils.toList(tEnv.toDataStream(outputs[0]).executeAndCollect()); - List> actualGroups = + List> actualGroups = KMeansTest.groupFeaturesByPrediction( output, agglomerativeClustering.getFeaturesCol(), agglomerativeClustering.getPredictionCol()); boolean isAllSubSet = true; - for (Set expectedSet : EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT) { + for (Set expectedSet : + EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT) { boolean isSubset = false; - for (Set actualSet : actualGroups) { + for (Set actualSet : actualGroups) { if (actualSet.containsAll(expectedSet)) { isSubset = true; break; @@ -440,13 +449,13 @@ private void verifyMergeInfo(double[] expectedDistances, Table mergeInfoTable) @SuppressWarnings("unchecked") public void verifyClusteringResult( - List> expected, + List> expected, Table outputTable, String featureCol, String predictionCol) throws Exception { List output = IteratorUtils.toList(tEnv.toDataStream(outputTable).executeAndCollect()); - List> actualGroups = + List> actualGroups = KMeansTest.groupFeaturesByPrediction(output, featureCol, predictionCol); assertTrue(CollectionUtils.isEqualCollection(expected, actualGroups)); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java index 8f1961726..e6fbd704b 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java @@ -22,9 +22,9 @@ import org.apache.flink.ml.clustering.kmeans.KMeansModel; import org.apache.flink.ml.clustering.kmeans.KMeansModelData; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -60,7 +60,7 @@ public class KMeansTest extends AbstractTestBase { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); - private static final List DATA = + private static final List DATA = Arrays.asList( Vectors.dense(0.0, 0.0), Vectors.dense(0.0, 0.3), @@ -70,7 +70,7 @@ public class KMeansTest extends AbstractTestBase { Vectors.dense(9.6, 0.0)); private StreamExecutionEnvironment env; private StreamTableEnvironment tEnv; - private static final List> expectedGroups = + private static final List> expectedGroups = Arrays.asList( new HashSet<>( Arrays.asList( @@ -100,11 +100,11 @@ public void before() { * @param predictionCol Name of the column in the table that contains the prediction result * @return A map containing the collected results */ - protected static List> groupFeaturesByPrediction( + protected static List> groupFeaturesByPrediction( List rows, String featuresCol, String predictionCol) { - Map> map = new HashMap<>(); + Map> map = new HashMap<>(); for (Row row : rows) { - DenseVector vector = ((Vector) row.getField(featuresCol)).toDense(); + DenseIntDoubleVector vector = ((IntDoubleVector) row.getField(featuresCol)).toDense(); int predict = (Integer) row.getField(predictionCol); map.putIfAbsent(predict, new HashSet<>()); map.get(predict).add(vector); @@ -152,7 +152,7 @@ public void testOutputSchema() { @Test public void testFewerDistinctPointsThanCluster() { - List data = + List data = Arrays.asList( Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1)); @@ -162,10 +162,10 @@ public void testFewerDistinctPointsThanCluster() { KMeansModel model = kmeans.fit(input); Table output = model.transform(input)[0]; - List> expectedGroups = + List> expectedGroups = Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1))); List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); @@ -181,7 +181,7 @@ public void testFitAndPredict() { Arrays.asList("features", "prediction"), output.getResolvedSchema().getColumnNames()); List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); @@ -191,7 +191,8 @@ public void testFitAndPredict() { public void testInputTypeConversion() { dataTable = TestUtils.convertDataTypesToSparseInt(tEnv, dataTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(dataTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(dataTable)); KMeans kmeans = new KMeans().setMaxIter(2).setK(2); KMeansModel model = kmeans.fit(dataTable); @@ -201,7 +202,7 @@ public void testInputTypeConversion() { Arrays.asList("features", "prediction"), output.getResolvedSchema().getColumnNames()); List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); @@ -226,7 +227,7 @@ public void testSaveLoadAndPredict() throws Exception { output.getResolvedSchema().getColumnNames()); List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); @@ -245,7 +246,7 @@ public void testGetModelData() throws Exception { List collectedModelData = IteratorUtils.toList(modelData.executeAndCollect()); assertEquals(1, collectedModelData.size()); - DenseVector[] centroids = collectedModelData.get(0).centroids; + DenseIntDoubleVector[] centroids = collectedModelData.get(0).centroids; assertEquals(2, centroids.length); Arrays.sort(centroids, Comparator.comparingDouble(vector -> vector.get(0))); assertArrayEquals(centroids[0].values, new double[] {0.1, 0.1}, 1e-5); @@ -261,7 +262,7 @@ public void testSetModelData() { Table output = modelB.transform(dataTable)[0]; List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java index d85ab3510..0591403b2 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java @@ -31,10 +31,10 @@ import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.InMemorySinkFunction; import org.apache.flink.ml.util.InMemorySourceFunction; import org.apache.flink.ml.util.TestUtils; @@ -77,8 +77,8 @@ public class OnlineKMeansTest extends TestLogger { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); - private static final DenseVector[] trainData1 = - new DenseVector[] { + private static final DenseIntDoubleVector[] trainData1 = + new DenseIntDoubleVector[] { Vectors.dense(10.0, 0.0), Vectors.dense(10.0, 0.3), Vectors.dense(10.3, 0.0), @@ -86,8 +86,8 @@ public class OnlineKMeansTest extends TestLogger { Vectors.dense(-10.0, 0.6), Vectors.dense(-10.6, 0.0) }; - private static final DenseVector[] trainData2 = - new DenseVector[] { + private static final DenseIntDoubleVector[] trainData2 = + new DenseIntDoubleVector[] { Vectors.dense(10.0, 100.0), Vectors.dense(10.0, 100.3), Vectors.dense(10.3, 100.0), @@ -95,8 +95,8 @@ public class OnlineKMeansTest extends TestLogger { Vectors.dense(-10.0, -100.6), Vectors.dense(-10.6, -100.0) }; - private static final DenseVector[] predictData = - new DenseVector[] { + private static final DenseIntDoubleVector[] predictData = + new DenseIntDoubleVector[] { Vectors.dense(10.0, 10.0), Vectors.dense(10.3, 10.0), Vectors.dense(10.0, 10.3), @@ -104,7 +104,7 @@ public class OnlineKMeansTest extends TestLogger { Vectors.dense(-10.3, 10.0), Vectors.dense(-10.0, 10.3) }; - private static final List> expectedGroups1 = + private static final List> expectedGroups1 = Arrays.asList( new HashSet<>( Arrays.asList( @@ -116,7 +116,7 @@ public class OnlineKMeansTest extends TestLogger { Vectors.dense(-10.0, 10.0), Vectors.dense(-10.3, 10.0), Vectors.dense(-10.0, 10.3)))); - private static final List> expectedGroups2 = + private static final List> expectedGroups2 = Collections.singletonList( new HashSet<>( Arrays.asList( @@ -133,8 +133,8 @@ public class OnlineKMeansTest extends TestLogger { private int currentModelDataVersion; - private InMemorySourceFunction trainSource; - private InMemorySourceFunction predictSource; + private InMemorySourceFunction trainSource; + private InMemorySourceFunction predictSource; private InMemorySinkFunction outputSink; private InMemorySinkFunction modelDataSink; @@ -179,10 +179,12 @@ public void before() throws Exception { offlineTrainTable = tEnv.fromDataStream(env.fromElements(trainData1)).as("features"); onlineTrainTable = - tEnv.fromDataStream(env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE)) + tEnv.fromDataStream( + env.addSource(trainSource, DenseIntDoubleVectorTypeInfo.INSTANCE)) .as("features"); onlinePredictTable = - tEnv.fromDataStream(env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE)) + tEnv.fromDataStream( + env.addSource(predictSource, DenseIntDoubleVectorTypeInfo.INSTANCE)) .as("features"); } @@ -247,11 +249,13 @@ private void waitModelDataUpdate(JobID jobID) throws InterruptedException { * @param predictionCol Name of the column in the table that contains the prediction result */ private void predictAndAssert( - List> expectedGroups, String featuresCol, String predictionCol) + List> expectedGroups, + String featuresCol, + String predictionCol) throws Exception { predictSource.addAll(OnlineKMeansTest.predictData); List rawResult = outputSink.poll(OnlineKMeansTest.predictData.length); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction(rawResult, featuresCol, predictionCol); Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); } @@ -324,13 +328,13 @@ public void testInputTypeConversion() throws Exception { onlinePredictTable = TestUtils.convertDataTypesToSparseInt(tEnv, onlinePredictTable); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(offlineTrainTable)); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(onlineTrainTable)); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(onlinePredictTable)); OnlineKMeans onlineKMeans = @@ -406,7 +410,7 @@ public void testDecayFactor() throws Exception { KMeansModelData expectedModelData = new KMeansModelData( - new DenseVector[] { + new DenseIntDoubleVector[] { Vectors.dense(-10.2, -200.2 / 3), Vectors.dense(10.1, 200.3 / 3) }, Vectors.dense(4.5, 4.5)); @@ -502,7 +506,9 @@ public void testGetModelData() throws Exception { KMeansModelData expectedModelData = new KMeansModelData( - new DenseVector[] {Vectors.dense(-10.2, 0.2), Vectors.dense(10.1, 0.1)}, + new DenseIntDoubleVector[] { + Vectors.dense(-10.2, 0.2), Vectors.dense(10.1, 0.1) + }, Vectors.dense(3, 3)); assertArrayEquals(expectedModelData.weights.values, actualModelData.weights.values, 1e-5); @@ -520,12 +526,14 @@ public void testGetModelData() throws Exception { public void testSetModelData() throws Exception { KMeansModelData modelData1 = new KMeansModelData( - new DenseVector[] {Vectors.dense(10.1, 0.1), Vectors.dense(-10.2, 0.2)}, + new DenseIntDoubleVector[] { + Vectors.dense(10.1, 0.1), Vectors.dense(-10.2, 0.2) + }, Vectors.dense(0.0, 0.0)); KMeansModelData modelData2 = new KMeansModelData( - new DenseVector[] { + new DenseIntDoubleVector[] { Vectors.dense(10.1, 100.1), Vectors.dense(-10.2, -100.2) }, Vectors.dense(0.0, 0.0)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java index 22ce2deba..19d403548 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.common.lossfunc; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.junit.Test; @@ -31,8 +31,8 @@ public class BinaryLogisticLossTest { private static final LabeledPointWithWeight dataPoint = new LabeledPointWithWeight(Vectors.dense(1.0, 2.0, 3.0), 1.0, 2.0); - private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); - private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final DenseIntDoubleVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseIntDoubleVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); private static final double TOLERANCE = 1e-7; @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java index 1cd165ecf..a759962c7 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.common.lossfunc; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.junit.Test; @@ -33,8 +33,8 @@ public class HingeLossTest { new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, -1.0), 1.0, 2.0); private static final LabeledPointWithWeight dataPoint2 = new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, 1.0), 1.0, 2.0); - private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); - private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final DenseIntDoubleVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseIntDoubleVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); private static final double TOLERANCE = 1e-7; @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java index ee2d03021..e420662db 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.common.lossfunc; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.junit.Test; @@ -31,8 +31,8 @@ public class LeastSquareLossTest { private static final LabeledPointWithWeight dataPoint = new LabeledPointWithWeight(Vectors.dense(1.0, 2.0, 3.0), 1.0, 2.0); - private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); - private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final DenseIntDoubleVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseIntDoubleVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); private static final double TOLERANCE = 1e-7; @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java index 83d688337..2ce653110 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.common.optimizer; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.commons.lang3.RandomUtils; import org.junit.Test; @@ -29,7 +29,8 @@ public class RegularizationUtilsTest { private static final double learningRate = 0.1; private static final double TOLERANCE = 1e-7; - private static final DenseVector coefficient = new DenseVector(new double[] {1.0, -2.0, 0}); + private static final DenseIntDoubleVector coefficient = + new DenseIntDoubleVector(new double[] {1.0, -2.0, 0}); @Test public void testRegularization() { @@ -40,7 +41,7 @@ public void testRegularization() { } private void checkRegularization(double reg, double elasticNet, double[] expectedCoefficient) { - DenseVector clonedCoefficient = coefficient.clone(); + DenseIntDoubleVector clonedCoefficient = coefficient.clone(); RegularizationUtils.regularize(clonedCoefficient, reg, elasticNet, learningRate); assertArrayEquals(expectedCoefficient, clonedCoefficient.values, TOLERANCE); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/PSBench.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/PSBench.java new file mode 100644 index 000000000..a3dd5c105 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/PSBench.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.util.Bits; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** Benchmark for PS-related stuff. */ +public class PSBench { + @Test + public void benchBits() { + double[] result = new double[100000]; + int warmUp = 500; + int numTries = 1000; + for (int i = 0; i < result.length; i++) { + result[i] = i; + } + byte[] bytes = new byte[Bits.getDoubleArraySizeInBytes(result)]; + + for (int i = 0; i < warmUp; i++) { + Bits.putDoubleArray(result, bytes, 0); + } + + long start = System.currentTimeMillis(); + for (int i = 0; i < numTries; i++) { + Bits.putDoubleArray(result, bytes, 0); + } + long end = System.currentTimeMillis(); + System.out.println(end - start); // ~600ms + } + + @Test + public void benchTypeSerializer() throws IOException { + double[] result = new double[100000]; + int warmUp = 500; + int numTries = 1000; + for (int i = 0; i < result.length; i++) { + result[i] = i; + } + byte[] bytes = new byte[Bits.getDoubleArraySizeInBytes(result)]; + + for (int i = 0; i < warmUp; i++) { + bytes = serializeDoubleArray(result); + } + + long start = System.currentTimeMillis(); + for (int i = 0; i < numTries; i++) { + bytes = serializeDoubleArray(result); + } + long end = System.currentTimeMillis(); + System.out.println(end - start); // 2000ms + Assert.assertEquals(Bits.getDoubleArraySizeInBytes(result), bytes.length); + } + + private byte[] serializeDoubleArray(double[] result) throws IOException { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(byteArrayOutputStream); + dataOutputViewStreamWrapper.writeInt(result.length); + + for (double value : result) { + dataOutputViewStreamWrapper.writeDouble(value); + } + return byteArrayOutputStream.toByteArray(); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java new file mode 100644 index 000000000..a2ca98236 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.test.util.TestBaseUtils; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableSupplier; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +/** Tests {@link TrainingUtils}. */ +public class TrainingUtilsTest { + private static final int NUM_WORKERS = 2; + private static final int NUM_SERVERS = 2; + private static final int MAX_ITER = 3; + private static final int NUM_COLUMNS_PER_KEY = 2; + + private DataStream maxKey; + private DataStream inputData; + + @Before + public void before() { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + env.setParallelism(NUM_WORKERS); + maxKey = env.fromElements(3L); + inputData = + env.fromCollection( + Arrays.asList( + Vectors.dense(1, 1, 1, 1), + Vectors.dense(2, 2, 2, 2), + Vectors.dense(3, 3, 3, 3), + Vectors.dense(4, 4, 4, 4))) + .map(x -> x, DenseIntDoubleVectorTypeInfo.INSTANCE); + } + + @Test + public void test() throws Exception { + ExecutionConfig config = maxKey.getExecutionEnvironment().getConfig(); + + TypeSerializer pojoDemoTypeSerializer = + Types.POJO(MockPojo.class).createSerializer(config); + + MockSession mockSession = + new MockSession( + DenseIntDoubleVectorTypeInfo.INSTANCE, + Collections.singletonList( + new OutputTag<>("AllReduceOutputTag", Types.POJO(MockPojo.class)))); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage(new MockComputePullIndicesStage()) + .addStage( + new PullStage( + (SerializableSupplier) + () -> mockSession.pullIndices, + (SerializableConsumer) + x -> mockSession.pulledValues = x)) + .addStage( + new AllReduceStage<>( + (SerializableSupplier) + () -> mockSession.toAllReduce, + (SerializableConsumer) + x -> mockSession.toAllReduce = x, + (ReduceFunction) TrainingUtilsTest::sumPojo, + pojoDemoTypeSerializer)) + .addStage(new MockComputePushValuesStage(NUM_COLUMNS_PER_KEY)) + .addStage( + new PushStage( + (SerializableSupplier) + () -> mockSession.pushIndices, + (SerializableSupplier) + () -> mockSession.pushValues)) + .setTerminationCriteria(context -> context.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + maxKey, + new TupleTypeInfo<>( + Types.LONG, + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_COLUMNS_PER_KEY), + NUM_SERVERS); + + // Verifies the model data. + DataStream> modelSegments = resultList.get(0); + List> collectedModelPieces = + IteratorUtils.toList(modelSegments.executeAndCollect()); + Assert.assertEquals(NUM_SERVERS, collectedModelPieces.size()); + collectedModelPieces.sort(Comparator.comparing(o -> o.f0)); + + double[] result = new double[4 * NUM_COLUMNS_PER_KEY]; + double[] expectedResult = new double[4 * NUM_COLUMNS_PER_KEY]; + Arrays.fill(expectedResult, 0, NUM_COLUMNS_PER_KEY, 35.0); + Arrays.fill(expectedResult, 3 * NUM_COLUMNS_PER_KEY, 4 * NUM_COLUMNS_PER_KEY, 35.0); + for (Tuple3 modelPiece : collectedModelPieces) { + int startIndex = (int) (long) modelPiece.f0 * NUM_COLUMNS_PER_KEY; + double[] pieceCoeff = modelPiece.f2; + System.arraycopy(pieceCoeff, 0, result, startIndex, pieceCoeff.length); + } + Assert.assertArrayEquals(expectedResult, result, 1e-7); + + // Verifies the all reduce result from worker output. + DataStream allReduceResult = resultList.get(1); + allReduceResult.getTransformation().setParallelism(1); + List collectedPojo = IteratorUtils.toList(allReduceResult.executeAndCollect()); + List expectedPojo = + Arrays.asList(new MockPojo(1, 0), new MockPojo(2, 0), new MockPojo(4, 0)); + TestBaseUtils.compareResultCollections( + expectedPojo, + collectedPojo, + new Comparator() { + @Override + public int compare(MockPojo o1, MockPojo o2) { + return Integer.compare(o1.i, o2.i); + } + }); + } + + private static MockPojo[] sumPojo(MockPojo[] d1, MockPojo[] d2) { + Preconditions.checkArgument(d1.length == d2.length); + for (int i = 0; i < d1.length; i++) { + d2[i].i += d1[i].i; + d2[i].j += d1[i].j; + } + return d2; + } + + private static class MockSession extends MiniBatchMLSession { + + public MockPojo[] toAllReduce; + private ProxySideOutput output; + + @Override + public void setOutput(ProxySideOutput collector) { + this.output = collector; + } + + public MockSession( + TypeInformation typeInformation, + List> outputTags) { + super(0, typeInformation, outputTags); + } + } + + /** Pulls the 0-th and 3-th dimension of the model from servers. */ + private static class MockComputePullIndicesStage extends ProcessStage { + + @Override + public void process(MockSession context) { + context.pullIndices = new long[] {0, 3}; + if (context.toAllReduce == null) { + context.toAllReduce = new MockPojo[1]; + context.toAllReduce[0] = new MockPojo(1, 0); + } + if (context.workerId == 0) { + context.output.output( + (OutputTag) context.getOutputTags().get(0), + new StreamRecord(context.toAllReduce[0])); + } + } + } + + /** + * Adds the 0-th and 3-th dimension of all training data to the model and pushes it to servers. + */ + private static class MockComputePushValuesStage extends ProcessStage { + private final int numCols; + + public MockComputePushValuesStage(int numCols) { + this.numCols = numCols; + } + + @Override + public void process(MockSession context) throws Exception { + long[] indices = context.pullIndices; + double[] values = context.pulledValues; + ResettableIterator data = context.inputData; + while (data.hasNext()) { + double[] d = data.next().values; + for (int i = 0; i < indices.length; i++) { + double v = d[(int) indices[i]]; + for (int j = 0; j < numCols; j++) { + values[i * numCols + j] += v; + } + } + } + data.reset(); + + BLAS.scal(1.0 / context.numWorkers, new DenseIntDoubleVector(values)); + + context.pushIndices = indices; + context.pushValues = values; + } + } + + /** The logic on servers. */ + private static class MockModelUpdater implements ModelUpdater> { + private final int numDoublesPerKey; + private long startIndex; + private long endIndex; + private double[] model; + + private ListState boundaryState; + private ListState modelDataState; + + public MockModelUpdater(int numDoublesPerKey) { + this.numDoublesPerKey = numDoublesPerKey; + } + + @Override + public void open(long startKeyIndex, long endKeyIndex) { + this.startIndex = startKeyIndex; + this.endIndex = endKeyIndex; + this.model = new double[(int) (endKeyIndex - startKeyIndex) * numDoublesPerKey]; + } + + @Override + public void update(long[] keys, double[] values) { + Preconditions.checkState(keys.length * numDoublesPerKey == values.length); + for (int i = 0; i < keys.length; i++) { + int index = (int) (keys[i] - startIndex); + for (int j = 0; j < numDoublesPerKey; j++) { + model[index * numDoublesPerKey + j] += values[i * numDoublesPerKey + j]; + } + } + } + + @Override + public double[] get(long[] keys) { + double[] values = new double[keys.length * numDoublesPerKey]; + for (int i = 0; i < keys.length; i++) { + int index = (int) (keys[i] - startIndex); + for (int j = 0; j < numDoublesPerKey; j++) { + values[i * numDoublesPerKey + j] += model[index * numDoublesPerKey + j]; + } + } + return values; + } + + @Override + public Iterator> getModelSegments() { + return Collections.singleton(Tuple3.of(startIndex, endIndex, model)).iterator(); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + boundaryState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("BoundaryState", Types.LONG)); + + Iterator iterator = boundaryState.get().iterator(); + if (iterator.hasNext()) { + startIndex = iterator.next(); + endIndex = iterator.next(); + } + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelDataState", + PrimitiveArrayTypeInfo + .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); + Iterator modelData = modelDataState.get().iterator(); + if (modelData.hasNext()) { + model = modelData.next(); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + if (model != null) { + boundaryState.clear(); + boundaryState.add(startIndex); + boundaryState.add(endIndex); + + modelDataState.clear(); + modelDataState.add(model); + } + } + } + + /** Mock pojo class to test all reduce. */ + public static class MockPojo { + public int i; + public int j; + + public MockPojo(int i, int j) { + this.i = i; + this.j = j; + } + + public MockPojo() {} + + @Override + public String toString() { + return i + "-" + j; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof MockPojo) { + MockPojo other = (MockPojo) obj; + return i == other.i && j == other.j; + } + return false; + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java index 027be970c..521f67b21 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java @@ -18,8 +18,8 @@ package org.apache.flink.ml.common.util; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.junit.Test; @@ -33,13 +33,13 @@ public class VectorUtilsTest { @Test public void testSelectByIndices() { - DenseVector denseVector = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0); + DenseIntDoubleVector denseVector = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0); assertArrayEquals( Vectors.dense(2.0, 4.0).toArray(), VectorUtils.selectByIndices(denseVector, new int[] {1, 3}).toArray(), EPS); - SparseVector sparseVector = + SparseIntDoubleVector sparseVector = Vectors.sparse(5, new int[] {1, 2, 3}, new double[] {2.0, 3.0, 4.0}); assertArrayEquals( Vectors.sparse(3, new int[] {1, 2}, new double[] {2.0, 4.0}).toArray(), diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java index 0c146a3d2..7a2f90a10 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluator; import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluatorParams; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -191,7 +191,7 @@ public void testEvaluate() { public void testInputTypeConversion() { inputDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, inputDataTable); assertArrayEquals( - new Class[] {Integer.class, SparseVector.class}, + new Class[] {Integer.class, SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(inputDataTable)); BinaryClassificationEvaluator eval = diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java index eb06b0c6e..17a502972 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.binarizer.Binarizer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -65,11 +65,11 @@ public class BinarizerTest extends AbstractTestBase { private static final Double[] EXPECTED_VALUE_OUTPUT = new Double[] {0.0, 1.0, 1.0}; - private static final List EXPECTED_DENSE_OUTPUT = + private static final List EXPECTED_DENSE_OUTPUT = Arrays.asList( Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0), Vectors.dense(1.0, 1.0)); - private static final List EXPECTED_SPARSE_OUTPUT = + private static final List EXPECTED_SPARSE_OUTPUT = Arrays.asList( Vectors.sparse(17, new int[] {9}, new double[] {1.0}), Vectors.sparse(17, new int[] {0, 2}, new double[] {1.0, 1.0}), @@ -90,8 +90,8 @@ private void verifyOutputResult(Table output, String[] outputCols) throws Except List results = IteratorUtils.toList(stream.executeAndCollect()); List doubleValues = new ArrayList<>(results.size()); - List sparseVectorValues = new ArrayList<>(results.size()); - List denseVectorValues = new ArrayList<>(results.size()); + List sparseVectorValues = new ArrayList<>(results.size()); + List denseVectorValues = new ArrayList<>(results.size()); for (Row row : results) { doubleValues.add(row.getFieldAs(outputCols[0])); denseVectorValues.add(row.getFieldAs(outputCols[1])); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java index 32d58ccc4..c751e604b 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.feature.countvectorizer.CountVectorizer; import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -68,7 +68,7 @@ public class CountVectorizerTest extends AbstractTestBase { Row.of((Object) new String[] {"e", "f"}), Row.of((Object) new String[] {"a", "c", "a"}))); - private static final List EXPECTED_OUTPUT = + private static final List EXPECTED_OUTPUT = new ArrayList<>( Arrays.asList( Vectors.sparse( @@ -101,15 +101,15 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expected) throws Exception { + Table output, String outputCol, List expected) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); - DataStream stream = + DataStream stream = tEnv.toDataStream(output) .map( - (MapFunction) - row -> (SparseVector) row.getField(outputCol)); - List result = IteratorUtils.toList(stream.executeAndCollect()); + (MapFunction) + row -> (SparseIntDoubleVector) row.getField(outputCol)); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expected, result, TestUtils::compare); } @@ -244,7 +244,7 @@ public void testFitOnEmptyData() { @Test public void testMinMaxDF() throws Exception { - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.sparse( @@ -280,7 +280,7 @@ public void testMinMaxDF() throws Exception { @Test public void testMinTF() throws Exception { - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.sparse( @@ -305,7 +305,7 @@ public void testMinTF() throws Exception { @Test public void testBinary() throws Exception { - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.sparse( @@ -336,7 +336,7 @@ public void testBinary() throws Exception { @Test public void testVocabularySize() throws Exception { - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.sparse( diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java index 36baea68d..7c1f71d0b 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java @@ -19,8 +19,8 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.dct.DCT; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -48,7 +48,7 @@ public class DCTTest extends AbstractTestBase { private StreamExecutionEnvironment env; private StreamTableEnvironment tEnv; - private static final List inputData = + private static final List inputData = Arrays.asList(Vectors.dense(1.0, 1.0, 1.0, 1.0), Vectors.dense(1.0, 0.0, -1.0, 0.0)); private static final List expectedForwardOutputData = @@ -126,7 +126,8 @@ public void testTransformInverse() { public void testInputTypeConversion() throws Exception { inputTable = TestUtils.convertDataTypesToSparseInt(tEnv, inputTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(inputTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(inputTable)); DCT dct = new DCT(); Table outputTable = dct.transform(inputTable)[0]; @@ -159,21 +160,21 @@ private static void verifyTransformResult( actualOutputData.sort( Comparator.comparingLong( x -> - ((Vector) Objects.requireNonNull(x.getField(inputCol))) + ((IntDoubleVector) Objects.requireNonNull(x.getField(inputCol))) .toDense() .hashCode())); expectedOutputData.sort( Comparator.comparingLong( x -> - ((Vector) Objects.requireNonNull(x.getField(0))) + ((IntDoubleVector) Objects.requireNonNull(x.getField(0))) .toDense() .hashCode())); assertEquals(actualOutputData.size(), expectedOutputData.size()); for (int i = 0; i < actualOutputData.size(); i++) { - Vector actualVector = actualOutputData.get(i).getFieldAs(outputCol); - Vector expectedVector = expectedOutputData.get(i).getFieldAs(1); + IntDoubleVector actualVector = actualOutputData.get(i).getFieldAs(outputCol); + IntDoubleVector expectedVector = expectedOutputData.get(i).getFieldAs(1); assertArrayEquals(expectedVector.toArray(), actualVector.toArray(), 1e-3); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java index 4b55f0dde..894b510e4 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java @@ -19,8 +19,8 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -90,28 +90,30 @@ private void verifyOutputResult(Table output, String outputCol, boolean isSparse for (Row result : results) { if (result.getField(0) == (Object) 0) { if (isSparse) { - SparseVector sparseVector = (SparseVector) result.getField(outputCol); - assertEquals(EXPECTED_OUTPUT_SPARSE_VEC_SIZE_1, sparseVector.size()); + SparseIntDoubleVector sparseVector = + (SparseIntDoubleVector) result.getField(outputCol); + assertEquals(EXPECTED_OUTPUT_SPARSE_VEC_SIZE_1, sparseVector.size().intValue()); assertArrayEquals(EXPECTED_OUTPUT_SPARSE_VEC_INDICES_1, sparseVector.indices); assertArrayEquals( EXPECTED_OUTPUT_SPARSE_VEC_VALUES_1, sparseVector.values, 1.0e-5); } else { assertArrayEquals( EXPECTED_OUTPUT_DENSE_VEC_ARRAY_1, - ((DenseVector) result.getField(outputCol)).values, + ((DenseIntDoubleVector) result.getField(outputCol)).values, 1.0e-5); } } else if (result.getField(0) == (Object) 1) { if (isSparse) { - SparseVector sparseVector = (SparseVector) result.getField(outputCol); - assertEquals(EXPECTED_OUTPUT_SPARSE_VEC_SIZE_2, sparseVector.size()); + SparseIntDoubleVector sparseVector = + (SparseIntDoubleVector) result.getField(outputCol); + assertEquals(EXPECTED_OUTPUT_SPARSE_VEC_SIZE_2, sparseVector.size().intValue()); assertArrayEquals(EXPECTED_OUTPUT_SPARSE_VEC_INDICES_2, sparseVector.indices); assertArrayEquals( EXPECTED_OUTPUT_SPARSE_VEC_VALUES_2, sparseVector.values, 1.0e-5); } else { assertArrayEquals( EXPECTED_OUTPUT_DENSE_VEC_ARRAY_2, - ((DenseVector) result.getField(outputCol)).values, + ((DenseIntDoubleVector) result.getField(outputCol)).values, 1.0e-5); } } else if (result.getField(0) == (Object) 2) { diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java index fae16692d..c636486c6 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.featurehasher.FeatureHasher; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -48,9 +48,9 @@ public class FeatureHasherTest extends AbstractTestBase { private static final List INPUT_DATA = Arrays.asList(Row.of(0, "a", 1.0, true), Row.of(1, "c", 1.0, false)); - private static final SparseVector EXPECTED_OUTPUT_DATA_1 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_1 = Vectors.sparse(1000, new int[] {607, 635, 913}, new double[] {1.0, 1.0, 1.0}); - private static final SparseVector EXPECTED_OUTPUT_DATA_2 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_2 = Vectors.sparse(1000, new int[] {242, 869, 913}, new double[] {1.0, 1.0, 1.0}); @Before diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java index 9bed73ec8..287b5ec7a 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java @@ -21,7 +21,7 @@ import org.apache.flink.ml.feature.idf.IDF; import org.apache.flink.ml.feature.idf.IDFModel; import org.apache.flink.ml.feature.idf.IDFModelData; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -54,12 +54,12 @@ public class IDFTest extends AbstractTestBase { private StreamTableEnvironment tEnv; private Table inputTable; - private static final List expectedOutput = + private static final List expectedOutput = Arrays.asList( Vectors.dense(0, 0, 0, 0.5753641), Vectors.dense(0, 0, 1.3862943, 0.8630462), Vectors.dense(0, 0, 0, 0)); - private static final List expectedOutputMinDocFreqAsTwo = + private static final List expectedOutputMinDocFreqAsTwo = Arrays.asList( Vectors.dense(0, 0, 0, 0.5753641), Vectors.dense(0, 0, 0, 0.8630462), @@ -71,7 +71,7 @@ public void before() { env = TestUtils.getExecutionEnvironment(); tEnv = StreamTableEnvironment.create(env); - List input = + List input = Arrays.asList( Vectors.dense(0, 1, 0, 2), Vectors.dense(0, 1, 2, 3), @@ -81,11 +81,12 @@ public void before() { @SuppressWarnings("unchecked") private void verifyPredictionResult( - List expectedOutput, Table output, String predictionCol) throws Exception { + List expectedOutput, Table output, String predictionCol) + throws Exception { List collectedResult = IteratorUtils.toList( tEnv.toDataStream(output.select($(predictionCol))).executeAndCollect()); - List actualOutputs = new ArrayList<>(expectedOutput.size()); + List actualOutputs = new ArrayList<>(expectedOutput.size()); collectedResult.forEach(x -> actualOutputs.add((x.getFieldAs(0)))); actualOutputs.sort(TestUtils::compare); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java index 0b0cc9957..59a672ac2 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java @@ -19,9 +19,9 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.interaction.Interaction; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -63,20 +63,20 @@ public class InteractionTest extends AbstractTestBase { Vectors.sparse(17, new int[] {0, 2, 14}, new double[] {5.0, 4.0, 1.0})), Row.of(3, null, null, null)); - private static final List EXPECTED_DENSE_OUTPUT = + private static final List EXPECTED_DENSE_OUTPUT = Arrays.asList( - new DenseVector(new double[] {3.0, 4.0, 6.0, 8.0}), - new DenseVector(new double[] {12.0, 16.0, 20.0, 48.0, 64.0, 80.0})); + new DenseIntDoubleVector(new double[] {3.0, 4.0, 6.0, 8.0}), + new DenseIntDoubleVector(new double[] {12.0, 16.0, 20.0, 48.0, 64.0, 80.0})); - private static final List EXPECTED_SPARSE_OUTPUT = + private static final List EXPECTED_SPARSE_OUTPUT = Arrays.asList( - new SparseVector( + new SparseIntDoubleVector( 68, new int[] {0, 3, 9, 17, 20, 26, 34, 37, 43, 51, 54, 60}, new double[] { 3.0, 6.0, 21.0, 4.0, 8.0, 28.0, 6.0, 12.0, 42.0, 8.0, 16.0, 56.0 }), - new SparseVector( + new SparseIntDoubleVector( 102, new int[] { 0, 2, 14, 17, 19, 31, 34, 36, 48, 51, 53, 65, 68, 70, 82, 85, 87, 99 @@ -94,14 +94,14 @@ public void before() { inputDataTable = tEnv.fromDataStream(dataStream).as("f0", "f1", "f2", "f3"); } - private void verifyOutputResult(Table output, String outputCol, List expectedData) - throws Exception { + private void verifyOutputResult( + Table output, String outputCol, List expectedData) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); DataStream stream = tEnv.toDataStream(output); List results = IteratorUtils.toList(stream.executeAndCollect()); - List resultVec = new ArrayList<>(results.size()); + List resultVec = new ArrayList<>(results.size()); for (Row row : results) { if (row.getField(outputCol) != null) { resultVec.add(row.getFieldAs(outputCol)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java index e6959baae..99786659c 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java @@ -22,7 +22,7 @@ import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -138,7 +138,8 @@ private void verifyPredictionResult( collectedResult, (o1, o2) -> TestUtils.compare( - (DenseVector) o1.getField(0), (DenseVector) o2.getField(0))); + (DenseIntDoubleVector) o1.getField(0), + (DenseIntDoubleVector) o2.getField(0))); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java index 6b8edf05e..037115660 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java @@ -21,7 +21,8 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler; import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -94,14 +95,14 @@ public class MaxAbsScalerTest { Row.of(Vectors.sparse(4, new int[] {}, new double[] {})), Row.of(Vectors.sparse(4, new int[] {1, 3}, new double[] {1.0, 2.0})))); - private static final List EXPECTED_DATA = + private static final List EXPECTED_DATA = new ArrayList<>( Arrays.asList( Vectors.dense(0.25, 0.1, 1.0), Vectors.dense(0.5, 0.125, 0.5), Vectors.dense(0.75, 0.225, 1.0))); - private static final List EXPECTED_SPARSE_DATA = + private static final List EXPECTED_SPARSE_DATA = new ArrayList<>( Arrays.asList( Vectors.sparse(4, new int[] {0, 1}, new double[] {1.0, 0.5}), @@ -124,7 +125,7 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expectedData) throws Exception { + Table output, String outputCol, List expectedData) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); @@ -134,7 +135,7 @@ private static void verifyPredictionResult( (MapFunction) row -> row.getFieldAs(outputCol), VectorTypeInfo.INSTANCE); - List result = IteratorUtils.toList(stream.executeAndCollect()); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expectedData, result, TestUtils::compare); } @@ -237,7 +238,8 @@ public void testGetModelData() throws Exception { DataStream output = tEnv.toDataStream(modelData); List modelRows = IteratorUtils.toList(output.executeAndCollect()); assertEquals( - new DenseVector(new double[] {200.0, 400.0, 0.0}), modelRows.get(0).getField(0)); + new DenseIntDoubleVector(new double[] {200.0, 400.0, 0.0}), + modelRows.get(0).getField(0)); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java index d92f062d0..442a4c3a7 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java @@ -21,9 +21,9 @@ import org.apache.flink.ml.feature.lsh.MinHashLSH; import org.apache.flink.ml.feature.lsh.MinHashLSHModel; import org.apache.flink.ml.feature.lsh.MinHashLSHModelData; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -90,10 +90,10 @@ private static List convertToOutputFormat(List arrays) { return arrays.stream() .map( array -> { - DenseVector[] denseVectors = + DenseIntDoubleVector[] denseVectors = Arrays.stream(array) .map(Vectors::dense) - .toArray(DenseVector[]::new); + .toArray(DenseIntDoubleVector[]::new); return Row.of((Object) denseVectors); }) .collect(Collectors.toList()); @@ -133,7 +133,7 @@ public void before() { Schema schema = Schema.newBuilder() .column("f0", DataTypes.INT()) - .column("f1", DataTypes.of(SparseVector.class)) + .column("f1", DataTypes.of(SparseIntDoubleVector.class)) .build(); DataStream dataStream = env.fromCollection(inputRows); @@ -144,8 +144,9 @@ public void before() { public void testHashFunction() { MinHashLSHModelData lshModelData = new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); - Vector vec = Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); - DenseVector[] result = lshModelData.hashFunction(vec); + IntDoubleVector vec = + Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); + DenseIntDoubleVector[] result = lshModelData.hashFunction(vec); Assert.assertEquals(3, result.length); Assert.assertEquals(Vectors.dense(1.), result[0]); Assert.assertEquals(Vectors.dense(5.), result[1]); @@ -158,9 +159,10 @@ public void testHashFunctionEqualWithSparseDenseVector() { // least non-zero index. MinHashLSHModelData lshModelData = MinHashLSHModelData.generateModelData(3, 1, 10, 2022L); new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); - Vector vec = Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); - DenseVector[] denseResults = lshModelData.hashFunction(vec.toDense()); - DenseVector[] sparseResults = lshModelData.hashFunction(vec.toSparse()); + IntDoubleVector vec = + Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); + DenseIntDoubleVector[] denseResults = lshModelData.hashFunction(vec.toDense()); + DenseIntDoubleVector[] sparseResults = lshModelData.hashFunction(vec.toSparse()); Assert.assertArrayEquals(denseResults, sparseResults); } @@ -168,7 +170,7 @@ public void testHashFunctionEqualWithSparseDenseVector() { public void testHashFunctionWithEmptyVector() { MinHashLSHModelData lshModelData = new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); - Vector vec = Vectors.sparse(10, new int[] {}, new double[] {}); + IntDoubleVector vec = Vectors.sparse(10, new int[] {}, new double[] {}); lshModelData.hashFunction(vec); } @@ -375,7 +377,7 @@ public void testApproxNearestNeighbors() { MinHashLSHModel lshModel = lsh.fit(inputTable); List expected = Arrays.asList(Row.of(0, .75), Row.of(1, .75)); - Vector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1.0, 1.0}); + IntDoubleVector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1.0, 1.0}); Table output = lshModel.approxNearestNeighbors(inputTable, key, 2).select($("id"), $("distCol")); List results = IteratorUtils.toList(output.execute().collect()); @@ -408,7 +410,7 @@ public void testApproxSimilarityJoin() { Schema schema = Schema.newBuilder() .column("f0", DataTypes.INT()) - .column("f1", DataTypes.of(SparseVector.class)) + .column("f1", DataTypes.of(SparseIntDoubleVector.class)) .build(); Table dataB = tEnv.fromDataStream(env.fromCollection(inputRowsB), schema).as("id", "vec"); @@ -426,9 +428,9 @@ public void testApproxSimilarityJoin() { .thenComparingDouble(r -> r.getFieldAs(2))); } - private static class DenseVectorArrayComparator implements Comparator { + private static class DenseVectorArrayComparator implements Comparator { @Override - public int compare(DenseVector[] o1, DenseVector[] o2) { + public int compare(DenseIntDoubleVector[] o1, DenseIntDoubleVector[] o2) { if (o1.length != o2.length) { return o1.length - o2.length; } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java index 324516359..f75b5f8a7 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java @@ -21,8 +21,8 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -70,7 +70,7 @@ public class MinMaxScalerTest extends AbstractTestBase { Row.of(Vectors.dense(50.0, 40.0)), Row.of(Vectors.dense(100.0, 50.0)))); private static final double EPS = 1.0e-5; - private static final List EXPECTED_DATA = + private static final List EXPECTED_DATA = new ArrayList<>( Arrays.asList( Vectors.dense(0.25, 0.1), @@ -86,15 +86,15 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expected) throws Exception { + Table output, String outputCol, List expected) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); - DataStream stream = + DataStream stream = tEnv.toDataStream(output) .map( - (MapFunction) - row -> (DenseVector) row.getField(outputCol)); - List result = IteratorUtils.toList(stream.executeAndCollect()); + (MapFunction) + row -> (DenseIntDoubleVector) row.getField(outputCol)); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expected, result, TestUtils::compare); } @@ -158,9 +158,10 @@ public void testInputTypeConversion() throws Exception { trainDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, trainDataTable); predictDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictDataTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(trainDataTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(trainDataTable)); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(predictDataTable)); MinMaxScaler minMaxScaler = new MinMaxScaler(); @@ -202,8 +203,11 @@ public void testGetModelData() throws Exception { modelData.getResolvedSchema().getColumnNames()); DataStream output = tEnv.toDataStream(modelData); List modelRows = IteratorUtils.toList(output.executeAndCollect()); - assertEquals(new DenseVector(new double[] {0.0, 0.0}), modelRows.get(0).getField(0)); - assertEquals(new DenseVector(new double[] {200.0, 400.0}), modelRows.get(0).getField(1)); + assertEquals( + new DenseIntDoubleVector(new double[] {0.0, 0.0}), modelRows.get(0).getField(0)); + assertEquals( + new DenseIntDoubleVector(new double[] {200.0, 400.0}), + modelRows.get(0).getField(1)); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java index 13e28c68f..7dfd1fa44 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.normalizer.Normalizer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -56,7 +56,7 @@ public class NormalizerTest extends AbstractTestBase { Vectors.dense(2.3, 4.1, 1.3, 2.4, 5.1, 4.1), Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {0.1, 0.2, 0.3}))); - private static final List EXPECTED_DENSE_OUTPUT = + private static final List EXPECTED_DENSE_OUTPUT = Arrays.asList( Vectors.dense( 0.17386300895299714, @@ -73,7 +73,7 @@ public class NormalizerTest extends AbstractTestBase { 0.4608889965995767, 0.3705186051094636)); - private static final List EXPECTED_SPARSE_OUTPUT = + private static final List EXPECTED_SPARSE_OUTPUT = Arrays.asList( Vectors.sparse( 5, @@ -96,14 +96,14 @@ public void before() { inputDataTable = tEnv.fromDataStream(dataStream).as("denseVec", "sparseVec"); } - private void verifyOutputResult(Table output, String outputCol, List expectedData) - throws Exception { + private void verifyOutputResult( + Table output, String outputCol, List expectedData) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); DataStream stream = tEnv.toDataStream(output); List results = IteratorUtils.toList(stream.executeAndCollect()); - List resultVec = new ArrayList<>(results.size()); + List resultVec = new ArrayList<>(results.size()); for (Row row : results) { if (row.getField(outputCol) != null) { resultVec.add(row.getFieldAs(outputCol)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java index e5a6726a8..c588d0e13 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -59,7 +59,7 @@ public class OneHotEncoderTest extends AbstractTestBase { private StreamTableEnvironment tEnv; private Table trainTable; private Table predictTable; - private Map[] expectedOutput; + private Map[] expectedOutput; private OneHotEncoder estimator; @Before @@ -77,7 +77,7 @@ public void before() { expectedOutput = new HashMap[] { - new HashMap() { + new HashMap() { { put(0.0, Vectors.sparse(2, new int[] {0}, new double[] {1.0})); put(1.0, Vectors.sparse(2, new int[] {1}, new double[] {1.0})); @@ -99,9 +99,9 @@ public void before() { * @param outputCols Name of the output columns containing one-hot encoding result * @return An array of map containing the collected results for each input column */ - private static Map[] executeAndCollect( + private static Map[] executeAndCollect( Table table, String[] inputCols, String[] outputCols) { - Map[] maps = new HashMap[inputCols.length]; + Map[] maps = new HashMap[inputCols.length]; for (int i = 0; i < inputCols.length; i++) { maps[i] = new HashMap<>(); } @@ -110,7 +110,7 @@ private static Map[] executeAndCollect( for (int i = 0; i < inputCols.length; i++) { maps[i].put( ((Number) row.getField(inputCols[i])).doubleValue(), - (Vector) row.getField(outputCols[i])); + (IntDoubleVector) row.getField(outputCols[i])); } } return maps; @@ -143,7 +143,7 @@ public void testParam() { public void testFitAndPredict() { OneHotEncoderModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -158,7 +158,7 @@ public void testInputTypeConversion() throws Exception { OneHotEncoderModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -169,7 +169,7 @@ public void testDropLast() { expectedOutput = new HashMap[] { - new HashMap() { + new HashMap() { { put(0.0, Vectors.sparse(3, new int[] {0}, new double[] {1.0})); put(1.0, Vectors.sparse(3, new int[] {1}, new double[] {1.0})); @@ -180,7 +180,7 @@ public void testDropLast() { OneHotEncoderModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -196,7 +196,7 @@ public void testInputDataType() { expectedOutput = new HashMap[] { - new HashMap() { + new HashMap() { { put(0.0, Vectors.sparse(2, new int[] {0}, new double[] {1.0})); put(1.0, Vectors.sparse(2, new int[] {1}, new double[] {1.0})); @@ -207,7 +207,7 @@ public void testInputDataType() { OneHotEncoderModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -276,7 +276,7 @@ public void testSaveLoad() throws Exception { tempFolder.newFolder().getAbsolutePath(), OneHotEncoderModel::load); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -301,7 +301,7 @@ public void testSetModelData() { ParamUtils.updateExistingParams(modelB, modelA.getParamMap()); Table outputTable = modelB.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, modelB.getInputCols(), modelB.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java index 303713cd1..fdb1515f8 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java @@ -32,10 +32,10 @@ import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -115,7 +115,10 @@ public void before() { inputStream, Schema.newBuilder() .column("f0", DataTypes.BIGINT()) - .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .column( + "f1", + DataTypes.RAW( + DenseIntDoubleVectorTypeInfo.INSTANCE)) .build()) .as("id", "input"); @@ -136,7 +139,7 @@ public Row map(Row value) throws Exception { }, new RowTypeInfo( new TypeInformation[] { - Types.LONG, DenseVectorTypeInfo.INSTANCE + Types.LONG, DenseIntDoubleVectorTypeInfo.INSTANCE }, new String[] {"id", "input"})) .setParallelism(1); @@ -155,7 +158,10 @@ public Row map(Row value) throws Exception { inputStreamWithEventTime, Schema.newBuilder() .column("f0", DataTypes.BIGINT()) - .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .column( + "f1", + DataTypes.RAW( + DenseIntDoubleVectorTypeInfo.INSTANCE)) .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") .watermark("rowtime", "SOURCE_WATERMARK()") .build()) @@ -295,7 +301,7 @@ public void testFitAndPredictWithGlobalWindow() throws Exception { .as("input"); // Tests withMean option. - List expectedResWithMean = + List expectedResWithMean = Arrays.asList( Vectors.dense(-2.8, 8, 1), Vectors.dense(1.1, -6, 1), @@ -304,7 +310,7 @@ public void testFitAndPredictWithGlobalWindow() throws Exception { verifyPredictionResult(expectedResWithMean, output, standardScaler.getOutputCol()); // Tests withStd option. - List expectedResWithStd = + List expectedResWithStd = Arrays.asList( Vectors.dense(-1.0231819, 1.2480754, 0.5773502), Vectors.dense(0.5729819, -0.6933752, 0.5773503), @@ -313,7 +319,7 @@ public void testFitAndPredictWithGlobalWindow() throws Exception { verifyPredictionResult(expectedResWithStd, output, standardScaler.getOutputCol()); // Tests withMean, withStd Option. - List expectedResWithMeanAndStd = + List expectedResWithMeanAndStd = Arrays.asList( Vectors.dense(-1.1459637, 1.1094004, 0.5773503), Vectors.dense(0.45020003, -0.8320503, 0.5773503), @@ -423,13 +429,14 @@ private void verifyUsedModelVersion( @SuppressWarnings("unchecked") private void verifyPredictionResult( - List expectedOutput, Table output, String predictionCol) throws Exception { + List expectedOutput, Table output, String predictionCol) + throws Exception { List collectedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); - List predictions = new ArrayList<>(collectedResult.size()); + List predictions = new ArrayList<>(collectedResult.size()); for (Row r : collectedResult) { - Vector vec = (Vector) r.getField(predictionCol); + IntDoubleVector vec = (IntDoubleVector) r.getField(predictionCol); predictions.add(vec.toDense()); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java index 5454eabff..9b06c0860 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -56,19 +56,19 @@ public class PolynomialExpansionTest extends AbstractTestBase { Vectors.dense(2.0, 3.0), Vectors.sparse(5, new int[] {1, 4}, new double[] {2.0, 1.0}))); - private static final List EXPECTED_DENSE_OUTPUT = + private static final List EXPECTED_DENSE_OUTPUT = Arrays.asList( Vectors.dense(1.0, 1.0, 2.0, 2.0, 4.0, 3.0, 3.0, 6.0, 9.0), Vectors.dense(2.0, 4.0, 3.0, 6.0, 9.0)); - private static final List EXPECTED_DENSE_OUTPUT_WITH_DEGREE_3 = + private static final List EXPECTED_DENSE_OUTPUT_WITH_DEGREE_3 = Arrays.asList( Vectors.dense( 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 4.0, 4.0, 8.0, 3.0, 3.0, 3.0, 6.0, 6.0, 12.0, 9.0, 9.0, 18.0, 27.0), Vectors.dense(2.0, 4.0, 8.0, 3.0, 6.0, 12.0, 9.0, 18.0, 27.0)); - private static final List EXPECTED_SPARSE_OUTPUT = + private static final List EXPECTED_SPARSE_OUTPUT = Arrays.asList( Vectors.sparse( 55, @@ -87,14 +87,14 @@ public void before() { inputDataTable = tEnv.fromDataStream(dataStream).as("denseVec", "sparseVec"); } - private void verifyOutputResult(Table output, String outputCol, List expectedData) - throws Exception { + private void verifyOutputResult( + Table output, String outputCol, List expectedData) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); DataStream stream = tEnv.toDataStream(output); List results = IteratorUtils.toList(stream.executeAndCollect()); - List resultVec = new ArrayList<>(results.size()); + List resultVec = new ArrayList<>(results.size()); for (Row row : results) { if (row.getField(outputCol) != null) { resultVec.add(row.getFieldAs(outputCol)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java index e8179024a..a97f57a84 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java @@ -21,8 +21,8 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.feature.robustscaler.RobustScaler; import org.apache.flink.ml.feature.robustscaler.RobustScalerModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -79,7 +79,7 @@ public class RobustScalerTest extends AbstractTestBase { Row.of(Vectors.dense(99.0, -99.0)))); private static final double EPS = 1.0e-5; - private static final List EXPECTED_OUTPUT = + private static final List EXPECTED_OUTPUT = new ArrayList<>( Arrays.asList( Vectors.dense(0.75, -0.75), @@ -95,15 +95,15 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expected) throws Exception { + Table output, String outputCol, List expected) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); - DataStream stream = + DataStream stream = tEnv.toDataStream(output) .map( - (MapFunction) - row -> (DenseVector) row.getField(outputCol)); - List result = IteratorUtils.toList(stream.executeAndCollect()); + (MapFunction) + row -> (DenseIntDoubleVector) row.getField(outputCol)); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expected, result, TestUtils::compare); } @@ -162,9 +162,10 @@ public void testInputTypeConversion() throws Exception { tEnv, trainDataTable.select(Expressions.$("input"))); predictDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictDataTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(trainDataTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(trainDataTable)); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(predictDataTable)); RobustScaler robustScaler = new RobustScaler(); @@ -217,7 +218,7 @@ public void testWithCentering() throws Exception { RobustScaler robustScaler = new RobustScaler().setWithCentering(true); RobustScalerModel model = robustScaler.fit(trainDataTable); Table output = model.transform(predictDataTable)[0]; - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.dense(-0.25, 0.25), @@ -231,7 +232,7 @@ public void testWithoutScaling() throws Exception { RobustScaler robustScaler = new RobustScaler().setWithCentering(true).setWithScaling(false); RobustScalerModel model = robustScaler.fit(trainDataTable); Table output = model.transform(predictDataTable)[0]; - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.dense(-1, 1), @@ -273,7 +274,7 @@ public void testZeroRange() throws Exception { Row.of(2, Vectors.dense(1.0, 1.0)), Row.of(3, Vectors.dense(1.0, 1.0)), Row.of(4, Vectors.dense(4.0, 4.0)))); - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.dense(0.0, -0.0), @@ -297,7 +298,7 @@ public void testNaNData() throws Exception { Row.of(3, Vectors.dense(2.0, -2.0)), Row.of(4, Vectors.dense(3.0, -3.0)), Row.of(5, Vectors.dense(4.0, -4.0)))); - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.dense(0.0, Double.NaN), @@ -322,11 +323,11 @@ public void testGetModelData() throws Exception { Arrays.asList("medians", "ranges"), modelData.getResolvedSchema().getColumnNames()); DataStream output = tEnv.toDataStream(modelData); List modelRows = IteratorUtils.toList(output.executeAndCollect()); - DenseVector medians = (DenseVector) modelRows.get(0).getField(0); - DenseVector ranges = (DenseVector) modelRows.get(0).getField(1); + DenseIntDoubleVector medians = (DenseIntDoubleVector) modelRows.get(0).getField(0); + DenseIntDoubleVector ranges = (DenseIntDoubleVector) modelRows.get(0).getField(1); - DenseVector expectedMedians = Vectors.dense(4.0, -4.0); - DenseVector expectedRanges = Vectors.dense(4.0, 4.0); + DenseIntDoubleVector expectedMedians = Vectors.dense(4.0, -4.0); + DenseIntDoubleVector expectedRanges = Vectors.dense(4.0, 4.0); assertEquals(expectedMedians, medians); assertEquals(expectedRanges, ranges); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java index 7d08c70ce..5c18ee727 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java @@ -21,9 +21,9 @@ import org.apache.flink.ml.feature.standardscaler.StandardScaler; import org.apache.flink.ml.feature.standardscaler.StandardScalerModel; import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -61,19 +61,19 @@ public class StandardScalerTest extends AbstractTestBase { Row.of(Vectors.dense(1.4, -5, 1)), Row.of(Vectors.dense(2, -1, -2))); - private final List expectedResWithMean = + private final List expectedResWithMean = Arrays.asList( Vectors.dense(-2.8, 8, 1), Vectors.dense(1.1, -6, 1), Vectors.dense(1.7, -2, -2)); - private final List expectedResWithStd = + private final List expectedResWithStd = Arrays.asList( Vectors.dense(-1.0231819, 1.2480754, 0.5773502), Vectors.dense(0.5729819, -0.6933752, 0.5773503), Vectors.dense(0.8185455, -0.1386750, -1.1547005)); - private final List expectedResWithMeanAndStd = + private final List expectedResWithMeanAndStd = Arrays.asList( Vectors.dense(-1.1459637, 1.1094004, 0.5773503), Vectors.dense(0.45020003, -0.8320503, 0.5773503), @@ -92,13 +92,14 @@ public void before() { @SuppressWarnings("unchecked") private void verifyPredictionResult( - List expectedOutput, Table output, String predictionCol) throws Exception { + List expectedOutput, Table output, String predictionCol) + throws Exception { List collectedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); - List predictions = new ArrayList<>(collectedResult.size()); + List predictions = new ArrayList<>(collectedResult.size()); for (Row r : collectedResult) { - Vector vec = (Vector) r.getField(predictionCol); + IntDoubleVector vec = (IntDoubleVector) r.getField(predictionCol); predictions.add(vec.toDense()); } @@ -179,7 +180,8 @@ public void testFitAndPredictWithMeanAndStd() throws Exception { public void testInputTypeConversion() throws Exception { denseTable = TestUtils.convertDataTypesToSparseInt(tEnv, denseTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(denseTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(denseTable)); StandardScaler standardScaler = new StandardScaler().setWithMean(true); Table output = standardScaler.fit(denseTable).transform(denseTable)[0]; @@ -257,7 +259,7 @@ public void testSparseInput() throws Exception { Row.of(Vectors.sparse(3, new int[] {0, 2}, new double[] {1.4, 1}))); Table sparseTable = tEnv.fromDataStream(env.fromCollection(sparseInput)).as("input"); - final List expectedResWithStd = + final List expectedResWithStd = Arrays.asList( Vectors.dense(-1.2653836, 1, 0), Vectors.dense(0, 2, -1.30930734), diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java index f722a78a3..af4551d55 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector; import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.util.TestUtils; @@ -512,11 +512,13 @@ private void verifyOutputResult(Table table, int... expectedIndices) throws Exce CloseableIterator rowIterator = tEnv.toDataStream(table).executeAndCollect(); while (rowIterator.hasNext()) { Row row = rowIterator.next(); - assertEquals(expectedIndices.length, ((Vector) row.getField("output")).size()); + assertEquals( + expectedIndices.length, + ((IntDoubleVector) (row.getField("output"))).size().intValue()); for (int i = 0; i < expectedIndices.length; i++) { assertEquals( - ((Vector) row.getField("features")).get(expectedIndices[i]), - ((Vector) row.getField("output")).get(i), + ((IntDoubleVector) row.getField("features")).get(expectedIndices[i]), + ((IntDoubleVector) row.getField("output")).get(i), EPS); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java index 43a893ce5..35aca0606 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector; import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -74,7 +75,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { Row.of(Vectors.dense(0.1, 0.2, 0.3, 0.4, 0.5, 0.6)), Row.of(Vectors.sparse(6, new int[] {0, 3, 4}, new double[] {0.1, 0.3, 0.5}))); - private static final List EXPECTED_OUTPUT = + private static final List EXPECTED_OUTPUT = Arrays.asList( Vectors.dense(1.0, 4.0, 6.0), Vectors.dense(0.1, 0.4, 0.6), @@ -98,7 +99,7 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expected) throws Exception { + Table output, String outputCol, List expected) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); DataStream stream = @@ -106,7 +107,7 @@ private static void verifyPredictionResult( .map( (MapFunction) row -> (Vector) row.getField(outputCol), VectorTypeInfo.INSTANCE); - List result = IteratorUtils.toList(stream.executeAndCollect()); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expected, result, TestUtils::compare); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java index f70d95fc4..2280b957d 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java @@ -20,8 +20,8 @@ import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -111,15 +111,15 @@ public class VectorAssemblerTest extends AbstractTestBase { Vectors.sparse( 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0}))); - private static final SparseVector EXPECTED_OUTPUT_DATA_1 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_1 = Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 1.0, 1.0}); - private static final DenseVector EXPECTED_OUTPUT_DATA_2 = + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); - private static final DenseVector EXPECTED_OUTPUT_DATA_3 = + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_3 = Vectors.dense(2.0, 2.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); - private static final DenseVector EXPECTED_OUTPUT_DATA_4 = + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_4 = Vectors.dense(Double.NaN, Double.NaN, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); - private static final DenseVector EXPECTED_OUTPUT_DATA_5 = + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_5 = Vectors.dense(2.0, 2.1, Double.NaN, 0.0, 1.0, 2.0, 3.0, 4.0); @Before @@ -373,7 +373,10 @@ public void testInputTypeConversion() throws Exception { inputDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, inputDataTable); assertArrayEquals( new Class[] { - Integer.class, SparseVector.class, Integer.class, SparseVector.class + Integer.class, + SparseIntDoubleVector.class, + Integer.class, + SparseIntDoubleVector.class }, TestUtils.getColumnDataTypes(inputDataTable)); @@ -409,7 +412,8 @@ public void testNumber2Vector() throws Exception { List results = IteratorUtils.toList(dataStream.executeAndCollect()); for (Row result : results) { if (result.getField(2) != null) { - assertEquals(result.getField(2), ((DenseVector) result.getField(4)).values[0]); + assertEquals( + result.getField(2), ((DenseIntDoubleVector) result.getField(4)).values[0]); } } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java index fb9838001..9bb31757f 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java @@ -19,8 +19,8 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.vectorslicer.VectorSlicer; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -59,12 +59,12 @@ public class VectorSlicerTest extends AbstractTestBase { Vectors.dense(2.3, 4.1, 1.3, 2.4, 5.1, 4.1), Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {0.1, 0.2, 0.3}))); - private static final DenseVector EXPECTED_OUTPUT_DATA_1 = Vectors.dense(2.1, 3.1, 2.3); - private static final DenseVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(2.3, 4.1, 1.3); + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_1 = Vectors.dense(2.1, 3.1, 2.3); + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(2.3, 4.1, 1.3); - private static final SparseVector EXPECTED_OUTPUT_DATA_3 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_3 = Vectors.sparse(3, new int[] {1}, new double[] {0.1}); - private static final SparseVector EXPECTED_OUTPUT_DATA_4 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_4 = Vectors.sparse(3, new int[] {1, 2}, new double[] {0.1, 0.2}); @Before diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java index 1c051e17e..bd7c078ff 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java @@ -21,9 +21,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData; @@ -93,7 +93,9 @@ public void before() { trainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); } @@ -173,7 +175,7 @@ public void testFitAndPredict() throws Exception { public void testInputTypeConversion() throws Exception { trainDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, trainDataTable); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class, Integer.class}, TestUtils.getColumnDataTypes(trainDataTable)); LinearRegression linearRegression = new LinearRegression().setWeightCol("weight"); @@ -246,7 +248,9 @@ public void testMoreSubtaskThanData() throws Exception { trainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java index 1b2270852..891a8f50e 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.stats; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.stats.anovatest.ANOVATest; import org.apache.flink.ml.util.TestUtils; @@ -328,13 +328,13 @@ private static void verifyTransformationResult(Table output, Row expected) throw Row result = results.get(0); assertEquals(3, result.getArity()); assertArrayEquals( - ((Vector) expected.getField(0)).toArray(), - ((Vector) result.getField(0)).toArray(), + ((IntDoubleVector) expected.getField(0)).toArray(), + ((IntDoubleVector) result.getField(0)).toArray(), EPS); assertArrayEquals((long[]) expected.getField(1), (long[]) result.getField(1)); assertArrayEquals( - ((Vector) expected.getField(2)).toArray(), - ((Vector) result.getField(2)).toArray(), + ((IntDoubleVector) expected.getField(2)).toArray(), + ((IntDoubleVector) result.getField(2)).toArray(), EPS); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java index faf8fe576..a79e82463 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.stats; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.stats.fvaluetest.FValueTest; import org.apache.flink.ml.util.TestUtils; @@ -346,13 +346,13 @@ private static void verifyTransformationResult(Table output, Row expected) throw Row result = results.get(0); assertEquals(3, result.getArity()); assertArrayEquals( - ((Vector) expected.getField(0)).toArray(), - ((Vector) result.getField(0)).toArray(), + ((IntDoubleVector) expected.getField(0)).toArray(), + ((IntDoubleVector) result.getField(0)).toArray(), EPS); assertArrayEquals((long[]) expected.getField(1), (long[]) result.getField(1)); assertArrayEquals( - ((Vector) expected.getField(2)).toArray(), - ((Vector) result.getField(2)).toArray(), + ((IntDoubleVector) expected.getField(2)).toArray(), + ((IntDoubleVector) result.getField(2)).toArray(), EPS); } diff --git a/flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py b/flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py index e8091d1f6..c2a218cb2 100644 --- a/flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py +++ b/flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py @@ -120,7 +120,7 @@ def test_save_load_and_predict(self): model = regression.fit(self.binomial_data_table) self.assertEqual( model.get_model_data()[0].get_schema().get_field_names(), - ['coefficient', 'modelVersion']) + ['coefficient', "startIndex", "endIndex", 'modelVersion']) output = model.transform(self.binomial_data_table)[0] field_names = output.get_schema().get_field_names() self.verify_predict_result( diff --git a/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py b/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py index db59df0b7..2ce8786a9 100644 --- a/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py +++ b/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py @@ -103,6 +103,12 @@ def module(self): from pyflink.ml import classification return classification + def exclude_java_stage(self) -> List[str]: + # TODO: Add python support for LogisticRegressionWithFtrl. + return [ + "logisticregression.LogisticRegressionWithFtrl", + ] + class ClusteringCompletenessTest(CompletenessTest, MLLibTest): diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java new file mode 100644 index 000000000..871dbf36c --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java @@ -0,0 +1,40 @@ +/// * +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package org.apache.flink.ml.common.feature; +// +// import org.apache.flink.api.java.tuple.Tuple2; +// +/// ** A data point to represent values that use long as index and double as values. */ +// public class LabeledLargePointWithWeight { +// public Tuple2 features; +// +// public double label; +// +// public double weight; +// +// public LabeledLargePointWithWeight( +// Tuple2 features, double label, double weight) { +// this.features = features; +// this.label = label; +// this.weight = weight; +// } +// +// /** Makes it pojo to use flink serializer. */ +// public LabeledLargePointWithWeight() {} +// } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java index 8440bc97d..6a6344ece 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java @@ -18,46 +18,23 @@ package org.apache.flink.ml.common.feature; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; /** Utility class to represent a data point that contains features, label and weight. */ public class LabeledPointWithWeight { - private DenseVector features; + public Vector features; - private double label; + public double label; - private double weight; + public double weight; - public LabeledPointWithWeight(DenseVector features, double label, double weight) { + public LabeledPointWithWeight(Vector features, double label, double weight) { this.features = features; this.label = label; this.weight = weight; } + /** Makes it as pojo. */ public LabeledPointWithWeight() {} - - public DenseVector getFeatures() { - return features; - } - - public void setFeatures(DenseVector features) { - this.features = features; - } - - public double getLabel() { - return label; - } - - public void setLabel(double label) { - this.label = label; - } - - public double getWeight() { - return weight; - } - - public void setWeight(double weight) { - this.weight = weight; - } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java index 643f1e0d8..6bc96aebc 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java @@ -27,67 +27,70 @@ public class BLAS { dev.ludovic.netlib.JavaBLAS.getInstance(); /** \sum_i |x_i| . */ - public static double asum(DenseVector x) { + public static double asum(DenseIntDoubleVector x) { return JAVA_BLAS.dasum(x.size(), x.values, 0, 1); } /** y += a * x . */ - public static void axpy(double a, Vector x, DenseVector y) { - Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); + public static void axpy(double a, IntDoubleVector x, DenseIntDoubleVector y) { + Preconditions.checkArgument( + x.size().intValue() == y.size().intValue(), "Vector size mismatched."); axpy(a, x, y, x.size()); } /** y += a * x for the first k dimensions, with the other dimensions unchanged. */ - public static void axpy(double a, Vector x, DenseVector y, int k) { + public static void axpy(double a, IntDoubleVector x, DenseIntDoubleVector y, int k) { Preconditions.checkArgument(x.size() >= k && y.size() >= k); - if (x instanceof SparseVector) { - axpy(a, (SparseVector) x, y, k); + if (x instanceof SparseIntDoubleVector) { + axpy(a, (SparseIntDoubleVector) x, y, k); } else { - axpy(a, (DenseVector) x, y, k); + axpy(a, (DenseIntDoubleVector) x, y, k); } } /** Computes the hadamard product of the two vectors (y = y \hdot x). */ - public static void hDot(Vector x, Vector y) { - Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); - if (x instanceof SparseVector) { - if (y instanceof SparseVector) { - hDot((SparseVector) x, (SparseVector) y); + public static void hDot(IntDoubleVector x, IntDoubleVector y) { + Preconditions.checkArgument( + x.size().intValue() == y.size().intValue(), "Vector size mismatched."); + if (x instanceof SparseIntDoubleVector) { + if (y instanceof SparseIntDoubleVector) { + hDot((SparseIntDoubleVector) x, (SparseIntDoubleVector) y); } else { - hDot((SparseVector) x, (DenseVector) y); + hDot((SparseIntDoubleVector) x, (DenseIntDoubleVector) y); } } else { - if (y instanceof SparseVector) { - hDot((DenseVector) x, (SparseVector) y); + if (y instanceof SparseIntDoubleVector) { + hDot((DenseIntDoubleVector) x, (SparseIntDoubleVector) y); } else { - hDot((DenseVector) x, (DenseVector) y); + hDot((DenseIntDoubleVector) x, (DenseIntDoubleVector) y); } } } /** Computes the dot of the two vectors (y \dot x). */ - public static double dot(Vector x, Vector y) { - Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); - if (x instanceof SparseVector) { - if (y instanceof SparseVector) { - return dot((SparseVector) x, (SparseVector) y); + public static double dot(IntDoubleVector x, IntDoubleVector y) { + Preconditions.checkArgument( + x.size().intValue() == y.size().intValue(), "Vector size mismatched."); + if (x instanceof SparseIntDoubleVector) { + if (y instanceof SparseIntDoubleVector) { + return dot((SparseIntDoubleVector) x, (SparseIntDoubleVector) y); } else { - return dot((DenseVector) y, (SparseVector) x); + return dot((DenseIntDoubleVector) y, (SparseIntDoubleVector) x); } } else { - if (y instanceof SparseVector) { - return dot((DenseVector) x, (SparseVector) y); + if (y instanceof SparseIntDoubleVector) { + return dot((DenseIntDoubleVector) x, (SparseIntDoubleVector) y); } else { - return dot((DenseVector) x, (DenseVector) y); + return dot((DenseIntDoubleVector) x, (DenseIntDoubleVector) y); } } } - private static double dot(DenseVector x, DenseVector y) { + private static double dot(DenseIntDoubleVector x, DenseIntDoubleVector y) { return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1); } - private static double dot(DenseVector x, SparseVector y) { + private static double dot(DenseIntDoubleVector x, SparseIntDoubleVector y) { double dotValue = 0.0; for (int i = 0; i < y.indices.length; ++i) { dotValue += y.values[i] * x.values[y.indices[i]]; @@ -95,7 +98,7 @@ private static double dot(DenseVector x, SparseVector y) { return dotValue; } - private static double dot(SparseVector x, SparseVector y) { + private static double dot(SparseIntDoubleVector x, SparseIntDoubleVector y) { double dotValue = 0; int p0 = 0; int p1 = 0; @@ -114,27 +117,30 @@ private static double dot(SparseVector x, SparseVector y) { } /** \sqrt(\sum_i x_i * x_i) . */ - public static double norm2(Vector x) { - if (x instanceof DenseVector) { - return norm2((DenseVector) x); + public static double norm2(IntDoubleVector x) { + if (x instanceof DenseIntDoubleVector) { + return norm2((DenseIntDoubleVector) x); + } else { + return norm2((SparseIntDoubleVector) x); } - return norm2((SparseVector) x); } - private static double norm2(DenseVector x) { + private static double norm2(DenseIntDoubleVector x) { return JAVA_BLAS.dnrm2(x.size(), x.values, 1); } - private static double norm2(SparseVector x) { + private static double norm2(SparseIntDoubleVector x) { return JAVA_BLAS.dnrm2(x.values.length, x.values, 1); } /** Calculates the p-norm of the vector x. */ - public static double norm(Vector x, double p) { + public static double norm(IntDoubleVector x, double p) { Preconditions.checkArgument(p >= 1.0, "p value must >= 1.0, but the current p is : " + p); double norm = 0.0; double[] data = - (x instanceof DenseVector) ? ((DenseVector) x).values : ((SparseVector) x).values; + (x instanceof DenseIntDoubleVector) + ? ((DenseIntDoubleVector) x).values + : ((SparseIntDoubleVector) x).values; if (p == 1.0) { for (double datum : data) { @@ -157,11 +163,11 @@ public static double norm(Vector x, double p) { } /** x = x * a . */ - public static void scal(double a, Vector x) { - if (x instanceof DenseVector) { - JAVA_BLAS.dscal(x.size(), a, ((DenseVector) x).values, 1); + public static void scal(double a, IntDoubleVector x) { + if (x instanceof DenseIntDoubleVector) { + JAVA_BLAS.dscal(x.size(), a, ((DenseIntDoubleVector) x).values, 1); } else { - double[] values = ((SparseVector) x).values; + double[] values = ((SparseIntDoubleVector) x).values; JAVA_BLAS.dscal(values.length, a, values, 1); } } @@ -180,9 +186,9 @@ public static void gemv( double alpha, DenseMatrix matrix, boolean transMatrix, - DenseVector x, + DenseIntDoubleVector x, double beta, - DenseVector y) { + DenseIntDoubleVector y) { Preconditions.checkArgument( transMatrix ? (matrix.numRows() == x.size() && matrix.numCols() == y.size()) @@ -203,11 +209,11 @@ public static void gemv( 1); } - private static void axpy(double a, DenseVector x, DenseVector y, int k) { + private static void axpy(double a, DenseIntDoubleVector x, DenseIntDoubleVector y, int k) { JAVA_BLAS.daxpy(k, a, x.values, 1, y.values, 1); } - private static void axpy(double a, SparseVector x, DenseVector y, int k) { + private static void axpy(double a, SparseIntDoubleVector x, DenseIntDoubleVector y, int k) { for (int i = 0; i < x.indices.length; i++) { int index = x.indices[i]; if (index >= k) { @@ -217,7 +223,7 @@ private static void axpy(double a, SparseVector x, DenseVector y, int k) { } } - private static void hDot(SparseVector x, SparseVector y) { + private static void hDot(SparseIntDoubleVector x, SparseIntDoubleVector y) { int idx = 0; int idy = 0; while (idx < x.indices.length && idy < y.indices.length) { @@ -238,7 +244,7 @@ private static void hDot(SparseVector x, SparseVector y) { } } - private static void hDot(SparseVector x, DenseVector y) { + private static void hDot(SparseIntDoubleVector x, DenseIntDoubleVector y) { int idx = 0; for (int i = 0; i < y.size(); i++) { if (idx < x.indices.length && x.indices[idx] == i) { @@ -250,13 +256,13 @@ private static void hDot(SparseVector x, DenseVector y) { } } - private static void hDot(DenseVector x, SparseVector y) { + private static void hDot(DenseIntDoubleVector x, SparseIntDoubleVector y) { for (int i = 0; i < y.values.length; i++) { y.values[i] *= x.values[y.indices[i]]; } } - private static void hDot(DenseVector x, DenseVector y) { + private static void hDot(DenseIntDoubleVector x, DenseIntDoubleVector y) { for (int i = 0; i < x.values.length; i++) { y.values[i] *= x.values[i]; } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseIntDoubleVector.java similarity index 71% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseIntDoubleVector.java index e26a93a4f..699fc269a 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseIntDoubleVector.java @@ -20,36 +20,36 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.typeinfo.TypeInfo; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfoFactory; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfoFactory; import java.util.Arrays; -/** A dense vector of double values. */ -@TypeInfo(DenseVectorTypeInfoFactory.class) +/** A dense vector with int as keys and double as values. */ +@TypeInfo(DenseIntDoubleVectorTypeInfoFactory.class) @PublicEvolving -public class DenseVector implements Vector { +public class DenseIntDoubleVector implements IntDoubleVector { public final double[] values; - public DenseVector(double[] values) { + public DenseIntDoubleVector(double[] values) { this.values = values; } - public DenseVector(int size) { + public DenseIntDoubleVector(int size) { this.values = new double[size]; } @Override - public int size() { + public Integer size() { return values.length; } @Override - public double get(int i) { + public Double get(Integer i) { return values[i]; } @Override - public void set(int i, double value) { + public void set(Integer i, Double value) { values[i] = value; } @@ -59,12 +59,12 @@ public double[] toArray() { } @Override - public DenseVector toDense() { + public DenseIntDoubleVector toDense() { return this; } @Override - public SparseVector toSparse() { + public SparseIntDoubleVector toSparse() { int numNonZeros = 0; for (double value : values) { if (value != 0.0) { @@ -84,7 +84,7 @@ public SparseVector toSparse() { k++; } - return new SparseVector(size(), nonZeroIndices, numZeroValues); + return new SparseIntDoubleVector(size(), nonZeroIndices, numZeroValues); } @Override @@ -94,10 +94,10 @@ public String toString() { @Override public boolean equals(Object obj) { - if (!(obj instanceof DenseVector)) { + if (!(obj instanceof DenseIntDoubleVector)) { return false; } - return Arrays.equals(values, ((DenseVector) obj).values); + return Arrays.equals(values, ((DenseIntDoubleVector) obj).values); } @Override @@ -106,7 +106,7 @@ public int hashCode() { } @Override - public DenseVector clone() { - return new DenseVector(values.clone()); + public DenseIntDoubleVector clone() { + return new DenseIntDoubleVector(values.clone()); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/IntDoubleVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/IntDoubleVector.java new file mode 100644 index 000000000..2bed89d96 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/IntDoubleVector.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfoFactory; + +/** A vector with int as keys and double as values. */ +@TypeInfo(VectorTypeInfoFactory.class) +@PublicEvolving +public interface IntDoubleVector extends Vector { + + @Override + double[] toArray(); + + @Override + DenseIntDoubleVector toDense(); + + @Override + SparseIntDoubleVector toSparse(); + + @Override + IntDoubleVector clone(); +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/LongDoubleVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/LongDoubleVector.java new file mode 100644 index 000000000..e0265f86f --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/LongDoubleVector.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +/** A vector with long as keys and double as values. */ +public interface LongDoubleVector extends Vector { + @Override + default double[] toArray() { + throw new UnsupportedOperationException( + "LongDoubleVector cannot be converted to dense array."); + } + + @Override + default Vector toDense() { + throw new UnsupportedOperationException( + "LongDoubleVector cannot be converted to dense array."); + } + + @Override + SparseLongDoubleVector toSparse(); + + @Override + LongDoubleVector clone(); +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseIntDoubleVector.java similarity index 86% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseIntDoubleVector.java index 54e707f63..f1b656673 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseIntDoubleVector.java @@ -20,21 +20,21 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.typeinfo.TypeInfo; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfoFactory; import org.apache.flink.util.Preconditions; import java.util.Arrays; import java.util.Objects; -/** A sparse vector of double values. */ -@TypeInfo(SparseVectorTypeInfoFactory.class) +/** A sparse vector with int as keys and double as values. */ +@TypeInfo(SparseIntDoubleVectorTypeInfoFactory.class) @PublicEvolving -public class SparseVector implements Vector { +public class SparseIntDoubleVector implements IntDoubleVector { public final int n; public int[] indices; public double[] values; - public SparseVector(int n, int[] indices, double[] values) { + public SparseIntDoubleVector(int n, int[] indices, double[] values) { this.n = n; this.indices = indices; this.values = values; @@ -45,12 +45,12 @@ public SparseVector(int n, int[] indices, double[] values) { } @Override - public int size() { + public Integer size() { return n; } @Override - public double get(int i) { + public Double get(Integer i) { int pos = Arrays.binarySearch(indices, i); if (pos >= 0) { return values[pos]; @@ -59,7 +59,7 @@ public double get(int i) { } @Override - public void set(int i, double value) { + public void set(Integer i, Double value) { int pos = Arrays.binarySearch(indices, i); if (pos >= 0) { values[pos] = value; @@ -88,12 +88,16 @@ public double[] toArray() { } @Override - public DenseVector toDense() { - return new DenseVector(toArray()); + public DenseIntDoubleVector toDense() { + double[] result = new double[n]; + for (int i = 0; i < indices.length; i++) { + result[indices[i]] = values[i]; + } + return new DenseIntDoubleVector(result); } @Override - public SparseVector toSparse() { + public SparseIntDoubleVector toSparse() { return this; } @@ -105,7 +109,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) { return false; } - SparseVector that = (SparseVector) o; + SparseIntDoubleVector that = (SparseIntDoubleVector) o; return n == that.n && Arrays.equals(indices, that.indices) && Arrays.equals(values, that.values); @@ -199,7 +203,7 @@ public String toString() { } @Override - public SparseVector clone() { - return new SparseVector(n, indices.clone(), values.clone()); + public SparseIntDoubleVector clone() { + return new SparseIntDoubleVector(n, indices.clone(), values.clone()); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseLongDoubleVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseLongDoubleVector.java new file mode 100644 index 000000000..9f7ca4bbe --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseLongDoubleVector.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseLongDoubleVectorTypeInfoFactory; +import org.apache.flink.util.Preconditions; + +import java.util.Arrays; + +/** + * A sparse vector with long as keys and double as values. + * + *

TODO: Add processing logic for {@link SparseLongDoubleVector} for existing algorithms. + */ +@TypeInfo(SparseLongDoubleVectorTypeInfoFactory.class) +@PublicEvolving +public class SparseLongDoubleVector implements LongDoubleVector { + + public final long n; + public long[] indices; + public double[] values; + + public SparseLongDoubleVector(long n, long[] indices, double[] values) { + this.n = n; + this.indices = indices; + this.values = values; + if (!isIndicesSorted()) { + sortIndices(); + } + validateSortedData(); + } + + @Override + public Long size() { + return n; + } + + @Override + public Double get(Long i) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + return values[pos]; + } + return 0.; + } + + @Override + public void set(Long i, Double value) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + values[pos] = value; + } else if (value != 0.0) { + Preconditions.checkArgument(i < n, "Index out of bounds: " + i); + long[] indices = new long[this.indices.length + 1]; + double[] values = new double[this.indices.length + 1]; + System.arraycopy(this.indices, 0, indices, 0, -pos - 1); + System.arraycopy(this.values, 0, values, 0, -pos - 1); + indices[-pos - 1] = i; + values[-pos - 1] = value; + System.arraycopy(this.indices, -pos - 1, indices, -pos, this.indices.length + pos + 1); + System.arraycopy(this.values, -pos - 1, values, -pos, this.indices.length + pos + 1); + this.indices = indices; + this.values = values; + } + } + + @Override + public SparseLongDoubleVector toSparse() { + return this; + } + + /** + * Checks whether input data is validate. + * + *

This function does the following checks: + * + *

    + *
  • The indices array and values array are of the same size. + *
  • vector indices are in valid range. + *
  • vector indices are unique. + *
+ * + *

This function works as expected only when indices are sorted. + */ + private void validateSortedData() { + Preconditions.checkArgument( + indices.length == values.length, + "Indices size and values size should be the same."); + if (this.indices.length > 0) { + Preconditions.checkArgument( + this.indices[0] >= 0 && this.indices[this.indices.length - 1] < this.n, + "Index out of bound."); + } + for (int i = 1; i < this.indices.length; i++) { + Preconditions.checkArgument( + this.indices[i] > this.indices[i - 1], "Indices duplicated."); + } + } + + private boolean isIndicesSorted() { + for (int i = 1; i < this.indices.length; i++) { + if (this.indices[i] < this.indices[i - 1]) { + return false; + } + } + return true; + } + + /** Sorts the indices and values. */ + private void sortIndices() { + sortImpl(this.indices, this.values, 0, this.indices.length - 1); + } + + /** Sorts the indices and values using quick sort. */ + private static void sortImpl(long[] indices, double[] values, int low, int high) { + int pivotPos = (low + high) / 2; + long pivot = indices[pivotPos]; + swapIndexAndValue(indices, values, pivotPos, high); + + int pos = low - 1; + for (int i = low; i <= high; i++) { + if (indices[i] <= pivot) { + pos++; + swapIndexAndValue(indices, values, pos, i); + } + } + if (high > pos + 1) { + sortImpl(indices, values, pos + 1, high); + } + if (pos - 1 > low) { + sortImpl(indices, values, low, pos - 1); + } + } + + private static void swapIndexAndValue(long[] indices, double[] values, int index1, int index2) { + long tempIndex = indices[index1]; + indices[index1] = indices[index2]; + indices[index2] = tempIndex; + double tempValue = values[index1]; + values[index1] = values[index2]; + values[index2] = tempValue; + } + + @Override + public SparseLongDoubleVector clone() { + return new SparseLongDoubleVector(n, indices.clone(), values.clone()); + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vector.java index c63f1d629..df9e09fb3 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vector.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vector.java @@ -24,29 +24,29 @@ import java.io.Serializable; -/** A vector of double values. */ +/** A vector representation of numbers. */ @TypeInfo(VectorTypeInfoFactory.class) @PublicEvolving -public interface Vector extends Serializable { +public interface Vector, V extends Number> extends Serializable { /** Gets the size of the vector. */ - int size(); + K size(); /** Gets the value of the ith element. */ - double get(int i); + V get(K i); /** Sets the value of the ith element. */ - void set(int i, double value); + void set(K i, V value); - /** Converts the instance to a double array. */ - double[] toArray(); + /** Converts the instance to a primitive array. */ + Object toArray(); /** Converts the instance to a dense vector. */ - DenseVector toDense(); + Vector toDense(); /** Converts the instance to a sparse vector. */ - SparseVector toSparse(); + Vector toSparse(); /** Makes a deep copy of the vector. */ - Vector clone(); + Vector clone(); } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java index bb78ef2ce..d8867daf3 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java @@ -25,15 +25,15 @@ /** A vector with its norm. */ @TypeInfo(VectorWithNormTypeInfoFactory.class) public class VectorWithNorm { - public final Vector vector; + public final IntDoubleVector vector; public final double l2Norm; - public VectorWithNorm(Vector vector) { + public VectorWithNorm(IntDoubleVector vector) { this(vector, BLAS.norm2(vector)); } - public VectorWithNorm(Vector vector, double l2Norm) { + public VectorWithNorm(IntDoubleVector vector, double l2Norm) { this.vector = vector; this.l2Norm = l2Norm; } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java index 99c508080..8d10a53b5 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java @@ -25,12 +25,12 @@ public class Vectors { /** Creates a dense vector from its values. */ - public static DenseVector dense(double... values) { - return new DenseVector(values); + public static DenseIntDoubleVector dense(double... values) { + return new DenseIntDoubleVector(values); } /** Creates a sparse vector from its values. */ - public static SparseVector sparse(int size, int[] indices, double[] values) { - return new SparseVector(size, indices, values); + public static SparseIntDoubleVector sparse(int size, int[] indices, double[] values) { + return new SparseIntDoubleVector(size, indices, values); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorSerializer.java similarity index 74% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorSerializer.java index 5b6f984aa..88cf6247c 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorSerializer.java @@ -24,15 +24,15 @@ import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.util.Bits; import java.io.IOException; import java.util.Arrays; import java.util.Objects; -/** Specialized serializer for {@link DenseVector}. */ -public final class DenseVectorSerializer extends TypeSerializer { +/** Specialized serializer for {@link DenseIntDoubleVector}. */ +public final class DenseIntDoubleVectorSerializer extends TypeSerializer { private static final long serialVersionUID = 1L; @@ -46,22 +46,22 @@ public boolean isImmutableType() { } @Override - public TypeSerializer duplicate() { - return new DenseVectorSerializer(); + public TypeSerializer duplicate() { + return new DenseIntDoubleVectorSerializer(); } @Override - public DenseVector createInstance() { - return new DenseVector(EMPTY); + public DenseIntDoubleVector createInstance() { + return new DenseIntDoubleVector(EMPTY); } @Override - public DenseVector copy(DenseVector from) { - return new DenseVector(Arrays.copyOf(from.values, from.values.length)); + public DenseIntDoubleVector copy(DenseIntDoubleVector from) { + return new DenseIntDoubleVector(Arrays.copyOf(from.values, from.values.length)); } @Override - public DenseVector copy(DenseVector from, DenseVector reuse) { + public DenseIntDoubleVector copy(DenseIntDoubleVector from, DenseIntDoubleVector reuse) { if (from.values.length == reuse.values.length) { System.arraycopy(from.values, 0, reuse.values, 0, from.values.length); return reuse; @@ -75,7 +75,7 @@ public int getLength() { } @Override - public void serialize(DenseVector vector, DataOutputView target) throws IOException { + public void serialize(DenseIntDoubleVector vector, DataOutputView target) throws IOException { if (vector == null) { throw new IllegalArgumentException("The vector must not be null."); } @@ -93,11 +93,11 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept } @Override - public DenseVector deserialize(DataInputView source) throws IOException { + public DenseIntDoubleVector deserialize(DataInputView source) throws IOException { int len = source.readInt(); double[] values = new double[len]; readDoubleArray(values, source, len); - return new DenseVector(values); + return new DenseIntDoubleVector(values); } // Reads `len` double values from `source` into `dst`. @@ -116,7 +116,8 @@ private void readDoubleArray(double[] dst, DataInputView source, int len) throws } @Override - public DenseVector deserialize(DenseVector reuse, DataInputView source) throws IOException { + public DenseIntDoubleVector deserialize(DenseIntDoubleVector reuse, DataInputView source) + throws IOException { int len = source.readInt(); if (len == reuse.values.length) { readDoubleArray(reuse.values, source, len); @@ -125,7 +126,7 @@ public DenseVector deserialize(DenseVector reuse, DataInputView source) throws I double[] values = new double[len]; readDoubleArray(values, source, len); - return new DenseVector(values); + return new DenseIntDoubleVector(values); } @Override @@ -137,28 +138,28 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public boolean equals(Object o) { - return o instanceof DenseVectorSerializer; + return o instanceof DenseIntDoubleVectorSerializer; } @Override public int hashCode() { - return Objects.hashCode(DenseVectorSerializer.class); + return Objects.hashCode(DenseIntDoubleVectorSerializer.class); } // ------------------------------------------------------------------------ @Override - public TypeSerializerSnapshot snapshotConfiguration() { + public TypeSerializerSnapshot snapshotConfiguration() { return new DenseVectorSerializerSnapshot(); } /** Serializer configuration snapshot for compatibility and format evolution. */ @SuppressWarnings("WeakerAccess") public static final class DenseVectorSerializerSnapshot - extends SimpleTypeSerializerSnapshot { + extends SimpleTypeSerializerSnapshot { public DenseVectorSerializerSnapshot() { - super(DenseVectorSerializer::new); + super(DenseIntDoubleVectorSerializer::new); } } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfo.java similarity index 71% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfo.java index 72f9a47d9..1ceca7bb0 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfo.java @@ -21,15 +21,15 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; -/** A {@link TypeInformation} for the {@link DenseVector} type. */ -public class DenseVectorTypeInfo extends TypeInformation { +/** A {@link TypeInformation} for the {@link DenseIntDoubleVector} type. */ +public class DenseIntDoubleVectorTypeInfo extends TypeInformation { private static final long serialVersionUID = 1L; - public static final DenseVectorTypeInfo INSTANCE = new DenseVectorTypeInfo(); + public static final DenseIntDoubleVectorTypeInfo INSTANCE = new DenseIntDoubleVectorTypeInfo(); - public DenseVectorTypeInfo() {} + public DenseIntDoubleVectorTypeInfo() {} @Override public int getArity() { @@ -42,8 +42,8 @@ public int getTotalFields() { } @Override - public Class getTypeClass() { - return DenseVector.class; + public Class getTypeClass() { + return DenseIntDoubleVector.class; } @Override @@ -62,8 +62,8 @@ public boolean isKeyType() { } @Override - public TypeSerializer createSerializer(ExecutionConfig executionConfig) { - return new DenseVectorSerializer(); + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new DenseIntDoubleVectorSerializer(); } // -------------------------------------------------------------------------------------------- @@ -75,12 +75,12 @@ public int hashCode() { @Override public boolean equals(Object obj) { - return obj instanceof DenseVectorTypeInfo; + return obj instanceof DenseIntDoubleVectorTypeInfo; } @Override public boolean canEqual(Object obj) { - return obj instanceof DenseVectorTypeInfo; + return obj instanceof DenseIntDoubleVectorTypeInfo; } @Override diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfoFactory.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfoFactory.java similarity index 81% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfoFactory.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfoFactory.java index 367f70638..e5920fa23 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfoFactory.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfoFactory.java @@ -21,21 +21,21 @@ import org.apache.flink.api.common.typeinfo.TypeInfoFactory; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.TypeExtractor; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import java.lang.reflect.Type; import java.util.Map; /** * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link - * DenseVector}. + * DenseIntDoubleVector}. */ -public class DenseVectorTypeInfoFactory extends TypeInfoFactory { +public class DenseIntDoubleVectorTypeInfoFactory extends TypeInfoFactory { @Override - public TypeInformation createTypeInfo( + public TypeInformation createTypeInfo( Type t, Map> genericParameters) { - return new DenseVectorTypeInfo(); + return new DenseIntDoubleVectorTypeInfo(); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorSerializer.java similarity index 76% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorSerializer.java index a20a4ffff..93c6679b3 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorSerializer.java @@ -24,13 +24,14 @@ import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import java.io.IOException; import java.util.Arrays; -/** Specialized serializer for {@link SparseVector}. */ -public final class SparseVectorSerializer extends TypeSerializerSingleton { +/** Specialized serializer for {@link SparseIntDoubleVector}. */ +public final class SparseIntDoubleVectorSerializer + extends TypeSerializerSingleton { private static final long serialVersionUID = 1L; @@ -38,7 +39,8 @@ public final class SparseVectorSerializer extends TypeSerializerSingleton snapshotConfiguration() { + public TypeSerializerSnapshot snapshotConfiguration() { return new SparseVectorSerializerSnapshot(); } /** Serializer configuration snapshot for compatibility and format evolution. */ @SuppressWarnings("WeakerAccess") public static final class SparseVectorSerializerSnapshot - extends SimpleTypeSerializerSnapshot { + extends SimpleTypeSerializerSnapshot { public SparseVectorSerializerSnapshot() { super(() -> INSTANCE); diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfo.java similarity index 70% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfo.java index 06686f088..bdf6430ee 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfo.java @@ -22,11 +22,12 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; -/** A {@link TypeInformation} for the {@link SparseVector} type. */ -public class SparseVectorTypeInfo extends TypeInformation { - public static final SparseVectorTypeInfo INSTANCE = new SparseVectorTypeInfo(); +/** A {@link TypeInformation} for the {@link SparseIntDoubleVector} type. */ +public class SparseIntDoubleVectorTypeInfo extends TypeInformation { + public static final SparseIntDoubleVectorTypeInfo INSTANCE = + new SparseIntDoubleVectorTypeInfo(); @Override public boolean isBasicType() { @@ -49,8 +50,8 @@ public int getTotalFields() { } @Override - public Class getTypeClass() { - return SparseVector.class; + public Class getTypeClass() { + return SparseIntDoubleVector.class; } @Override @@ -59,8 +60,8 @@ public boolean isKeyType() { } @Override - public TypeSerializer createSerializer(ExecutionConfig executionConfig) { - return SparseVectorSerializer.INSTANCE; + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return SparseIntDoubleVectorSerializer.INSTANCE; } @Override @@ -70,7 +71,7 @@ public String toString() { @Override public boolean equals(Object obj) { - return obj instanceof SparseVectorTypeInfo; + return obj instanceof SparseIntDoubleVectorTypeInfo; } @Override @@ -80,6 +81,6 @@ public int hashCode() { @Override public boolean canEqual(Object obj) { - return obj instanceof SparseVectorTypeInfo; + return obj instanceof SparseIntDoubleVectorTypeInfo; } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfoFactory.java similarity index 80% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfoFactory.java index 01c10367e..07446d34f 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfoFactory.java @@ -22,19 +22,19 @@ import org.apache.flink.api.common.typeinfo.TypeInfoFactory; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.TypeExtractor; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import java.lang.reflect.Type; import java.util.Map; /** * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link - * SparseVector}. + * SparseIntDoubleVector}. */ -public class SparseVectorTypeInfoFactory extends TypeInfoFactory { +public class SparseIntDoubleVectorTypeInfoFactory extends TypeInfoFactory { @Override - public TypeInformation createTypeInfo( + public TypeInformation createTypeInfo( Type type, Map> map) { - return new SparseVectorTypeInfo(); + return new SparseIntDoubleVectorTypeInfo(); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorSerializer.java new file mode 100644 index 000000000..f1cb239d0 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorSerializer.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; + +import java.io.IOException; +import java.util.Arrays; + +/** Specialized serializer for {@link SparseLongDoubleVector}. */ +public final class SparseLongDoubleVectorSerializer + extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + private static final double[] EMPTY_DOUBLE_ARRAY = new double[0]; + + private static final long[] EMPTY_LONG_ARRAY = new long[0]; + + public static final SparseLongDoubleVectorSerializer INSTANCE = + new SparseLongDoubleVectorSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public SparseLongDoubleVector createInstance() { + return new SparseLongDoubleVector(0, EMPTY_LONG_ARRAY, EMPTY_DOUBLE_ARRAY); + } + + @Override + public SparseLongDoubleVector copy(SparseLongDoubleVector from) { + return new SparseLongDoubleVector( + from.n, + Arrays.copyOf(from.indices, from.indices.length), + Arrays.copyOf(from.values, from.values.length)); + } + + @Override + public SparseLongDoubleVector copy(SparseLongDoubleVector from, SparseLongDoubleVector reuse) { + if (from.values.length == reuse.values.length && from.n == reuse.n) { + System.arraycopy(from.values, 0, reuse.values, 0, from.values.length); + System.arraycopy(from.indices, 0, reuse.indices, 0, from.indices.length); + return reuse; + } + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(SparseLongDoubleVector vector, DataOutputView target) throws IOException { + if (vector == null) { + throw new IllegalArgumentException("The vector must not be null."); + } + + target.writeLong(vector.n); + final int len = vector.values.length; + target.writeInt(len); + // TODO: optimize the serialization/deserialization process of SparseVectorSerializer. + for (int i = 0; i < len; i++) { + target.writeLong(vector.indices[i]); + target.writeDouble(vector.values[i]); + } + } + + // Reads `len` int values from `source` into `indices` and `len` double values from `source` + // into `values`. + private void readSparseVectorArrays( + long[] indices, double[] values, DataInputView source, int len) throws IOException { + for (int i = 0; i < len; i++) { + indices[i] = source.readLong(); + values[i] = source.readDouble(); + } + } + + @Override + public SparseLongDoubleVector deserialize(DataInputView source) throws IOException { + long n = source.readLong(); + int len = source.readInt(); + long[] indices = new long[len]; + double[] values = new double[len]; + readSparseVectorArrays(indices, values, source, len); + return new SparseLongDoubleVector(n, indices, values); + } + + @Override + public SparseLongDoubleVector deserialize(SparseLongDoubleVector reuse, DataInputView source) + throws IOException { + long n = source.readLong(); + int len = source.readInt(); + if (reuse.n == n && reuse.values.length == len) { + readSparseVectorArrays(reuse.indices, reuse.values, source, len); + return reuse; + } + + long[] indices = new long[len]; + double[] values = new double[len]; + readSparseVectorArrays(indices, values, source, len); + return new SparseLongDoubleVector(n, indices, values); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + long n = source.readLong(); + int len = source.readInt(); + + target.writeLong(n); + target.writeInt(len); + + target.write(source, len * 16); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new SparseLongDoubleVectorSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class SparseLongDoubleVectorSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public SparseLongDoubleVectorSerializerSnapshot() { + super(() -> INSTANCE); + } + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfo.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfo.java new file mode 100644 index 000000000..d574043eb --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfo.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; + +/** A {@link TypeInformation} for the {@link SparseLongDoubleVector} type. */ +public class SparseLongDoubleVectorTypeInfo extends TypeInformation { + public static final SparseLongDoubleVectorTypeInfo INSTANCE = + new SparseLongDoubleVectorTypeInfo(); + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 3; + } + + @Override + public int getTotalFields() { + return 3; + } + + @Override + public Class getTypeClass() { + return SparseLongDoubleVector.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer( + ExecutionConfig executionConfig) { + return SparseLongDoubleVectorSerializer.INSTANCE; + } + + @Override + public String toString() { + return "SparseVectorType"; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof SparseIntDoubleVectorSerializer; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof SparseIntDoubleVectorSerializer; + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfoFactory.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfoFactory.java new file mode 100644 index 000000000..fb8753382 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfoFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * SparseLongDoubleVector}. + */ +public class SparseLongDoubleVectorTypeInfoFactory extends TypeInfoFactory { + @Override + public TypeInformation createTypeInfo( + Type type, Map> map) { + return new SparseLongDoubleVectorTypeInfo(); + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorSerializer.java index 51b9307aa..d635d2296 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorSerializer.java @@ -24,23 +24,29 @@ import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; import org.apache.flink.ml.linalg.Vector; import java.io.IOException; -/** Specialized serializer for {@link Vector}. */ +/** Specialized serializer for {@link IntDoubleVector}. */ public final class VectorSerializer extends TypeSerializerSingleton { private static final long serialVersionUID = 1L; private static final double[] EMPTY = new double[0]; - private final DenseVectorSerializer denseVectorSerializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer denseVectorSerializer = + new DenseIntDoubleVectorSerializer(); - private static final SparseVectorSerializer SPARSE_VECTOR_SERIALIZER = - SparseVectorSerializer.INSTANCE; + private static final SparseIntDoubleVectorSerializer SPARSE_INT_DOUBLE_VECTOR_SERIALIZER = + SparseIntDoubleVectorSerializer.INSTANCE; + + private static final SparseLongDoubleVectorSerializer SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER = + SparseLongDoubleVectorSerializer.INSTANCE; @Override public boolean isImmutableType() { @@ -49,25 +55,32 @@ public boolean isImmutableType() { @Override public Vector createInstance() { - return new DenseVector(EMPTY); + return new DenseIntDoubleVector(EMPTY); } @Override public Vector copy(Vector from) { - if (from instanceof DenseVector) { - return denseVectorSerializer.copy((DenseVector) from); + if (from instanceof DenseIntDoubleVector) { + return denseVectorSerializer.copy((DenseIntDoubleVector) from); + } else if (from instanceof SparseIntDoubleVector) { + return SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.copy((SparseIntDoubleVector) from); } else { - return SPARSE_VECTOR_SERIALIZER.copy((SparseVector) from); + return SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.copy((SparseLongDoubleVector) from); } } @Override public Vector copy(Vector from, Vector reuse) { assert from.getClass() == reuse.getClass(); - if (from instanceof DenseVector) { - return denseVectorSerializer.copy((DenseVector) from, (DenseVector) reuse); + if (from instanceof DenseIntDoubleVector) { + return denseVectorSerializer.copy( + (DenseIntDoubleVector) from, (DenseIntDoubleVector) reuse); + } else if (from instanceof SparseIntDoubleVector) { + return SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.copy( + (SparseIntDoubleVector) from, (SparseIntDoubleVector) reuse); } else { - return SPARSE_VECTOR_SERIALIZER.copy((SparseVector) from, (SparseVector) reuse); + return SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.copy( + (SparseLongDoubleVector) from, (SparseLongDoubleVector) reuse); } } @@ -78,12 +91,15 @@ public int getLength() { @Override public void serialize(Vector vector, DataOutputView target) throws IOException { - if (vector instanceof DenseVector) { + if (vector instanceof DenseIntDoubleVector) { target.writeByte(0); - denseVectorSerializer.serialize((DenseVector) vector, target); - } else { + denseVectorSerializer.serialize((DenseIntDoubleVector) vector, target); + } else if (vector instanceof SparseIntDoubleVector) { target.writeByte(1); - SPARSE_VECTOR_SERIALIZER.serialize((SparseVector) vector, target); + SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.serialize((SparseIntDoubleVector) vector, target); + } else { + target.writeByte(2); + SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.serialize((SparseLongDoubleVector) vector, target); } } @@ -92,20 +108,25 @@ public Vector deserialize(DataInputView source) throws IOException { byte type = source.readByte(); if (type == 0) { return denseVectorSerializer.deserialize(source); + } else if (type == 1) { + return SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.deserialize(source); } else { - return SPARSE_VECTOR_SERIALIZER.deserialize(source); + return SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.deserialize(source); } } @Override public Vector deserialize(Vector reuse, DataInputView source) throws IOException { byte type = source.readByte(); - assert type == 0 && reuse instanceof DenseVector - || type == 1 && reuse instanceof SparseVector; + assert type == 0 && reuse instanceof DenseIntDoubleVector + || type == 1 && reuse instanceof SparseIntDoubleVector + || type == 2 && reuse instanceof SparseLongDoubleVector; if (type == 0) { return denseVectorSerializer.deserialize(source); + } else if (type == 1) { + return SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.deserialize(source); } else { - return SPARSE_VECTOR_SERIALIZER.deserialize(source); + return SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.deserialize(source); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorTypeInfo.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorTypeInfo.java index 672dedf63..71b7a28d3 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorTypeInfo.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorTypeInfo.java @@ -22,9 +22,10 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; -/** A {@link TypeInformation} for the {@link Vector} type. */ +/** A {@link TypeInformation} for the {@link IntDoubleVector} type. */ public class VectorTypeInfo extends TypeInformation { public static final VectorTypeInfo INSTANCE = new VectorTypeInfo(); diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java index 92d1de165..8f8c4a21c 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java @@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; import java.io.IOException; @@ -50,18 +50,18 @@ public TypeSerializer duplicate() { @Override public VectorWithNorm createInstance() { - return new VectorWithNorm(new DenseVector(EMPTY)); + return new VectorWithNorm(new DenseIntDoubleVector(EMPTY)); } @Override public VectorWithNorm copy(VectorWithNorm from) { - Vector vector = vectorSerializer.copy(from.vector); + IntDoubleVector vector = (IntDoubleVector) vectorSerializer.copy(from.vector); return new VectorWithNorm(vector, from.l2Norm); } @Override public VectorWithNorm copy(VectorWithNorm from, VectorWithNorm reuse) { - Vector vector = vectorSerializer.copy(from.vector, reuse.vector); + IntDoubleVector vector = (IntDoubleVector) vectorSerializer.copy(from.vector, reuse.vector); return new VectorWithNorm(vector, from.l2Norm); } @@ -78,7 +78,7 @@ public void serialize(VectorWithNorm from, DataOutputView dataOutputView) throws @Override public VectorWithNorm deserialize(DataInputView dataInputView) throws IOException { - Vector vector = vectorSerializer.deserialize(dataInputView); + IntDoubleVector vector = (IntDoubleVector) vectorSerializer.deserialize(dataInputView); double l2NormSquare = dataInputView.readDouble(); return new VectorWithNorm(vector, l2NormSquare); } @@ -86,7 +86,8 @@ public VectorWithNorm deserialize(DataInputView dataInputView) throws IOExceptio @Override public VectorWithNorm deserialize(VectorWithNorm reuse, DataInputView dataInputView) throws IOException { - Vector vector = vectorSerializer.deserialize(reuse.vector, dataInputView); + IntDoubleVector vector = + (IntDoubleVector) vectorSerializer.deserialize(reuse.vector, dataInputView); double l2NormSquare = dataInputView.readDouble(); return new VectorWithNorm(vector, l2NormSquare); } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/param/VectorParam.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/param/VectorParam.java index 7fbf32b96..179fda831 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/param/VectorParam.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/param/VectorParam.java @@ -18,30 +18,30 @@ package org.apache.flink.ml.param; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import java.util.List; import java.util.Map; /** Class for the Vector parameter. */ -public class VectorParam extends Param { +public class VectorParam extends Param { public VectorParam( String name, String description, - Vector defaultValue, - ParamValidator validator) { - super(name, Vector.class, description, defaultValue, validator); + IntDoubleVector defaultValue, + ParamValidator validator) { + super(name, IntDoubleVector.class, description, defaultValue, validator); } - public VectorParam(String name, String description, Vector defaultValue) { + public VectorParam(String name, String description, IntDoubleVector defaultValue) { this(name, description, defaultValue, ParamValidators.alwaysTrue()); } @Override - public Vector jsonDecode(Object object) { + public IntDoubleVector jsonDecode(Object object) { Map vecValues = (Map) object; if (vecValues.size() == 1) { List list = (List) vecValues.get("values"); @@ -49,7 +49,7 @@ public Vector jsonDecode(Object object) { for (int i = 0; i < values.length; ++i) { values[i] = list.get(i); } - return new DenseVector(values); + return new DenseIntDoubleVector(values); } else if (vecValues.size() == 3) { List valuesList = (List) vecValues.get("values"); List indicesList = (List) vecValues.get("indices"); @@ -60,7 +60,7 @@ public Vector jsonDecode(Object object) { values[i] = valuesList.get(i); indices[i] = indicesList.get(i); } - return new SparseVector(n, indices, values); + return new SparseIntDoubleVector(n, indices, values); } else { throw new UnsupportedOperationException("Vector parameter is invalid."); } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java index 8de3a44d4..fff889edc 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java @@ -18,6 +18,8 @@ package org.apache.flink.ml.util; +import org.apache.flink.api.java.tuple.Tuple2; + /** * Utility methods for packing/unpacking primitive values in/out of byte arrays using big-endian * byte ordering. Referenced from java.io.Bits. @@ -44,6 +46,17 @@ public static double getDouble(byte[] b, int off) { return Double.longBitsToDouble(getLong(b, off)); } + public static int getInt(byte[] b, int off) { + return ((b[off + 3] & 0xFF)) + + ((b[off + 2] & 0xFF) << 8) + + ((b[off + 1] & 0xFF) << 16) + + ((b[off]) << 24); + } + + public static char getChar(byte[] b, int off) { + return (char) ((b[off + 1] & 0xFF) + (b[off] << 8)); + } + /* * Methods for packing primitive values into byte arrays starting at given * offsets. @@ -63,4 +76,107 @@ public static void putLong(byte[] b, int off, long val) { public static void putDouble(byte[] b, int off, double val) { putLong(b, off, Double.doubleToLongBits(val)); } + + public static void putInt(byte[] b, int off, int val) { + b[off + 3] = (byte) (val); + b[off + 2] = (byte) (val >>> 8); + b[off + 1] = (byte) (val >>> 16); + b[off] = (byte) (val >>> 24); + } + + public static void putChar(byte[] b, int off, char val) { + b[off + 1] = (byte) (val); + b[off] = (byte) (val >>> 8); + } + + /** Gets a long array from the byte array starting from the given offset. */ + public static long[] getLongArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + long[] result = new long[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getLong(bytes, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Puts a long array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putLongArray(long[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putLong(bytes, offset, array[i]); + offset += Long.BYTES; + } + return offset; + } + + /** Returns the size of a long array in bytes. */ + public static int getLongArraySizeInBytes(long[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Gets a double array from the byte array starting from the given offset. */ + public static double[] getDoubleArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getDouble(bytes, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Puts a double array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putDoubleArray(double[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putDouble(bytes, offset, array[i]); + offset += Double.BYTES; + } + return offset; + } + + /** Returns the size of a double array in bytes. */ + public static int getDoubleArraySizeInBytes(double[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Gets a long-double array from the byte array starting from the given offset. */ + public static Tuple2 getLongDoubleArray(byte[] bytes, int offset) { + long[] indices = getLongArray(bytes, offset); + offset += getLongArraySizeInBytes(indices); + double[] values = getDoubleArray(bytes, offset); + return Tuple2.of(indices, values); + } + + /** + * Puts a long-double array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putLongDoubleArray( + Tuple2 longDoubleArray, byte[] bytes, int offset) { + offset = putLongArray(longDoubleArray.f0, bytes, offset); + offset = putDoubleArray(longDoubleArray.f1, bytes, offset); + + return offset; + } + + /** Returns the size of a long-double array in bytes. */ + public static int getLongDoubleArraySizeInBytes(Tuple2 longDoubleArray) { + return getLongArraySizeInBytes(longDoubleArray.f0) + + getDoubleArraySizeInBytes(longDoubleArray.f1); + } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java index c38e3d2f8..55a45d738 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java @@ -25,13 +25,14 @@ import org.apache.flink.ml.servable.api.TransformerServable; import org.apache.flink.ml.servable.builder.PipelineModelServable; import org.apache.flink.util.InstantiationUtil; -import org.apache.flink.util.Preconditions; import java.io.IOException; import java.io.InputStream; +import java.io.SequenceInputStream; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -143,10 +144,10 @@ public static InputStream loadModelData(String path) throws IOException { FileSystem fileSystem = modelDataPath.getFileSystem(); FileStatus[] files = fileSystem.listStatus(modelDataPath); - Preconditions.checkState( - files.length == 1, - "Only one model data file is expected in the directory %s.", - path); - return fileSystem.open(files[0].getPath()); + List inputStreams = new ArrayList<>(); + for (FileStatus file : files) { + inputStreams.add(fileSystem.open(file.getPath())); + } + return new SequenceInputStream(Collections.enumeration(inputStreams)); } } diff --git a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java index 5fff534a6..cfa553d41 100644 --- a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java +++ b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java @@ -27,7 +27,7 @@ public class BLASTest { private static final double TOLERANCE = 1e-7; - private static final DenseVector inputDenseVec = Vectors.dense(1, -2, 3, 4, -5); + private static final DenseIntDoubleVector inputDenseVec = Vectors.dense(1, -2, 3, 4, -5); private static final DenseMatrix inputDenseMat = new DenseMatrix(2, 5, new double[] {1, -2, 3, 4, -5, 1, -2, 3, 4, -5}); @@ -39,13 +39,14 @@ public void testAsum() { @Test public void testAxpy() { // Tests axpy(dense, dense). - DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); + DenseIntDoubleVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); BLAS.axpy(1, inputDenseVec, anotherDenseVec); double[] expectedResult = new double[] {2, 0, 6, 8, 0}; assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); // Tests axpy(sparse, dense). - SparseVector sparseVec = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVec = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); BLAS.axpy(2, sparseVec, anotherDenseVec); expectedResult = new double[] {4, 0, 12, 8, 10}; assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); @@ -54,13 +55,14 @@ public void testAxpy() { @Test public void testAxpyK() { // Tests axpy(dense, dense, k). - DenseVector anotherDenseVec = Vectors.dense(1, 2, 3); + DenseIntDoubleVector anotherDenseVec = Vectors.dense(1, 2, 3); BLAS.axpy(1, inputDenseVec, anotherDenseVec, 3); double[] expectedResult = new double[] {2, 0, 6}; assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); // Tests axpy(sparse, dense, k). - SparseVector sparseVec = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVec = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5, 6, 7); BLAS.axpy(2, sparseVec, anotherDenseVec, 5); expectedResult = new double[] {3, 2, 9, 4, 15, 6, 7}; @@ -69,10 +71,10 @@ public void testAxpyK() { @Test public void testDot() { - DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); - SparseVector sparseVector1 = + DenseIntDoubleVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); + SparseIntDoubleVector sparseVector1 = Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {1., 1., 4.}); - SparseVector sparseVector2 = + SparseIntDoubleVector sparseVector2 = Vectors.sparse(5, new int[] {1, 3, 4}, new double[] {1., 2., 1.}); // Tests dot(dense, dense). assertEquals(-3, BLAS.dot(inputDenseVec, anotherDenseVec), TOLERANCE); @@ -88,7 +90,8 @@ public void testDot() { public void testNorm2() { assertEquals(Math.sqrt(55), BLAS.norm2(inputDenseVec), TOLERANCE); - SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVector = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); assertEquals(Math.sqrt(35), BLAS.norm2(sparseVector), TOLERANCE); } @@ -96,7 +99,8 @@ public void testNorm2() { public void testNorm() { assertEquals(Math.sqrt(55), BLAS.norm(inputDenseVec, 2.0), TOLERANCE); - SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVector = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); assertEquals(5.0, BLAS.norm(sparseVector, Double.POSITIVE_INFINITY), TOLERANCE); assertEquals(5.348481241239363, BLAS.norm(sparseVector, 3.0), TOLERANCE); @@ -109,7 +113,7 @@ public void testScal() { double[] expectedDenseResult = new double[] {2, -4, 6, 8, -10}; assertArrayEquals(expectedDenseResult, inputDenseVec.values, TOLERANCE); - SparseVector inputSparseVector = + SparseIntDoubleVector inputSparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); BLAS.scal(1.5, inputSparseVector); @@ -122,7 +126,7 @@ public void testScal() { @Test public void testGemv() { - DenseVector anotherDenseVec = Vectors.dense(1.0, 2.0); + DenseIntDoubleVector anotherDenseVec = Vectors.dense(1.0, 2.0); BLAS.gemv(-2.0, inputDenseMat, false, inputDenseVec, 0.0, anotherDenseVec); double[] expectedResult = new double[] {96.0, -60.0}; assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); @@ -131,16 +135,18 @@ public void testGemv() { @Test public void testHDot() { // Tests hDot(sparse, sparse). - SparseVector sparseVec1 = Vectors.sparse(5, new int[] {0, 2, 3}, new double[] {1, 3, 5}); - SparseVector sparseVec2 = Vectors.sparse(5, new int[] {0, 1, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVec1 = + Vectors.sparse(5, new int[] {0, 2, 3}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVec2 = + Vectors.sparse(5, new int[] {0, 1, 4}, new double[] {1, 3, 5}); BLAS.hDot(sparseVec1, sparseVec2); - assertEquals(5, sparseVec2.size()); + assertEquals(5, sparseVec2.size().intValue()); assertArrayEquals(new int[] {0, 1, 4}, sparseVec2.indices); assertArrayEquals(new double[] {1, 0, 0}, sparseVec2.values, TOLERANCE); // Tests hDot(dense, dense). - DenseVector denseVec1 = Vectors.dense(1, 2, 3, 4, 5); - DenseVector denseVec2 = Vectors.dense(1, 2, 3, 4, 5); + DenseIntDoubleVector denseVec1 = Vectors.dense(1, 2, 3, 4, 5); + DenseIntDoubleVector denseVec2 = Vectors.dense(1, 2, 3, 4, 5); BLAS.hDot(denseVec1, denseVec2); double[] expectedResult = new double[] {1, 4, 9, 16, 25}; assertArrayEquals(expectedResult, denseVec2.values, TOLERANCE); @@ -151,9 +157,9 @@ public void testHDot() { assertArrayEquals(expectedResult, denseVec1.values, TOLERANCE); // Tests hDot(dense, sparse). - DenseVector denseVec3 = Vectors.dense(1, 2, 3, 4, 5); + DenseIntDoubleVector denseVec3 = Vectors.dense(1, 2, 3, 4, 5); BLAS.hDot(denseVec3, sparseVec1); - assertEquals(5, sparseVec1.size()); + assertEquals(5, sparseVec1.size().intValue()); assertArrayEquals(new int[] {0, 2, 3}, sparseVec1.indices); assertArrayEquals(new double[] {1, 9, 20}, sparseVec1.values, TOLERANCE); } diff --git a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java index 427403d17..3ffc429ea 100644 --- a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java +++ b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java @@ -23,15 +23,15 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -/** Tests the behavior of {@link DenseVector}. */ +/** Tests the behavior of {@link DenseIntDoubleVector}. */ public class DenseVectorTest { private static final double TOLERANCE = 1e-7; @Test public void testClone() { - DenseVector denseVec = Vectors.dense(1, 2, 3); - DenseVector clonedDenseVec = denseVec.clone(); + DenseIntDoubleVector denseVec = Vectors.dense(1, 2, 3); + DenseIntDoubleVector clonedDenseVec = denseVec.clone(); assertArrayEquals(clonedDenseVec.values, new double[] {1, 2, 3}, TOLERANCE); clonedDenseVec.values[0] = -1; @@ -41,10 +41,10 @@ public void testClone() { @Test public void testGetAndSet() { - DenseVector denseVec = Vectors.dense(1, 2, 3); + DenseIntDoubleVector denseVec = Vectors.dense(1, 2, 3); assertEquals(1, denseVec.get(0), TOLERANCE); - denseVec.set(0, 2); + denseVec.set(0, 2.0); assertEquals(2, denseVec.get(0), TOLERANCE); } } diff --git a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java index 916963f18..e928c6b14 100644 --- a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java +++ b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java @@ -20,7 +20,7 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorSerializer; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorSerializer; import org.apache.commons.io.output.ByteArrayOutputStream; import org.junit.Assert; @@ -32,7 +32,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -/** Tests the behavior of {@link SparseVector}. */ +/** Tests the behavior of {@link SparseIntDoubleVector}. */ public class SparseVectorTest { private static final double TOLERANCE = 1e-7; @@ -42,7 +42,7 @@ public void testConstructor() { int[] indices = new int[] {0, 2, 3}; double[] values = new double[] {0.1, 0.3, 0.4}; - SparseVector vector = Vectors.sparse(n, indices, values); + SparseIntDoubleVector vector = Vectors.sparse(n, indices, values); assertEquals(n, vector.n); assertArrayEquals(indices, vector.indices); assertArrayEquals(values, vector.values, 1e-5); @@ -67,13 +67,13 @@ public void testDuplicateIndex() { @Test public void testAllZeroVector() { int n = 4; - SparseVector vector = Vectors.sparse(n, new int[0], new double[0]); + SparseIntDoubleVector vector = Vectors.sparse(n, new int[0], new double[0]); assertArrayEquals(vector.toArray(), new double[n], 1e-5); } @Test public void testUnsortedIndex() { - SparseVector vector; + SparseIntDoubleVector vector; vector = Vectors.sparse(4, new int[] {2}, new double[] {0.3}); assertEquals(4, vector.n); @@ -115,8 +115,8 @@ public void testSerializer() throws IOException { int n = 4; int[] indices = new int[] {0, 2, 3}; double[] values = new double[] {0.1, 0.3, 0.4}; - SparseVector vector = Vectors.sparse(n, indices, values); - SparseVectorSerializer serializer = SparseVectorSerializer.INSTANCE; + SparseIntDoubleVector vector = Vectors.sparse(n, indices, values); + SparseIntDoubleVectorSerializer serializer = SparseIntDoubleVectorSerializer.INSTANCE; ByteArrayOutputStream bOutput = new ByteArrayOutputStream(1024); DataOutputViewStreamWrapper output = new DataOutputViewStreamWrapper(bOutput); @@ -125,7 +125,7 @@ public void testSerializer() throws IOException { byte[] b = bOutput.toByteArray(); ByteArrayInputStream bInput = new ByteArrayInputStream(b); DataInputViewStreamWrapper input = new DataInputViewStreamWrapper(bInput); - SparseVector vector2 = serializer.deserialize(input); + SparseIntDoubleVector vector2 = serializer.deserialize(input); assertEquals(vector.n, vector2.n); assertArrayEquals(vector.indices, vector2.indices); @@ -134,9 +134,9 @@ public void testSerializer() throws IOException { @Test public void testClone() { - SparseVector sparseVec = Vectors.sparse(3, new int[] {0, 2}, new double[] {1, 3}); - SparseVector clonedSparseVec = sparseVec.clone(); - assertEquals(3, clonedSparseVec.size()); + SparseIntDoubleVector sparseVec = Vectors.sparse(3, new int[] {0, 2}, new double[] {1, 3}); + SparseIntDoubleVector clonedSparseVec = sparseVec.clone(); + assertEquals(3, clonedSparseVec.size().intValue()); assertArrayEquals(clonedSparseVec.indices, new int[] {0, 2}); assertArrayEquals(clonedSparseVec.values, new double[] {1, 3}, TOLERANCE); @@ -150,7 +150,7 @@ public void testClone() { @Test public void testGetAndSet() { - SparseVector sparseVec = Vectors.sparse(4, new int[] {2}, new double[] {0.3}); + SparseIntDoubleVector sparseVec = Vectors.sparse(4, new int[] {2}, new double[] {0.3}); assertEquals(0, sparseVec.get(0), TOLERANCE); assertEquals(0.3, sparseVec.get(2), TOLERANCE); diff --git a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java index 25b45b089..31d6cf691 100644 --- a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java +++ b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java @@ -26,12 +26,13 @@ public class VectorWithNormTest { @Test public void testL2Norm() { - DenseVector denseVector = Vectors.dense(1, 2, 3); + DenseIntDoubleVector denseVector = Vectors.dense(1, 2, 3); VectorWithNorm denseVectorWithNorm = new VectorWithNorm(denseVector); assertEquals(denseVector, denseVectorWithNorm.vector); assertEquals(Math.sqrt(14), denseVectorWithNorm.l2Norm, 1e-7); - SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 2, 3}); + SparseIntDoubleVector sparseVector = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 2, 3}); VectorWithNorm sparseVectorWithNorm = new VectorWithNorm(sparseVector); assertEquals(sparseVector, sparseVectorWithNorm.vector); assertEquals(Math.sqrt(14), sparseVectorWithNorm.l2Norm, 1e-7); diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java deleted file mode 100644 index 28927e475..000000000 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.classification.logisticregression; - -import org.apache.flink.annotation.VisibleForTesting; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; - -/** Model data of {@link LogisticRegressionModelServable}. */ -public class LogisticRegressionModelData { - - public DenseVector coefficient; - - public long modelVersion; - - public LogisticRegressionModelData() {} - - public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) { - this.coefficient = coefficient; - this.modelVersion = modelVersion; - } - - /** - * Serializes the instance and writes to the output stream. - * - * @param outputStream The stream to write to. - */ - @VisibleForTesting - public void encode(OutputStream outputStream) throws IOException { - DataOutputViewStreamWrapper dataOutputViewStreamWrapper = - new DataOutputViewStreamWrapper(outputStream); - - DenseVectorSerializer serializer = new DenseVectorSerializer(); - serializer.serialize(coefficient, dataOutputViewStreamWrapper); - dataOutputViewStreamWrapper.writeLong(modelVersion); - } - - /** - * Reads and deserializes the model data from the input stream. - * - * @param inputStream The stream to read from. - * @return The model data instance. - */ - static LogisticRegressionModelData decode(InputStream inputStream) throws IOException { - DataInputViewStreamWrapper dataInputViewStreamWrapper = - new DataInputViewStreamWrapper(inputStream); - - DenseVectorSerializer serializer = new DenseVectorSerializer(); - DenseVector coefficient = serializer.deserialize(dataInputViewStreamWrapper); - long modelVersion = dataInputViewStreamWrapper.readLong(); - - return new LogisticRegressionModelData(coefficient, modelVersion); - } -} diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataSegment.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataSegment.java new file mode 100644 index 000000000..127e00371 --- /dev/null +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataSegment.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** Segment model data of {@link LogisticRegressionModelServable}. */ +public class LogisticRegressionModelDataSegment { + + public DenseIntDoubleVector coefficient; + + public long startIndex; + + public long endIndex; + + public long modelVersion; + + public LogisticRegressionModelDataSegment() {} + + public LogisticRegressionModelDataSegment(DenseIntDoubleVector coefficient, long modelVersion) { + this(coefficient, 0L, coefficient.size(), modelVersion); + } + + public LogisticRegressionModelDataSegment( + DenseIntDoubleVector coefficient, long startIndex, long endIndex, long modelVersion) { + this.coefficient = coefficient; + this.startIndex = startIndex; + this.endIndex = endIndex; + this.modelVersion = modelVersion; + } + + /** + * Serializes the instance and writes to the output stream. + * + * @param outputStream The stream to write to. + */ + @VisibleForTesting + public void encode(OutputStream outputStream) throws IOException { + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(outputStream); + + DenseIntDoubleVectorSerializer serializer = new DenseIntDoubleVectorSerializer(); + serializer.serialize(coefficient, dataOutputViewStreamWrapper); + dataOutputViewStreamWrapper.writeLong(startIndex); + dataOutputViewStreamWrapper.writeLong(endIndex); + dataOutputViewStreamWrapper.writeLong(modelVersion); + } + + /** + * Reads and deserializes the model data from the input stream. + * + * @param inputStream The stream to read from. + * @return The model data instance. + */ + static LogisticRegressionModelDataSegment decode(InputStream inputStream) throws IOException { + DataInputViewStreamWrapper dataInputViewStreamWrapper = + new DataInputViewStreamWrapper(inputStream); + + DenseIntDoubleVectorSerializer serializer = new DenseIntDoubleVectorSerializer(); + DenseIntDoubleVector coefficient = serializer.deserialize(dataInputViewStreamWrapper); + long startIndex = dataInputViewStreamWrapper.readLong(); + long endIndex = dataInputViewStreamWrapper.readLong(); + long modelVersion = dataInputViewStreamWrapper.readLong(); + + return new LogisticRegressionModelDataSegment( + coefficient, startIndex, endIndex, modelVersion); + } + + @VisibleForTesting + public static LogisticRegressionModelDataSegment mergeSegments( + List segments) { + long dim = 0; + for (LogisticRegressionModelDataSegment segment : segments) { + dim = Math.max(dim, segment.endIndex); + } + Preconditions.checkState( + dim < Integer.MAX_VALUE, + "The dimension of logistic regression model is larger than INT.MAX. Please consider using distributed inference."); + int intDim = (int) dim; + DenseIntDoubleVector mergedCoefficient = new DenseIntDoubleVector(intDim); + for (LogisticRegressionModelDataSegment segment : segments) { + int startIndex = (int) segment.startIndex; + int endIndex = (int) segment.endIndex; + System.arraycopy( + segment.coefficient.values, + 0, + mergedCoefficient.values, + startIndex, + endIndex - startIndex); + } + return new LogisticRegressionModelDataSegment( + mergedCoefficient, 0, mergedCoefficient.size(), segments.get(0).modelVersion); + } +} diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java index 4cec85131..655ea13b4 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java @@ -20,8 +20,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.servable.api.DataFrame; @@ -47,13 +47,13 @@ public class LogisticRegressionModelServable private final Map, Object> paramMap = new HashMap<>(); - private LogisticRegressionModelData modelData; + private LogisticRegressionModelDataSegment modelData; public LogisticRegressionModelServable() { ParamUtils.initializeMapWithDefaultValues(paramMap, this); } - LogisticRegressionModelServable(LogisticRegressionModelData modelData) { + LogisticRegressionModelServable(LogisticRegressionModelDataSegment modelData) { this(); this.modelData = modelData; } @@ -61,12 +61,12 @@ public LogisticRegressionModelServable() { @Override public DataFrame transform(DataFrame input) { List predictionResults = new ArrayList<>(); - List rawPredictionResults = new ArrayList<>(); + List rawPredictionResults = new ArrayList<>(); int featuresColIndex = input.getIndex(getFeaturesCol()); for (Row row : input.collect()) { - Vector features = (Vector) row.get(featuresColIndex); - Tuple2 dataPoint = transform(features); + IntDoubleVector features = (IntDoubleVector) row.get(featuresColIndex); + Tuple2 dataPoint = transform(features); predictionResults.add(dataPoint.f0); rawPredictionResults.add(dataPoint.f1); } @@ -81,8 +81,19 @@ public DataFrame transform(DataFrame input) { public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs) throws IOException { Preconditions.checkArgument(modelDataInputs.length == 1); + List modelSegments = new ArrayList<>(); + while (true) { + try { + LogisticRegressionModelDataSegment segment = + LogisticRegressionModelDataSegment.decode(modelDataInputs[0]); + modelSegments.add(segment); + } catch (IOException e) { + // Reached the end of model stream. + break; + } + } - modelData = LogisticRegressionModelData.decode(modelDataInputs[0]); + modelData = LogisticRegressionModelDataSegment.mergeSegments(modelSegments); return this; } @@ -103,7 +114,7 @@ public static LogisticRegressionModelServable load(String path) throws IOExcepti * @param feature The input feature. * @return The prediction label and the raw probabilities. */ - protected Tuple2 transform(Vector feature) { + protected Tuple2 transform(IntDoubleVector feature) { double dotValue = BLAS.dot(feature, modelData.coefficient); double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue)); return Tuple2.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob)); diff --git a/flink-ml-uber/pom.xml b/flink-ml-uber/pom.xml index f56b6faf3..5cded2fee 100644 --- a/flink-ml-uber/pom.xml +++ b/flink-ml-uber/pom.xml @@ -94,6 +94,7 @@ under the License. org.apache.flink:flink-ml-lib org.apache.flink:flink-ml-benchmark dev.ludovic.netlib:blas + fastutil:fastutil