From ce947dea9a1601d0f041d77165d55630cba33e59 Mon Sep 17 00:00:00 2001 From: Zhangzp Date: Mon, 8 May 2023 10:58:18 +0800 Subject: [PATCH 01/18] Support SparseVector as input for LogisticRegression --- .../LogisticRegression.java | 2 +- .../LogisticRegressionTest.java | 39 +++++++++++++++++++ .../feature/LabeledPointWithWeight.java | 10 ++--- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index 87cc650c3..2ff61c7c5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -85,7 +85,7 @@ public LogisticRegressionModel fit(Table... inputs) { throw new RuntimeException( "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); } - DenseVector features = + Vector features = ((Vector) dataPoint.getField(getFeaturesCol())) .toDense(); return new LabeledPointWithWeight(features, label, weight); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index f899c281e..a821a9cd0 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -31,6 +31,7 @@ import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.types.BasicType; import org.apache.flink.ml.servable.types.DataTypes; @@ -55,6 +56,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; import static org.apache.flink.ml.util.TestUtils.saveAndLoadServable; import static org.junit.Assert.assertArrayEquals; @@ -104,6 +106,7 @@ public class LogisticRegressionTest extends AbstractTestBase { private static final double TOLERANCE = 1e-7; private Table binomialDataTable; + private Table binomialSparseDataTable; private Table multinomialDataTable; @@ -123,6 +126,29 @@ public void before() { DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, new String[] {"features", "label", "weight"}))); + + List binomialSparseTrainData = + binomialTrainData.stream() + .map( + r -> { + DenseVector features = r.getFieldAs(0); + double label = r.getFieldAs(1); + double weight = r.getFieldAs(2); + return Row.of(features.toSparse(), label, weight); + }) + .collect(Collectors.toList()); + binomialSparseDataTable = + tEnv.fromDataStream( + env.fromCollection( + binomialSparseTrainData, + new RowTypeInfo( + new TypeInformation[] { + SparseVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + multinomialDataTable = tEnv.fromDataStream( env.fromCollection( @@ -313,6 +339,19 @@ public void testGetModelData() throws Exception { assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, 0.1); } + @Test + @SuppressWarnings("unchecked") + public void testGetModelDataFromSparseInput() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + LogisticRegressionModel model = logisticRegression.fit(binomialSparseDataTable); + List modelData = + IteratorUtils.toList( + LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + assertEquals(1, modelData.size()); + assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, 0.1); + } + @Test public void testSetModelData() throws Exception { LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java index 8440bc97d..136231ac8 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java @@ -18,18 +18,18 @@ package org.apache.flink.ml.common.feature; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; /** Utility class to represent a data point that contains features, label and weight. */ public class LabeledPointWithWeight { - private DenseVector features; + private Vector features; private double label; private double weight; - public LabeledPointWithWeight(DenseVector features, double label, double weight) { + public LabeledPointWithWeight(Vector features, double label, double weight) { this.features = features; this.label = label; this.weight = weight; @@ -37,11 +37,11 @@ public LabeledPointWithWeight(DenseVector features, double label, double weight) public LabeledPointWithWeight() {} - public DenseVector getFeatures() { + public Vector getFeatures() { return features; } - public void setFeatures(DenseVector features) { + public void setFeatures(Vector features) { this.features = features; } From 9d58fc7b337a963945ccf6fe6c8984d1b57577d3 Mon Sep 17 00:00:00 2001 From: Zhangzp Date: Tue, 9 May 2023 16:03:07 +0800 Subject: [PATCH 02/18] [hotfix] Fix TableUtils.getRowTypeInfo when the input contains Tuple --- .../org/apache/flink/ml/common/datastream/TableUtils.java | 1 + .../apache/flink/ml/common/datastream/TableUtilsTest.java | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java index 8d278502a..8af464241 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java @@ -60,6 +60,7 @@ public class TableUtils { LOGICAL_TYPE_ROOTS_USING_EXTERNAL_TYPE_INFO.add(LogicalTypeRoot.MAP); LOGICAL_TYPE_ROOTS_USING_EXTERNAL_TYPE_INFO.add(LogicalTypeRoot.MULTISET); LOGICAL_TYPE_ROOTS_USING_EXTERNAL_TYPE_INFO.add(LogicalTypeRoot.ROW); + LOGICAL_TYPE_ROOTS_USING_EXTERNAL_TYPE_INFO.add(LogicalTypeRoot.STRUCTURED_TYPE); } // Constructs a RowTypeInfo from the given schema. Currently, this function does not support diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java index b420ea4d8..e357d80c4 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java @@ -20,6 +20,7 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.linalg.DenseMatrix; import org.apache.flink.ml.linalg.DenseVector; @@ -119,6 +120,12 @@ public void testGetRowTypeInfo() { dataFields.add(new SparseVector(2, new int[] {0}, new double[] {0.1})); preDefinedDataTypes.add(DataTypes.RAW(DenseMatrixTypeInfo.INSTANCE)); dataFields.add(new DenseMatrix(2, 2)); + preDefinedDataTypes.add( + DataTypes.STRUCTURED( + Tuple2.class, + DataTypes.FIELD("f0", DataTypes.BIGINT()), + DataTypes.FIELD("f1", DataTypes.BIGINT()))); + dataFields.add(Tuple2.of(1L, 2L)); Schema.Builder builder = Schema.newBuilder(); for (int i = 0; i < preDefinedDataTypes.size(); i++) { From a92671c465a073154026168eb3630b4790c43e35 Mon Sep 17 00:00:00 2001 From: Zhangzp Date: Fri, 12 May 2023 16:57:05 +0800 Subject: [PATCH 03/18] Expand LogisticRegressionModelData as many pieces --- .../classification/logisticregression.md | 2 +- .../OnlineLogisticRegressionExample.java | 3 +- .../LogisticRegressionModel.java | 15 +++++-- .../LogisticRegressionModelDataUtil.java | 13 +++++- .../LogisticRegressionTest.java | 2 +- .../OnlineLogisticRegressionTest.java | 4 ++ .../tests/test_logisticregression.py | 2 +- .../flink/ml/util/ServableReadWriteUtils.java | 13 +++--- .../LogisticRegressionModelData.java | 17 +++++++- .../LogisticRegressionModelServable.java | 41 ++++++++++++++++++- 10 files changed, 94 insertions(+), 18 deletions(-) diff --git a/docs/content/docs/operators/classification/logisticregression.md b/docs/content/docs/operators/classification/logisticregression.md index b45b8480a..edd9f8d33 100644 --- a/docs/content/docs/operators/classification/logisticregression.md +++ b/docs/content/docs/operators/classification/logisticregression.md @@ -323,7 +323,7 @@ public class OnlineLogisticRegressionExample { // Creates an online LogisticRegression object and initializes its parameters and initial // model data. - Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L); + Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L, 2L, 0L); Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData)); OnlineLogisticRegression olr = new OnlineLogisticRegression() diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java index d4e7b2f27..9523050ee 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java @@ -96,7 +96,8 @@ public static void main(String[] args) { // Creates an online LogisticRegression object and initializes its parameters and initial // model data. - Row initModelData = Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L); + Row initModelData = + Row.of(Vectors.dense(0.41233679404769874, -0.18088118293232122), 0L, 2L, 0L); Table initModelDataTable = tEnv.fromDataStream(env.fromElements(initModelData)); OnlineLogisticRegression olr = new OnlineLogisticRegression() diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java index e777c5faa..248cf9438 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java @@ -43,6 +43,7 @@ import java.io.IOException; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; /** A Model which classifies data using the model data computed by {@link LogisticRegression}. */ @@ -147,10 +148,16 @@ public PredictLabelFunction(String broadcastModelKey, Map, Object> para @Override public Row map(Row dataPoint) { if (servable == null) { - LogisticRegressionModelData modelData = - (LogisticRegressionModelData) - getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); - servable = new LogisticRegressionModelServable(modelData); + List modelData = + getRuntimeContext().getBroadcastVariable(broadcastModelKey); + + if (modelData.size() == 1) { + servable = new LogisticRegressionModelServable(modelData.get(0)); + } else { + LogisticRegressionModelData mergedModel = + LogisticRegressionModelServable.mergePieces(modelData); + servable = new LogisticRegressionModelServable(mergedModel); + } ParamUtils.updateExistingParams(servable, params); } Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java index e6acb7c73..5b4a4f4a5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java @@ -89,7 +89,13 @@ public static DataStream getModelDataStream(Table m StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); return tEnv.toDataStream(modelData) - .map(x -> new LogisticRegressionModelData(x.getFieldAs(0), x.getFieldAs(1))); + .map( + x -> + new LogisticRegressionModelData( + x.getFieldAs(0), + x.getFieldAs(1), + x.getFieldAs(2), + x.getFieldAs(3))); } /** @@ -107,7 +113,10 @@ public static DataStream getModelDataByteStream(Table modelDataTable) { x -> { LogisticRegressionModelData modelData = new LogisticRegressionModelData( - x.getFieldAs(0), x.getFieldAs(1)); + x.getFieldAs(0), + x.getFieldAs(1), + x.getFieldAs(2), + x.getFieldAs(3)); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); modelData.encode(outputStream); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index a821a9cd0..4f6c3fa35 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -316,7 +316,7 @@ public void testSaveLoadAndPredict() throws Exception { tempFolder.newFolder().getAbsolutePath(), LogisticRegressionModel::load); assertEquals( - Arrays.asList("coefficient", "modelVersion"), + Arrays.asList("coefficient", "startIndex", "endIndex", "modelVersion"), model.getModelData()[0].getResolvedSchema().getColumnNames()); Table output = model.transform(binomialDataTable)[0]; verifyPredictionResult( diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java index cac9473c3..3eb2bb047 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java @@ -256,6 +256,8 @@ public void before() throws Exception { new double[] { 0.41233679404769874, -0.18088118293232122 }), + 0L, + 2L, 0L))); initSparseModel = tEnv.fromDataStream( @@ -266,6 +268,8 @@ public void before() throws Exception { 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01 }), + 0L, + 10L, 0L))); } diff --git a/flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py b/flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py index e8091d1f6..c2a218cb2 100644 --- a/flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py +++ b/flink-ml-python/pyflink/ml/classification/tests/test_logisticregression.py @@ -120,7 +120,7 @@ def test_save_load_and_predict(self): model = regression.fit(self.binomial_data_table) self.assertEqual( model.get_model_data()[0].get_schema().get_field_names(), - ['coefficient', 'modelVersion']) + ['coefficient', "startIndex", "endIndex", 'modelVersion']) output = model.transform(self.binomial_data_table)[0] field_names = output.get_schema().get_field_names() self.verify_predict_result( diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java index c38e3d2f8..55a45d738 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/ServableReadWriteUtils.java @@ -25,13 +25,14 @@ import org.apache.flink.ml.servable.api.TransformerServable; import org.apache.flink.ml.servable.builder.PipelineModelServable; import org.apache.flink.util.InstantiationUtil; -import org.apache.flink.util.Preconditions; import java.io.IOException; import java.io.InputStream; +import java.io.SequenceInputStream; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -143,10 +144,10 @@ public static InputStream loadModelData(String path) throws IOException { FileSystem fileSystem = modelDataPath.getFileSystem(); FileStatus[] files = fileSystem.listStatus(modelDataPath); - Preconditions.checkState( - files.length == 1, - "Only one model data file is expected in the directory %s.", - path); - return fileSystem.open(files[0].getPath()); + List inputStreams = new ArrayList<>(); + for (FileStatus file : files) { + inputStreams.add(fileSystem.open(file.getPath())); + } + return new SequenceInputStream(Collections.enumeration(inputStreams)); } } diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java index 28927e475..32934ea30 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java @@ -33,12 +33,23 @@ public class LogisticRegressionModelData { public DenseVector coefficient; + public long startIndex; + + public long endIndex; + public long modelVersion; public LogisticRegressionModelData() {} public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) { + this(coefficient, 0L, coefficient.size(), modelVersion); + } + + public LogisticRegressionModelData( + DenseVector coefficient, long startIndex, long endIndex, long modelVersion) { this.coefficient = coefficient; + this.startIndex = startIndex; + this.endIndex = endIndex; this.modelVersion = modelVersion; } @@ -54,6 +65,8 @@ public void encode(OutputStream outputStream) throws IOException { DenseVectorSerializer serializer = new DenseVectorSerializer(); serializer.serialize(coefficient, dataOutputViewStreamWrapper); + dataOutputViewStreamWrapper.writeLong(startIndex); + dataOutputViewStreamWrapper.writeLong(endIndex); dataOutputViewStreamWrapper.writeLong(modelVersion); } @@ -69,8 +82,10 @@ static LogisticRegressionModelData decode(InputStream inputStream) throws IOExce DenseVectorSerializer serializer = new DenseVectorSerializer(); DenseVector coefficient = serializer.deserialize(dataInputViewStreamWrapper); + long startIndex = dataInputViewStreamWrapper.readLong(); + long endIndex = dataInputViewStreamWrapper.readLong(); long modelVersion = dataInputViewStreamWrapper.readLong(); - return new LogisticRegressionModelData(coefficient, modelVersion); + return new LogisticRegressionModelData(coefficient, startIndex, endIndex, modelVersion); } } diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java index 4cec85131..c2e14029d 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.classification.logisticregression; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.DenseVector; @@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) { public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs) throws IOException { Preconditions.checkArgument(modelDataInputs.length == 1); + List modelPieces = new ArrayList<>(); + while (true) { + try { + LogisticRegressionModelData piece = + LogisticRegressionModelData.decode(modelDataInputs[0]); + modelPieces.add(piece); + } catch (IOException e) { + // Reached the end of model stream. + break; + } + } - modelData = LogisticRegressionModelData.decode(modelDataInputs[0]); + modelData = mergePieces(modelPieces); return this; } + @VisibleForTesting + public static LogisticRegressionModelData mergePieces( + List pieces) { + long dim = 0; + for (LogisticRegressionModelData piece : pieces) { + dim = Math.max(dim, piece.endIndex); + } + // TODO: Add distributed inference for very large models. + Preconditions.checkState( + dim < Integer.MAX_VALUE, + "The dimension of logistic regression model is larger than INT.MAX. Please consider using distributed inference."); + int intDim = (int) dim; + DenseVector mergedCoefficient = new DenseVector(intDim); + for (LogisticRegressionModelData piece : pieces) { + int startIndex = (int) piece.startIndex; + int endIndex = (int) piece.endIndex; + System.arraycopy( + piece.coefficient.values, + 0, + mergedCoefficient.values, + startIndex, + endIndex - startIndex); + } + return new LogisticRegressionModelData( + mergedCoefficient, 0, mergedCoefficient.size(), pieces.get(0).modelVersion); + } + public static LogisticRegressionModelServable load(String path) throws IOException { LogisticRegressionModelServable servable = ServableReadWriteUtils.loadServableParam( From 74b4b7cf1a008b57a10e5a35e87a5d64ca658f8e Mon Sep 17 00:00:00 2001 From: Zhangzp Date: Fri, 12 May 2023 17:09:38 +0800 Subject: [PATCH 04/18] [FLINK-27826] Support training very high dimensional logisticRegression --- flink-ml-lib/pom.xml | 6 + .../LogisticRegressionWithFtrl.java | 380 +++++++++++++++++ .../LogisticRegressionWithFtrlParams.java | 115 +++++ .../common/lossfunc/BinaryLogisticLoss.java | 12 + .../flink/ml/common/lossfunc/LossFunc.java | 10 + .../ml/common/ps/MirrorWorkerOperator.java | 122 ++++++ .../flink/ml/common/ps/RangePartitioner.java | 108 +++++ .../flink/ml/common/ps/ServerAgent.java | 81 ++++ .../flink/ml/common/ps/ServerOperator.java | 301 +++++++++++++ .../flink/ml/common/ps/WorkerOperator.java | 293 +++++++++++++ .../ml/common/ps/message/IndicesToPullM.java | 70 ++++ .../ml/common/ps/message/KVsToPushM.java | 73 ++++ .../flink/ml/common/ps/message/Message.java | 35 ++ .../ml/common/ps/message/MessageType.java | 48 +++ .../ml/common/ps/message/MessageUtils.java | 123 ++++++ .../ml/common/ps/message/ValuesPulledM.java | 71 ++++ .../ml/common/ps/message/ZerosToPushM.java | 76 ++++ .../ml/common/ps/training/IterationStage.java | 32 ++ .../ps/training/IterationStageList.java | 52 +++ .../ml/common/ps/training/ProcessStage.java | 33 ++ .../ml/common/ps/training/PullStage.java | 38 ++ .../ml/common/ps/training/PushStage.java | 32 ++ .../ps/training/SerializableConsumer.java | 29 ++ .../common/ps/training/TrainingContext.java | 50 +++ .../ml/common/ps/training/TrainingUtils.java | 164 ++++++++ .../apache/flink/ml/common/updater/FTRL.java | 150 +++++++ .../flink/ml/common/updater/ModelUpdater.java | 52 +++ .../LogisticRegressionWithFtrlTest.java | 394 ++++++++++++++++++ .../ml/tests/test_ml_lib_completeness.py | 6 + .../feature/LabeledLargePointWithWeight.java | 40 ++ .../java/org/apache/flink/ml/util/Bits.java | 23 + flink-ml-uber/pom.xml | 1 + 32 files changed, 3020 insertions(+) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/KVsToPushM.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ValuesPulledM.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableConsumer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java create mode 100644 flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml index 8050dd78c..37f05ea72 100644 --- a/flink-ml-lib/pom.xml +++ b/flink-ml-lib/pom.xml @@ -138,6 +138,12 @@ under the License. test test-jar + + fastutil + fastutil + 5.0.9 + + diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java new file mode 100644 index 000000000..7a56d9edc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.ProcessStage; +import org.apache.flink.ml.common.ps.training.PullStage; +import org.apache.flink.ml.common.ps.training.PushStage; +import org.apache.flink.ml.common.ps.training.SerializableConsumer; +import org.apache.flink.ml.common.ps.training.TrainingContext; +import org.apache.flink.ml.common.ps.training.TrainingUtils; +import org.apache.flink.ml.common.updater.FTRL; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableFunction; +import org.apache.flink.util.function.SerializableSupplier; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the large scale logistic regression algorithm using FTRL optimizer. + * + *

See https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegressionWithFtrl + implements Estimator, + LogisticRegressionWithFtrlParams { + + private final Map, Object> paramMap = new HashMap<>(); + + public LogisticRegressionWithFtrl() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public LogisticRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + String classificationType = getMultiClass(); + Preconditions.checkArgument( + "auto".equals(classificationType) || "binomial".equals(classificationType), + "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream trainData = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction) + dataPoint -> { + double weight = + getWeightCol() == null + ? 1.0 + : ((Number) + dataPoint.getField( + getWeightCol())) + .doubleValue(); + double label = + ((Number) dataPoint.getField(getLabelCol())) + .doubleValue(); + boolean isBinomial = + Double.compare(0., label) == 0 + || Double.compare(1., label) == 0; + if (!isBinomial) { + throw new RuntimeException( + "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); + } + Tuple2 features = + dataPoint.getFieldAs(getFeaturesCol()); + return new LabeledLargePointWithWeight( + features, label, weight); + }); + + DataStream modelDim; + if (getModelDim() > 0) { + modelDim = trainData.getExecutionEnvironment().fromElements(getModelDim()); + } else { + modelDim = + DataStreamUtils.reduce( + trainData.map(x -> x.features.f0[x.features.f0.length - 1]), + (ReduceFunction) Math::max) + .map((MapFunction) value -> value + 1); + } + + LogisticRegressionWithFtrlTrainingContext trainingContext = + new LogisticRegressionWithFtrlTrainingContext(getParamMap()); + + IterationStageList iterationStages = + new IterationStageList<>(trainingContext); + iterationStages + .addTrainingStage(new ComputeIndices()) + .addTrainingStage( + new PullStage( + (SerializableSupplier) () -> trainingContext.pullIndices, + (SerializableConsumer) + x -> trainingContext.pulledValues = x)) + .addTrainingStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) + .addTrainingStage( + new PushStage( + (SerializableSupplier) () -> trainingContext.pushIndices, + (SerializableSupplier) () -> trainingContext.pushValues)) + .setTerminationCriteria( + (SerializableFunction) + o -> o.iterationId >= getMaxIter()); + FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet()); + + DataStream> rawModelData = + TrainingUtils.train( + modelDim, + trainData, + ftrl, + iterationStages, + getNumServers(), + getNumServerCores()); + + final long modelVersion = 0L; + + DataStream modelData = + rawModelData.map( + tuple3 -> + new LogisticRegressionModelData( + Vectors.dense(tuple3.f2), + tuple3.f0, + tuple3.f1, + modelVersion)); + + LogisticRegressionModel model = + new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); + ParamUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static LogisticRegressionWithFtrl load(StreamTableEnvironment tEnv, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } +} + +/** + * An iteration stage that samples a batch of training data and computes the indices needed to + * compute gradients. + */ +class ComputeIndices extends ProcessStage { + + @Override + public void process(LogisticRegressionWithFtrlTrainingContext context) throws Exception { + context.readInNextBatchData(); + context.pullIndices = computeIndices(context.batchData); + } + + public static long[] computeIndices(List dataPoints) { + LongOpenHashSet indices = new LongOpenHashSet(); + for (LabeledLargePointWithWeight dataPoint : dataPoints) { + long[] notZeros = dataPoint.features.f0; + for (long index : notZeros) { + indices.add(index); + } + } + + long[] sortedIndices = new long[indices.size()]; + Iterator iterator = indices.iterator(); + int i = 0; + while (iterator.hasNext()) { + sortedIndices[i++] = iterator.next(); + } + Arrays.sort(sortedIndices); + return sortedIndices; + } +} + +/** + * An iteration stage that uses the pulled model values and sampled batch data to compute the + * gradients. + */ +class ComputeGradients extends ProcessStage { + private final LossFunc lossFunc; + + public ComputeGradients(LossFunc lossFunc) { + this.lossFunc = lossFunc; + } + + @Override + public void process(LogisticRegressionWithFtrlTrainingContext context) throws IOException { + long[] indices = ComputeIndices.computeIndices(context.batchData); + double[] pulledModelValues = context.pulledValues; + double[] gradients = computeGradient(context.batchData, indices, pulledModelValues); + + context.pushIndices = indices; + context.pushValues = gradients; + } + + private double[] computeGradient( + List batchData, + long[] sortedBatchIndices, + double[] pulledModelValues) { + Long2DoubleOpenHashMap coefficient = new Long2DoubleOpenHashMap(sortedBatchIndices.length); + for (int i = 0; i < sortedBatchIndices.length; i++) { + coefficient.put(sortedBatchIndices[i], pulledModelValues[i]); + } + Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(sortedBatchIndices.length); + + for (LabeledLargePointWithWeight dataPoint : batchData) { + double dot = dot(dataPoint.features, coefficient); + double multiplier = lossFunc.computeGradient(dataPoint.label, dot) * dataPoint.weight; + + long[] featureIndices = dataPoint.features.f0; + double[] featureValues = dataPoint.features.f1; + double z; + for (int i = 0; i < featureIndices.length; i++) { + long currentIndex = featureIndices[i]; + z = featureValues[i] * multiplier + cumGradients.getOrDefault(currentIndex, 0.); + cumGradients.put(currentIndex, z); + } + } + double[] cumGradientValues = new double[sortedBatchIndices.length]; + for (int i = 0; i < sortedBatchIndices.length; i++) { + cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]); + } + return cumGradientValues; + } + + private static double dot( + Tuple2 features, Long2DoubleOpenHashMap coefficient) { + double dot = 0; + for (int i = 0; i < features.f0.length; i++) { + dot += features.f1[i] * coefficient.get(features.f0[i]); + } + return dot; + } +} + +/** The context information of local computing process. */ +class LogisticRegressionWithFtrlTrainingContext + implements TrainingContext, + LogisticRegressionWithFtrlParams { + /** Parameters of LogisticRegressionWithFtrl. */ + private final Map, Object> paramMap; + /** Current iteration id. */ + int iterationId; + /** The local batch size. */ + private int localBatchSize = -1; + /** The training data. */ + private ResettableIterator trainData; + /** The batch of training data for computing gradients. */ + List batchData; + + private ListState batchDataState; + + /** The placeholder for indices to pull for each iteration. */ + long[] pullIndices; + /** The placeholder for the pulled values for each iteration. */ + double[] pulledValues; + /** The placeholder for indices to push for each iteration. */ + long[] pushIndices; + /** The placeholder for values to push for each iteration. */ + double[] pushValues; + + public LogisticRegressionWithFtrlTrainingContext(Map, Object> paramMap) { + this.paramMap = paramMap; + } + + @Override + public void setIterationId(int iterationId) { + this.iterationId = iterationId; + } + + @Override + public void setWorldInfo(int workerId, int numWorkers) { + int globalBatchSize = getGlobalBatchSize(); + this.localBatchSize = globalBatchSize / numWorkers; + if (globalBatchSize % numWorkers > workerId) { + localBatchSize++; + } + this.batchData = new ArrayList<>(localBatchSize); + } + + @Override + public void setTrainData(ResettableIterator trainData) { + this.trainData = (ResettableIterator) trainData; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + batchDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "batchDataState", + TypeInformation.of(LabeledLargePointWithWeight.class))); + + Iterator batchDataIterator = batchDataState.get().iterator(); + if (batchDataIterator.hasNext()) { + batchData = IteratorUtils.toList(batchDataIterator); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + batchDataState.clear(); + if (batchData.size() > 0) { + batchDataState.addAll(batchData); + } + } + + @Override + public Map, Object> getParamMap() { + return paramMap; + } + + /** Reads in next batch of training data. */ + public void readInNextBatchData() throws IOException { + batchData.clear(); + int i = 0; + while (i < localBatchSize && trainData.hasNext()) { + batchData.add(trainData.next()); + i++; + } + if (!trainData.hasNext()) { + trainData.reset(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java new file mode 100644 index 000000000..be00aeab4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.ml.common.param.HasElasticNet; +import org.apache.flink.ml.common.param.HasGlobalBatchSize; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasMaxIter; +import org.apache.flink.ml.common.param.HasMultiClass; +import org.apache.flink.ml.common.param.HasReg; +import org.apache.flink.ml.common.param.HasTol; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.LongParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** Params for {@link LogisticRegressionWithFtrl}. */ +public interface LogisticRegressionWithFtrlParams + extends HasLabelCol, + HasWeightCol, + HasGlobalBatchSize, + HasReg, + HasElasticNet, + HasMultiClass, + HasMaxIter, + HasTol, + LogisticRegressionModelParams { + + Param NUM_SERVERS = + new IntParam( + "numServers", + "Number of servers to store model parameters.", + 1, + ParamValidators.gtEq(1)); + + Param NUM_SERVER_CORES = + new IntParam( + "numServerCores", + "number of cores that a server can use.", + 1, + ParamValidators.gtEq(1)); + + Param ALPHA = + new DoubleParam( + "alpha", + "The alpha parameter of FTRL optimizer.", + 0.1, + ParamValidators.gt(0.0)); + + Param BETA = + new DoubleParam( + "beta", "The beta parameter of FTRL optimizer.", 0.1, ParamValidators.gt(0.0)); + + Param MODEL_DIM = + new LongParam( + "modelDim", "number of features of input data.", 0L, ParamValidators.gtEq(0)); + + default int getNumServers() { + return get(NUM_SERVERS); + } + + default T setNumServers(Integer value) { + return set(NUM_SERVERS, value); + } + + default int getNumServerCores() { + return get(NUM_SERVER_CORES); + } + + default T setNumServerCores(int value) { + return set(NUM_SERVER_CORES, value); + } + + default double getAlpha() { + return get(ALPHA); + } + + default T setAlpha(Double value) { + return set(ALPHA, value); + } + + default double getBeta() { + return get(BETA); + } + + default T setBeta(Double value) { + return set(BETA, value); + } + + default long getModelDim() { + return get(MODEL_DIM); + } + + default T setModelDim(long value) { + return set(MODEL_DIM, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java index cd24c0684..d4d43cbdd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java @@ -47,4 +47,16 @@ public void computeGradient( dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient, dataPoint.getFeatures().size()); } + + @Override + public double computeLoss(double label, double prediction) { + double labelScaled = 2 * label - 1; + return Math.log(1 + Math.exp(-prediction * labelScaled)); + } + + @Override + public double computeGradient(double label, double prediction) { + double labelScaled = 2 * label - 1; + return -labelScaled / (Math.exp(prediction * labelScaled) + 1); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java index a90967a73..0fead4aab 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java @@ -48,4 +48,14 @@ public interface LossFunc extends Serializable { */ void computeGradient( LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient); + + /** Computes loss using the label and the prediction. */ + default double computeLoss(double label, double prediction) { + throw new UnsupportedOperationException("Not supported yet."); + } + + /** Computes gradient using the label and the prediction. */ + default double computeGradient(double label, double prediction) { + throw new UnsupportedOperationException("Not supported yet."); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java new file mode 100644 index 000000000..40d055da4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.ps.message.ValuesPulledM; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +/** + * Merges the message from different servers for one pull request. + * + *

Note that for each single-thread worker, there are at exactly #numServers pieces for each pull + * request in the feedback edge. + */ +public class MirrorWorkerOperator extends AbstractStreamOperator + implements OneInputStreamOperator, byte[]> { + private final int numServers; + private int workerId; + + /** The received messages from servers for the current pull request. */ + private List messageReceived; + + private ListState messageReceivedState; + + public MirrorWorkerOperator(int numServers) { + this.numServers = numServers; + } + + @Override + public void open() throws Exception { + super.open(); + this.workerId = getRuntimeContext().getIndexOfThisSubtask(); + } + + @Override + public void processElement(StreamRecord> element) throws Exception { + Preconditions.checkState(element.getValue().f0 == workerId); + ValuesPulledM pulledModelM = ValuesPulledM.fromBytes(element.getValue().f1); + messageReceived.add(pulledModelM); + trySendingPulls(numServers); + } + + private void trySendingPulls(int numPieces) { + if (messageReceived.size() == numPieces) { + Comparator comparator = Comparator.comparingInt(o -> o.serverId); + messageReceived.sort(comparator); + int size = 0; + for (ValuesPulledM pulledModelM : messageReceived) { + size += pulledModelM.valuesPulled.length; + } + double[] answer = new double[size]; + int offset = 0; + for (ValuesPulledM pulledModelM : messageReceived) { + double[] values = pulledModelM.valuesPulled; + System.arraycopy(values, 0, answer, offset, values.length); + offset += values.length; + } + ValuesPulledM pulledModelM = new ValuesPulledM(-1, workerId, answer); + output.collect(new StreamRecord<>(pulledModelM.toBytes())); + messageReceived.clear(); + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + messageReceivedState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "messageReceivedState", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + messageReceived = new ArrayList<>(); + + Iterator iterator = messageReceivedState.get().iterator(); + if (iterator.hasNext()) { + while (iterator.hasNext()) { + messageReceived.add(ValuesPulledM.fromBytes(iterator.next())); + } + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + messageReceivedState.clear(); + if (messageReceived.size() > 0) { + for (ValuesPulledM valuesPulled : messageReceived) { + messageReceivedState.add(valuesPulled.toBytes()); + } + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java new file mode 100644 index 000000000..d16ef9896 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.Iterator; + +/** Range partitioner for model data. */ +public class RangePartitioner { + public final long dim; + public final int numServers; + public final long[] ranges; + + public RangePartitioner(long dim, int numServers) { + Preconditions.checkArgument( + dim > 0, + String.format( + "Illegal dimension when using %s: %d", + RangePartitioner.class.getSimpleName(), dim)); + + this.dim = dim; + this.numServers = numServers; + this.ranges = new long[numServers + 1]; + long shardSize = dim / numServers; + + for (int serverId = 0; serverId < numServers; serverId++) { + ranges[serverId] = shardSize * serverId; + } + ranges[numServers] = dim; + } + + /** + * Splits the push/pull request according to the given sorted indices and the corresponding + * values. + * + * @param indices Sorted indices of push/pull request. + * @param values The push values if not null. + * @return The split requests for each server task. + */ + public Iterator> splitRequest( + long[] indices, @Nullable double[] values) { + return new RequestsIterator(numServers, indices, values, ranges); + } + + private static class RequestsIterator implements Iterator> { + private final int numServers; + private final long[] indices; + private final double[] values; + private final long[] ranges; + + private int serverId = 0; + + private int s = 0; + + public RequestsIterator( + int numPss, long[] indices, @Nullable double[] values, long[] ranges) { + Preconditions.checkArgument(values == null || values.length % indices.length == 0); + this.numServers = numPss; + this.indices = indices; + this.values = values; + this.ranges = ranges; + } + + @Override + public boolean hasNext() { + return serverId < numServers; + } + + @Override + public Tuple3 next() { + int e = s; + while (e < indices.length && indices[e] < ranges[serverId + 1]) { + e++; + } + + long[] splitIndices = new long[0]; + double[] splitValues = values == null ? null : new double[0]; + if (s < e) { + splitIndices = Arrays.copyOfRange(indices, s, e); + splitValues = values == null ? null : Arrays.copyOfRange(values, s, e); + } + s = e; + serverId++; + return Tuple3.of(serverId - 1, splitIndices, splitValues); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java new file mode 100644 index 000000000..7a6b5dea1 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.ps.message.IndicesToPullM; +import org.apache.flink.ml.common.ps.message.KVsToPushM; +import org.apache.flink.ml.common.ps.message.ZerosToPushM; +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import java.util.Iterator; + +/** ServerAgent resides on each worker. It serves as an agent for workers to talk with servers. */ +public class ServerAgent { + /** Id of the worker that this agent resides on. */ + private final int workerId; + + private RangePartitioner partitioner; + /** The collector on this worker. */ + private final Output>> output; + + public ServerAgent(int workerId, Output>> output) { + this.workerId = workerId; + this.output = output; + } + + public void setPartitioner(RangePartitioner partitioner) { + this.partitioner = partitioner; + } + + /** Pushes a key-value arrays to servers. */ + public void pushKVs(long[] indices, double[] values) { + Iterator> requests = + partitioner.splitRequest(indices, values); + while (requests.hasNext()) { + Tuple3 request = requests.next(); + KVsToPushM kvToPush = + new KVsToPushM(workerId, request.f0, Tuple2.of(request.f1, request.f2)); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, kvToPush.toBytes()))); + } + } + + /** Sends a request to servers to initialize the values stored as zeros. */ + public void zeros() { + for (int serverId = 0; serverId < partitioner.numServers; serverId++) { + long start = partitioner.ranges[serverId]; + long end = partitioner.ranges[serverId + 1]; + ZerosToPushM zerosToPush = new ZerosToPushM(workerId, serverId, start, end); + output.collect(new StreamRecord<>(Tuple2.of(serverId, zerosToPush.toBytes()))); + } + } + + /** Pulls the values from servers with the specified indices. */ + public void pull(long[] indices) { + Iterator> requests = + partitioner.splitRequest(indices, null); + while (requests.hasNext()) { + Tuple3 request = requests.next(); + IndicesToPullM indicesToPullM = new IndicesToPullM(request.f0, workerId, request.f1); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, indicesToPullM.toBytes()))); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java new file mode 100644 index 000000000..cbdcc1d90 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.message.IndicesToPullM; +import org.apache.flink.ml.common.ps.message.KVsToPushM; +import org.apache.flink.ml.common.ps.message.MessageType; +import org.apache.flink.ml.common.ps.message.MessageUtils; +import org.apache.flink.ml.common.ps.message.ValuesPulledM; +import org.apache.flink.ml.common.ps.message.ZerosToPushM; +import org.apache.flink.ml.common.updater.ModelUpdater; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.SerializableObject; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +/** + * The server operator maintains the shared parameters. It receives push/pull requests from {@link + * WorkerOperator} and sends the answer request to {@link MirrorWorkerOperator}. It works closely + * with {@link ModelUpdater} in the following way: + * + *

    + *
  • The server operator deals with the message from workers and decide when to process the + * received message. (i.e., synchronous vs. asynchronous). + *
  • The server operator calls {@link ModelUpdater#handlePush(long[], double[])} and {@link + * ModelUpdater#handlePull(long[])} to process the messages in detail. + *
  • The server operator ensures that {@link ModelUpdater} is robust to failures. + *
  • The server operator outputs the final output parameters by calling {@link + * ModelUpdater#getModelPieces()}. + *
+ * + *

TODO: Add support for asynchronous operations on servers. + * + *

TODO: Add support for maintaining multiple parameters on servers. + */ +public class ServerOperator extends AbstractStreamOperator> + implements OneInputStreamOperator, Tuple2>, + IterationListener> { + /** The logic to answer push/pull request from workers. */ + private final ModelUpdater modelUpdater; + /** Format of model data: start index, end index, dense double array. */ + private final OutputTag> modelOutputTag; + + private int serverId = -1; + + /** + * Lock for output records to downstream operators. Note that we use multiple threads to deal + * with push/pull requests for better performance. + */ + private final SerializableObject lock = new SerializableObject(); + /** Number of threads to answer push/pull requests. */ + private final int numServerCores; + /** Thread pool to answer push/pull requests. */ + private transient ExecutorService fixedThreadPool; + /** The future objects of thread calls in one epoch. */ + private final List> futuresInEpoch = new ArrayList<>(); + /** The accumulated push request from workers by threadId. */ + private final ConcurrentHashMap accumulatedKvsByThreadId; + /** The accumulated results of Kvs. */ + private final Long2DoubleOpenHashMap accumulatedKvs; + /** The state for accumulated Kvs. */ + private ListState accumulatedKvsState; + /** The pending pull requests. */ + private ListState pendingPulls; + + public ServerOperator( + ModelUpdater modelUpdater, + OutputTag> modelOutputTag, + int numServerCores) { + this.modelUpdater = modelUpdater; + this.modelOutputTag = modelOutputTag; + this.numServerCores = numServerCores; + this.accumulatedKvsByThreadId = new ConcurrentHashMap<>(); + this.accumulatedKvs = new Long2DoubleOpenHashMap(); + } + + @Override + public void open() throws Exception { + super.open(); + serverId = getRuntimeContext().getIndexOfThisSubtask(); + fixedThreadPool = Executors.newFixedThreadPool(numServerCores); + } + + @Override + public void processElement(StreamRecord> element) throws Exception { + byte[] request = element.getValue().f1; + MessageType type = MessageUtils.getMessageType(request); + if (type == MessageType.INDICES_TO_PULL) { + pendingPulls.add(request); + } else { + processPushRequest(request); + } + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector> collector) + throws Exception { + // Waits until all pushes have been processed. + for (Future future : futuresInEpoch) { + future.get(); + } + futuresInEpoch.clear(); + + Iterator kvsFromAllThreads = + accumulatedKvsByThreadId.values().iterator(); + if (kvsFromAllThreads.hasNext()) { + Tuple2 kvs = mergeKvsFromAllThreads(kvsFromAllThreads); + modelUpdater.handlePush(kvs.f0, kvs.f1); + accumulatedKvs.clear(); + } + + Iterator pullsIterator = pendingPulls.get().iterator(); + if (pullsIterator.hasNext()) { + // The last iteration contains no pulls. + while (pullsIterator.hasNext()) { + byte[] pull = pullsIterator.next(); + futuresInEpoch.add(fixedThreadPool.submit(() -> processPullRequest(pull))); + } + } + for (Future future : futuresInEpoch) { + future.get(); + } + pendingPulls.clear(); + futuresInEpoch.clear(); + } + + @Override + public void onIterationTerminated( + Context context, Collector> collector) { + Iterator> modelPieces = modelUpdater.getModelPieces(); + while (modelPieces.hasNext()) { + Tuple3 modelPiece = modelPieces.next(); + output.collect(modelOutputTag, new StreamRecord<>(modelPiece)); + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + pendingPulls = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "pendingPulls", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + modelUpdater.initializeState(context); + + accumulatedKvsState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accumulatedKvs", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + + byte[] accumulatedKvsInBytes = + OperatorStateUtils.getUniqueElement(accumulatedKvsState, "accumulatedKvs") + .orElse(null); + if (accumulatedKvsInBytes != null) { + Tuple2 kvs = + MessageUtils.readLongDoubleArray(accumulatedKvsInBytes, 0); + accumulatedKvs.clear(); + for (int i = 0; i < kvs.f0.length; i++) { + accumulatedKvs.put(kvs.f0[i], kvs.f1[i]); + } + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + + // Waits until the futures to finish. + for (Future future : futuresInEpoch) { + future.get(); + } + futuresInEpoch.clear(); + modelUpdater.snapshotState(context); + + // Snapshots the pending pushes. + Tuple2 kvs = + mergeKvsFromAllThreads(accumulatedKvsByThreadId.values().iterator()); + accumulatedKvsState.clear(); + if (kvs.f0.length > 0) { + byte[] bytes = new byte[MessageUtils.getLongDoubleArraySizeInBytes(kvs)]; + MessageUtils.writeLongDoubleArray(kvs, bytes, 0); + accumulatedKvsState.add(bytes); + } + } + + private void processPushRequest(byte[] pushRpc) { + MessageType type = MessageUtils.getMessageType(pushRpc); + if (type == MessageType.ZEROS_TO_PUSH) { + ZerosToPushM zerosToPush = ZerosToPushM.fromBytes(pushRpc); + Preconditions.checkState(serverId == zerosToPush.serverId); + + long start = zerosToPush.startIndex; + long end = zerosToPush.endIndex; + if (zerosToPush.workerId == 0) { + modelUpdater.open(start, end); + } + } else if (type == MessageType.KVS_TO_PUSH) { + futuresInEpoch.add(fixedThreadPool.submit(() -> processPushedKvs(pushRpc))); + } else { + throw new UnsupportedOperationException("Unsupported message type: " + type + "."); + } + } + + private Object processPushedKvs(byte[] pushKv) { + KVsToPushM kvsToPush = KVsToPushM.fromBytes(pushKv); + Preconditions.checkState(kvsToPush.serverId == serverId); + long threadId = Thread.currentThread().getId(); + accumulatedKvsByThreadId.putIfAbsent(threadId, new Long2DoubleOpenHashMap()); + Long2DoubleOpenHashMap tmpGrad = accumulatedKvsByThreadId.get(threadId); + + Tuple2 pushedGrad = kvsToPush.kvs; + long[] indices = pushedGrad.f0; + double[] values = pushedGrad.f1; + for (int i = 0; i < indices.length; i++) { + tmpGrad.merge(indices[i], values[i], Double::sum); + } + + return new Object(); + } + + private Object processPullRequest(byte[] bytesData) { + IndicesToPullM sparsePullModeM = IndicesToPullM.fromBytes(bytesData); + Preconditions.checkState(serverId == sparsePullModeM.serverId); + int workerId = sparsePullModeM.workerId; + long[] indices = sparsePullModeM.indicesToPull; + double[] pulledValues = modelUpdater.handlePull(indices); + ValuesPulledM pulledModelM = new ValuesPulledM(serverId, workerId, pulledValues); + StreamRecord> record = + new StreamRecord<>(Tuple2.of(workerId, pulledModelM.toBytes())); + + // Holds the lock for output. + synchronized (lock) { + output.collect(record); + } + return new Object(); + } + + private Tuple2 mergeKvsFromAllThreads( + Iterator kvsFromAllThreads) { + while (kvsFromAllThreads.hasNext()) { + Long2DoubleOpenHashMap kv = kvsFromAllThreads.next(); + for (Map.Entry entry : kv.entrySet()) { + accumulatedKvs.merge(entry.getKey(), entry.getValue(), Double::sum); + } + kv.clear(); + } + long[] indices = new long[accumulatedKvs.size()]; + double[] values = new double[indices.length]; + int idx = 0; + for (Map.Entry entry : accumulatedKvs.entrySet()) { + indices[idx] = entry.getKey(); + values[idx] = entry.getValue(); + idx++; + } + return Tuple2.of(indices, values); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java new file mode 100644 index 000000000..0e4ef5fbc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.message.ValuesPulledM; +import org.apache.flink.ml.common.ps.training.IterationStage; +import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.ProcessStage; +import org.apache.flink.ml.common.ps.training.PullStage; +import org.apache.flink.ml.common.ps.training.PushStage; +import org.apache.flink.ml.common.ps.training.TrainingContext; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.util.Iterator; + +/** + * The worker operator that executes the machine learning training process following {@link + * IterationStageList}. + * + *

In detail, the worker operator is responsible for the following: + * + *

    + *
  • Caches the training data. + *
  • Initializes the {@link TrainingContext}. + *
  • Splits the {@link IterationStageList} by {@link PullStage} into multiple sequences and map + * it into flink-ml-iterations. + *
  • Executes the process function in each {@link ProcessStage}. + *
  • Executes the push/pull request in {@link PushStage} and {@link PullStage} and talk to + * servers, by reading/writing {@link TrainingContext}. + *
+ */ +public class WorkerOperator + extends AbstractStreamOperator> + implements TwoInputStreamOperator>, + IterationListener> { + /** Number of servers that this worker needs to talk to. */ + private final int numServers; + + /** The agent for each worker to talk with servers. */ + private ServerAgent serverAgent; + + /** The user defined iteration logic. */ + private final IterationStageList iterationStages; + + /** + * Iteration id in terms of {@link IterationStageList}. When we finished processing all stages + * in stageList, the iteration id increments by one. + */ + private int iterationId; + + private ListState iterationIdState; + + /** The id of the stages to execute in iterationStages. */ + private int nextStageToExecute = 0; + + private ListState nextStageToExecuteState; + + /** The cached training data. */ + private ListStateWithCache
trainDataState; + + /** The feedback array from iterations. */ + private byte[] feedback; + + private ListState feedbackState; + + /** Dimension of the model. */ + private long modelDim = 0; + + private ListState modelDimState; + + public WorkerOperator(IterationStageList iterationStages, int numServers) { + this.iterationStages = iterationStages; + this.numServers = numServers; + } + + @Override + public void open() { + int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); + int workerId = getRuntimeContext().getIndexOfThisSubtask(); + this.serverAgent = new ServerAgent(workerId, output); + iterationStages.context.setWorldInfo(workerId, numTasks); + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector> collector) + throws Exception { + if (epochWatermark == 0) { + modelDim = Bits.getLong(feedback, 0); + serverAgent.setPartitioner(new RangePartitioner(modelDim, numServers)); + serverAgent.zeros(); + iterationStages.context.setTrainData(new ResettableTrainDataIterator<>(trainDataState)); + nextStageToExecute = processTrainingStage(nextStageToExecute, iterationStages); + } + } + + @Override + public void onIterationTerminated( + Context context, Collector> collector) { + trainDataState.clear(); + } + + @Override + public void processElement1(StreamRecord
streamRecord) throws Exception { + trainDataState.add(streamRecord.getValue()); + } + + @Override + public void processElement2(StreamRecord streamRecord) throws Exception { + feedback = streamRecord.getValue(); + if (modelDim > 0) { + // Decodes the pulled method and put it in training context. + PullStage pullStage = (PullStage) iterationStages.stageList.get(nextStageToExecute); + ValuesPulledM valuesPulledMessage = ValuesPulledM.fromBytes(streamRecord.getValue()); + Preconditions.checkState( + getRuntimeContext().getIndexOfThisSubtask() == valuesPulledMessage.workerId); + pullStage.valuesConsumer.accept(valuesPulledMessage.valuesPulled); + nextStageToExecute++; + + nextStageToExecute = processTrainingStage(nextStageToExecute, iterationStages); + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + feedbackState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "feedbackArrayState", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + OperatorStateUtils.getUniqueElement(feedbackState, "feedbackArrayState") + .ifPresent(x -> feedback = x); + + trainDataState = + new ListStateWithCache<>( + (getOperatorConfig().getTypeSerializerIn(0, getClass().getClassLoader())), + getContainingTask(), + getRuntimeContext(), + context, + config.getOperatorID()); + + nextStageToExecuteState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("nextStageToExecuteState", Types.INT)); + nextStageToExecute = + OperatorStateUtils.getUniqueElement( + nextStageToExecuteState, "nextStageToExecuteState") + .orElse(0); + + modelDimState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("modelDimState", Types.LONG)); + modelDim = OperatorStateUtils.getUniqueElement(modelDimState, "modelDimState").orElse(0L); + + iterationIdState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("iterationIdState", Types.INT)); + iterationId = + OperatorStateUtils.getUniqueElement(iterationIdState, "iterationIdState").orElse(0); + + if (modelDim > 0) { + serverAgent.setPartitioner(new RangePartitioner(modelDim, numServers)); + } + + iterationStages.context.initializeState(context); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + feedbackState.clear(); + if (feedback != null) { + feedbackState.add(feedback); + } + + nextStageToExecuteState.clear(); + nextStageToExecuteState.add(nextStageToExecute); + modelDimState.clear(); + modelDimState.add(modelDim); + iterationIdState.clear(); + iterationIdState.add(iterationId); + + trainDataState.snapshotState(context); + iterationStages.context.snapshotState(context); + } + + /** + * Processes the stages described in the given iterationStages from the given nextStage id. This + * function processes the stages until it meets an {@link PullStage}. + * + * @param nextStageToExecute id of the next stage to execute in the given iteration stages. + * @param iterationStages iteration stages used to describe the training logic. + * @return the id of the next stage to execute. + */ + @SuppressWarnings("unchecked") + private int processTrainingStage( + int nextStageToExecute, IterationStageList iterationStages) throws Exception { + while (true) { + if (nextStageToExecute >= iterationStages.stageList.size()) { + iterationId++; + iterationStages.context.setIterationId(iterationId); + if (iterationStages.shouldTerminate.apply(iterationStages.context)) { + return -1; + } + nextStageToExecute -= iterationStages.stageList.size(); + } + IterationStage stage = iterationStages.stageList.get(nextStageToExecute); + + if (stage instanceof PullStage) { + // We are not incrementing nextStageToExecute here, since we will need to pull + // values from servers. + PullStage pullStage = ((PullStage) stage); + serverAgent.pull(pullStage.keysSupplier.get()); + return nextStageToExecute; + } else if (stage instanceof PushStage) { + PushStage pushStage = (PushStage) stage; + serverAgent.pushKVs(pushStage.keysSupplier.get(), pushStage.valuesSupplier.get()); + nextStageToExecute++; + } else if (stage instanceof ProcessStage) { + ((ProcessStage) stage).process(iterationStages.context); + nextStageToExecute++; + } else { + throw new IllegalStateException( + "Illegal type of IterationStage: + " + stage.getClass().getSimpleName()); + } + } + } + + /** A resettable iterator for {@link ListStateWithCache}. */ + private static class ResettableTrainDataIterator implements ResettableIterator { + private final ListStateWithCache data; + private Iterator dataIterator; + + public ResettableTrainDataIterator(ListStateWithCache data) throws Exception { + this.data = data; + this.dataIterator = data.get().iterator(); + } + + @Override + public void reset() { + try { + this.dataIterator = data.get().iterator(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean hasNext() { + return dataIterator.hasNext(); + } + + @Override + public T next() { + return dataIterator.next(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java new file mode 100644 index 000000000..c6742b41f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +/** The indices one worker needs to pull from servers. */ +public class IndicesToPullM implements Message { + public final int serverId; + public final int workerId; + public final long[] indicesToPull; + + public static final MessageType MESSAGE_TYPE = MessageType.INDICES_TO_PULL; + + public IndicesToPullM(int serverId, int workerId, long[] indicesToPull) { + this.serverId = serverId; + this.workerId = workerId; + this.indicesToPull = indicesToPull; + } + + public static IndicesToPullM fromBytes(byte[] bytesData) { + int offset = 0; + char type = Bits.getChar(bytesData, offset); + offset += Character.BYTES; + Preconditions.checkState(type == MESSAGE_TYPE.type); + + int psId = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + int workerId = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + long[] toPullIndices = MessageUtils.readLongArray(bytesData, offset); + return new IndicesToPullM(psId, workerId, toPullIndices); + } + + @Override + public byte[] toBytes() { + int numBytes = + Character.BYTES + + Integer.BYTES * 2 + + MessageUtils.getLongArraySizeInBytes(indicesToPull); + byte[] buffer = new byte[numBytes]; + int offset = 0; + + Bits.putChar(buffer, offset, MESSAGE_TYPE.type); + offset += Character.BYTES; + Bits.putInt(buffer, offset, this.serverId); + offset += Integer.BYTES; + Bits.putInt(buffer, offset, this.workerId); + offset += Integer.BYTES; + MessageUtils.writeLongArray(this.indicesToPull, buffer, offset); + return buffer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/KVsToPushM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/KVsToPushM.java new file mode 100644 index 000000000..58be79fdc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/KVsToPushM.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +/** The sparse key-values to push from workers to servers. */ +public class KVsToPushM implements Message { + public final int serverId; + public final int workerId; + public final Tuple2 kvs; + public static final MessageType MESSAGE_TYPE = MessageType.KVS_TO_PUSH; + + public KVsToPushM(int workerId, int serverId, Tuple2 kvs) { + this.workerId = workerId; + this.serverId = serverId; + this.kvs = kvs; + } + + public static KVsToPushM fromBytes(byte[] bytesData) { + int offset = 0; + char type = Bits.getChar(bytesData, offset); + offset += Character.BYTES; + Preconditions.checkState(type == MESSAGE_TYPE.type); + + int workerId = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + int psId = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + Tuple2 grad = MessageUtils.readLongDoubleArray(bytesData, offset); + return new KVsToPushM(workerId, psId, grad); + } + + @Override + public byte[] toBytes() { + int numBytes = + Character.BYTES + + Integer.BYTES + + Integer.BYTES + + MessageUtils.getLongDoubleArraySizeInBytes(kvs); + byte[] buffer = new byte[numBytes]; + int offset = 0; + + Bits.putChar(buffer, offset, MESSAGE_TYPE.type); + offset += Character.BYTES; + + Bits.putInt(buffer, offset, this.workerId); + offset += Integer.BYTES; + Bits.putInt(buffer, offset, this.serverId); + offset += Integer.BYTES; + MessageUtils.writeLongDoubleArray(kvs, buffer, offset); + + return buffer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java new file mode 100644 index 000000000..39bafbc13 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +/** + * The message to be passed between worker node and server node. + * + *

NOTE: Every Message subclass should implement a static method with signature {@code static T + * fromBytes(byte[] bytesData)}, where {@code T} refers to the concrete subclass. This static method + * should instantiate a new Message instance based on the data read from the given byte array. + */ +public interface Message { + /** + * Serializes the message into a byte array. + * + *

Note that the first two bytes of the result buffer is reserved for {@link MessageType}. + */ + byte[] toBytes(); +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java new file mode 100644 index 000000000..b6e9a6afd --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +/** Message Type between workers and servers. */ +public enum MessageType { + ZEROS_TO_PUSH((char) 0), + INDICES_TO_PULL((char) 1), + VALUES_PULLED((char) 2), + KVS_TO_PUSH((char) 3); + + public final char type; + + MessageType(char type) { + this.type = type; + } + + public static MessageType valueOf(char value) { + switch (value) { + case (char) 0: + return MessageType.ZEROS_TO_PUSH; + case (char) 1: + return MessageType.INDICES_TO_PULL; + case ((char) 2): + return MessageType.VALUES_PULLED; + case ((char) 3): + return MessageType.KVS_TO_PUSH; + default: + throw new UnsupportedOperationException(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java new file mode 100644 index 000000000..01c4fa4aa --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.util.Bits; + +/** Utility functions for processing messages. */ +public class MessageUtils { + + /** Retrieves the message type from the byte array. */ + public static MessageType getMessageType(byte[] bytesData) { + char type = Bits.getChar(bytesData, 0); + return MessageType.valueOf(type); + } + + /** Reads a long array from the byte array starting from the given offset. */ + public static long[] readLongArray(byte[] bytesData, int offset) { + int size = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + long[] result = new long[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getLong(bytesData, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Writes a long array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int writeLongArray(long[] array, byte[] bytesData, int offset) { + Bits.putInt(bytesData, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putLong(bytesData, offset, array[i]); + offset += Long.BYTES; + } + return offset; + } + + /** Returns the size of a long array in bytes. */ + public static int getLongArraySizeInBytes(long[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Reads a double array from the byte array starting from the given offset. */ + public static double[] readDoubleArray(byte[] bytesData, int offset) { + int size = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getDouble(bytesData, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Writes a double array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int writeDoubleArray(double[] array, byte[] bytesData, int offset) { + Bits.putInt(bytesData, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putDouble(bytesData, offset, array[i]); + offset += Double.BYTES; + } + return offset; + } + + /** Returns the size of a double array in bytes. */ + public static int getDoubleArraySizeInBytes(double[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Reads a long-double array from the byte array starting from the given offset. */ + public static Tuple2 readLongDoubleArray(byte[] bytesData, int offset) { + long[] indices = readLongArray(bytesData, offset); + offset += getLongArraySizeInBytes(indices); + double[] values = readDoubleArray(bytesData, offset); + return Tuple2.of(indices, values); + } + + /** + * Writes a long-double to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int writeLongDoubleArray( + Tuple2 longDoubleArray, byte[] bytesData, int offset) { + offset = writeLongArray(longDoubleArray.f0, bytesData, offset); + offset = writeDoubleArray(longDoubleArray.f1, bytesData, offset); + + return offset; + } + + /** Returns the size of a long-double array in bytes. */ + public static int getLongDoubleArraySizeInBytes(Tuple2 longDoubleArray) { + return getLongArraySizeInBytes(longDoubleArray.f0) + + getDoubleArraySizeInBytes(longDoubleArray.f1); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ValuesPulledM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ValuesPulledM.java new file mode 100644 index 000000000..61bf6900e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ValuesPulledM.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +/** The values pulled from servers. */ +public class ValuesPulledM implements Message { + public final int serverId; + public final int workerId; + public final double[] valuesPulled; + public static final MessageType MESSAGE_TYPE = MessageType.VALUES_PULLED; + + public ValuesPulledM(int serverId, int workerId, double[] valuesPulled) { + this.serverId = serverId; + this.workerId = workerId; + this.valuesPulled = valuesPulled; + } + + public static ValuesPulledM fromBytes(byte[] bytesData) { + int offset = 0; + char type = Bits.getChar(bytesData, offset); + offset += Character.BYTES; + Preconditions.checkState(type == MESSAGE_TYPE.type); + + int psId = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + int workerId = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + double[] pulledValues = MessageUtils.readDoubleArray(bytesData, offset); + return new ValuesPulledM(psId, workerId, pulledValues); + } + + @Override + public byte[] toBytes() { + int numBytes = + Character.BYTES + + Integer.BYTES + + Integer.BYTES + + MessageUtils.getDoubleArraySizeInBytes(valuesPulled); + byte[] buffer = new byte[numBytes]; + int offset = 0; + Bits.putChar(buffer, offset, MESSAGE_TYPE.type); + offset += Character.BYTES; + + Bits.putInt(buffer, offset, this.serverId); + offset += Integer.BYTES; + Bits.putInt(buffer, offset, this.workerId); + offset += Integer.BYTES; + MessageUtils.writeDoubleArray(valuesPulled, buffer, offset); + + return buffer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java new file mode 100644 index 000000000..d226efad2 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +/** + * Message sent by worker to server that initializes the model as a dense array with defined range. + */ +public class ZerosToPushM implements Message { + public final int workerId; + public final int serverId; + public final long startIndex; + public final long endIndex; + + public static final MessageType MESSAGE_TYPE = MessageType.ZEROS_TO_PUSH; + + public ZerosToPushM(int workerId, int serverId, long startIndex, long endIndex) { + this.workerId = workerId; + this.serverId = serverId; + this.startIndex = startIndex; + this.endIndex = endIndex; + } + + public static ZerosToPushM fromBytes(byte[] bytesData) { + int offset = 0; + char type = Bits.getChar(bytesData, offset); + offset += Character.BYTES; + Preconditions.checkState(type == MESSAGE_TYPE.type); + + int workerId = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + int serverId = Bits.getInt(bytesData, offset); + offset += Integer.BYTES; + long startIndex = Bits.getLong(bytesData, offset); + offset += Long.BYTES; + long endIndex = Bits.getLong(bytesData, offset); + return new ZerosToPushM(workerId, serverId, startIndex, endIndex); + } + + @Override + public byte[] toBytes() { + int numBytes = Character.BYTES + Integer.BYTES + Integer.BYTES + Long.BYTES + Long.BYTES; + byte[] buffer = new byte[numBytes]; + int offset = 0; + Bits.putChar(buffer, offset, MESSAGE_TYPE.type); + offset += Character.BYTES; + + Bits.putInt(buffer, offset, this.workerId); + offset += Integer.BYTES; + Bits.putInt(buffer, offset, this.serverId); + offset += Integer.BYTES; + Bits.putLong(buffer, offset, this.startIndex); + offset += Long.BYTES; + Bits.putLong(buffer, offset, this.endIndex); + + return buffer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java new file mode 100644 index 000000000..13c5909f3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import java.io.Serializable; + +/** + * Iterative machine learning training usually incurs local computation step (e.g., computing + * gradients) and global communication step (e.g., all-reduce and parameter servers to aggregate the + * gradients). + * + *

To describe the above iteration training process, we model the training process as a sequence + * of iteration stages. An iteration stage could be either local computation or global + * communication. + */ +public interface IterationStage extends Serializable {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java new file mode 100644 index 000000000..f6e34095d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.util.function.SerializableFunction; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * A list of iteration stages to express the logic of an iterative machine learning training + * process. + */ +public class IterationStageList implements Serializable { + public final T context; + public Function shouldTerminate; + public List stageList; + + public IterationStageList(T context) { + this.stageList = new ArrayList<>(); + this.context = context; + } + + /** Sets the criteria of termination. */ + public void setTerminationCriteria(SerializableFunction shouldTerminate) { + this.shouldTerminate = shouldTerminate; + } + + /** Adds an iteration stage into the stage list. */ + public IterationStageList addTrainingStage(IterationStage stage) { + stageList.add(stage); + return this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java new file mode 100644 index 000000000..2469b1eeb --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +/** + * A local computation stage of the training process. The input and output of {@link ProcessStage} + * can be accessed via {@link TrainingContext}. + * + * @param Type of the training data. + */ +public abstract class ProcessStage implements IterationStage { + /** + * Does a local computation logic using the information from context. Example stages could be + * computing gradients. + */ + public abstract void process(T context) throws Exception; +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java new file mode 100644 index 000000000..585b23296 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.util.function.SerializableSupplier; + +import java.util.function.Consumer; + +/** + * A communication stage that pulls data from servers using keys as {@code + * PullStage#keysSupplier#get()} and stores the pulled values by {@code + * PullStage#valuesConsumer#accept()}. + */ +public final class PullStage implements IterationStage { + public final SerializableSupplier keysSupplier; + public final Consumer valuesConsumer; + + public PullStage(SerializableSupplier keysSupplier, Consumer valuesConsumer) { + this.keysSupplier = keysSupplier; + this.valuesConsumer = valuesConsumer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java new file mode 100644 index 000000000..952377935 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import java.util.function.Supplier; + +/** A communication stage that push (indices, values) to servers. */ +public class PushStage implements IterationStage { + public final Supplier keysSupplier; + public final Supplier valuesSupplier; + + public PushStage(Supplier keysSupplier, Supplier valuesSupplier) { + this.keysSupplier = keysSupplier; + this.valuesSupplier = valuesSupplier; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableConsumer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableConsumer.java new file mode 100644 index 000000000..ec09d3de3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableConsumer.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import java.io.Serializable; +import java.util.function.Consumer; + +/** + * A serializable {@link Consumer}. + * + * @param the type of results consumed by this consumer. + */ +public interface SerializableConsumer extends Consumer, Serializable {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java new file mode 100644 index 000000000..e12864b77 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; + +import java.io.Serializable; + +/** + * Stores the context information that is alive during the training process. Note that the context + * information will be updated by each {@link IterationStage}. + * + *

Note that subclasses should take care of the snapshot of object stored in {@link + * TrainingContext} if the object satisfies that: the write-process is followed by an {@link + * PullStage}, which is later again read by other stages. + */ +public interface TrainingContext extends Serializable { + /** Sets the current iteration ID. */ + default void setIterationId(int iterationId) {} + + /** Sets the worker id and total number of workers. */ + default void setWorldInfo(int workerId, int numWorkers) {} + + /** Sets the training data. */ + default void setTrainData(ResettableIterator trainData) {} + + /** Recover from state. */ + default void initializeState(StateInitializationContext context) throws Exception {} + + /** Snapshots to state. */ + default void snapshotState(StateSnapshotContext context) throws Exception {} +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java new file mode 100644 index 000000000..20b187358 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.ps.MirrorWorkerOperator; +import org.apache.flink.ml.common.ps.ServerOperator; +import org.apache.flink.ml.common.ps.WorkerOperator; +import org.apache.flink.ml.common.updater.ModelUpdater; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.util.OutputTag; + +/** Utility function to describe iterative training process. */ +public final class TrainingUtils { + + /** + * Executes the training logic described in {@link IterationStageList} and returns the fitted + * model data. + * + * @param modelDim dimension of the input model. + * @param trainData the training data. + * @param iterationStages the iterative training logic. + * @return the fitted model data. + */ + public static DataStream> train( + DataStream modelDim, + DataStream trainData, + ModelUpdater modelUpdater, + IterationStageList iterationStages, + int numServers, + int numServerCores) { + // TODO: Support different types for model data types. + // TODO: Support incremental training for multiple models. + // TODO: Support user defined model partitioner. + + DataStream variableStream = + modelDim.broadcast() + .map( + (MapFunction) + value -> { + byte[] buffer = new byte[Long.BYTES]; + Bits.putLong(buffer, 0, value); + return buffer; + }); + + DataStreamList resultList = + Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(variableStream), + ReplayableDataStreamList.notReplay( + trainData.rebalance().map(x -> x, trainData.getType())), + IterationConfig.newBuilder().build(), + new TrainIterationBody( + modelUpdater, iterationStages, numServers, numServerCores)); + + return resultList.get(0); + } + + /** The iteration implementation for training process. */ + private static class TrainIterationBody implements IterationBody { + private final ModelUpdater modelUpdater; + private final IterationStageList iterationStages; + private final int numServers; + private final int numServerCores; + + public TrainIterationBody( + ModelUpdater modelUpdater, + IterationStageList iterationStages, + int numServers, + int numServerCores) { + this.iterationStages = iterationStages; + this.modelUpdater = modelUpdater; + this.numServers = numServers; + this.numServerCores = numServerCores; + } + + @Override + @SuppressWarnings("unchecked") + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream variableStream = variableStreams.get(0); + DataStream trainData = dataStreams.get(0); + final OutputTag> modelDataOutputTag = + new OutputTag>("MODEL_OUTPUT") {}; + + SingleOutputStreamOperator> messageToServer = + trainData + .connect(variableStream) + .transform( + "workerNode", + new TupleTypeInfo<>( + Types.INT, + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO), + new WorkerOperator(iterationStages, numServers)) + .name("WorkerOp"); + int numWorkers = messageToServer.getParallelism(); + + SingleOutputStreamOperator> messageToWorker = + messageToServer + .partitionCustom( + (Partitioner) + (key, numPartitions) -> key % numPartitions, + (KeySelector, Integer>) + value -> value.f0) + .transform( + "ServerOp", + new TupleTypeInfo<>( + Types.INT, + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO), + new ServerOperator( + modelUpdater, modelDataOutputTag, numServerCores)); + messageToWorker.setParallelism(numServers); + + DataStream combinedMessageToWorker = + messageToWorker + .partitionCustom( + (Partitioner) + (key, numPartitions) -> key % numPartitions, + (KeySelector, Integer>) + value -> value.f0) + .transform( + "MirrorWorkerOp", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + new MirrorWorkerOperator(numServers)) + .setParallelism(numWorkers); + + return new IterationBodyResult( + DataStreamList.of(combinedMessageToWorker), + DataStreamList.of(messageToWorker.getSideOutput(modelDataOutputTag)), + null); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java new file mode 100644 index 000000000..ad4280cca --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.updater; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** The FTRL model updater. */ +public class FTRL implements ModelUpdater { + private final double alpha; + private final double beta; + private final double lambda1; + private final double lambda2; + + // ------ Model data of FTRL optimizer. ----- + private long startIndex; + private long endIndex; + private double[] weight; + private double[] sigma; + private double[] z; + private double[] n; + + private ListState boundaryState; + private ListState modelDataState; + + public FTRL(double alpha, double beta, double lambda1, double lambda2) { + this.alpha = alpha; + this.beta = beta; + this.lambda1 = lambda1; + this.lambda2 = lambda2; + } + + @Override + public void open(long startFeatureIndex, long endFeatureIndex) { + this.startIndex = startFeatureIndex; + this.endIndex = endFeatureIndex; + int modelShardSize = (int) (endIndex - startIndex); + weight = new double[modelShardSize]; + sigma = new double[modelShardSize]; + z = new double[modelShardSize]; + n = new double[modelShardSize]; + } + + @Override + public void handlePush(long[] keys, double[] values) { + for (int i = 0; i < keys.length; i++) { + int index = (int) (keys[i] - startIndex); + double gi = values[i]; + updateModelOnOneDim(gi, index, weight); + } + } + + private void updateModelOnOneDim(double gi, int index, double[] weight) { + double gigi = gi * gi; + sigma[index] = 1 / alpha * (Math.sqrt(n[index] + gigi) - Math.sqrt(n[index])); + z[index] += gi - sigma[index] * weight[index]; + n[index] += gigi; + + if (Math.abs(z[index]) <= lambda1) { + weight[index] = 0; + } else { + weight[index] = + -(z[index] - Math.signum(z[index]) * lambda1) + / ((beta + Math.sqrt(n[index])) / alpha + lambda2); + } + } + + @Override + public double[] handlePull(long[] keys) { + double[] values = new double[keys.length]; + for (int i = 0; i < keys.length; i++) { + values[i] = weight[(int) (keys[i] - startIndex)]; + } + return values; + } + + @Override + public Iterator> getModelPieces() { + List> modelPieces = new ArrayList<>(); + modelPieces.add(Tuple3.of(startIndex, endIndex, weight)); + return modelPieces.iterator(); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + boundaryState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("BoundaryState", Types.LONG)); + + Iterator iterator = boundaryState.get().iterator(); + if (iterator.hasNext()) { + startIndex = iterator.next(); + endIndex = iterator.next(); + } + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelDataState", + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); + Iterator modelData = modelDataState.get().iterator(); + if (modelData.hasNext()) { + weight = modelData.next(); + sigma = modelData.next(); + z = modelData.next(); + n = modelData.next(); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + if (weight != null) { + boundaryState.clear(); + boundaryState.add(startIndex); + boundaryState.add(endIndex); + + modelDataState.clear(); + modelDataState.add(weight); + modelDataState.add(sigma); + modelDataState.add(z); + modelDataState.add(n); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java new file mode 100644 index 000000000..fc4c4af8f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.updater; + +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A model updater that could be used to handle push/pull request from workers. + * + *

Note that model updater should also ensure that model data is robust to failures. + */ +public interface ModelUpdater extends Serializable { + + /** Initialize the model data. */ + void open(long startFeatureIndex, long endFeatureIndex); + + /** Applies the push to update the model data, e.g., using gradient to update model. */ + void handlePush(long[] keys, double[] values); + + /** Applies the pull and return the retrieved model data. */ + double[] handlePull(long[] keys); + + /** Returns model pieces with the format of (startFeatureIdx, endFeatureIdx, modelValues). */ + Iterator> getModelPieces(); + + /** Recover the model data from state. */ + void initializeState(StateInitializationContext context) throws Exception; + + /** Snapshots the model data to state. */ + void snapshotState(StateSnapshotContext context) throws Exception; +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java new file mode 100644 index 000000000..93e9a9a14 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionWithFtrl; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.servable.api.DataFrame; +import org.apache.flink.ml.servable.types.BasicType; +import org.apache.flink.ml.servable.types.DataTypes; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.ByteArrayInputStream; +import java.io.SequenceInputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.flink.ml.util.TestUtils.saveAndLoadServable; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link LogisticRegressionWithFtrl}. */ +public class LogisticRegressionWithFtrlTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private final double[] expectedCoefficient = + new double[] {0.3140991, -0.6776634, -0.5825635, -0.4035519}; + + private static final double TOLERANCE = 1e-7; + + private static final List trainRows = + Arrays.asList( + Row.of(Tuple2.of(new long[] {0, 1}, new double[] {1, 2}), 0., 1.), + Row.of(Tuple2.of(new long[] {0, 2}, new double[] {2, 3}), 0., 2.), + Row.of(Tuple2.of(new long[] {0, 3}, new double[] {3, 4}), 0., 3.), + Row.of(Tuple2.of(new long[] {0, 2}, new double[] {4, 4}), 0., 4.), + Row.of(Tuple2.of(new long[] {0, 1}, new double[] {5, 4}), 0., 5.), + Row.of(Tuple2.of(new long[] {0, 2}, new double[] {11, 3}), 1., 1.), + Row.of(Tuple2.of(new long[] {0, 3}, new double[] {12, 4}), 1., 2.), + Row.of(Tuple2.of(new long[] {0, 1}, new double[] {13, 2}), 1., 3.), + Row.of(Tuple2.of(new long[] {0, 3}, new double[] {14, 4}), 1., 4.), + Row.of(Tuple2.of(new long[] {0, 2}, new double[] {15, 4}), 1., 5.)); + + private static final List testRows = + Arrays.asList( + Row.of(Vectors.sparse(4, new int[] {0, 1}, new double[] {1, 2}), 0., 1.), + Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {2, 3}), 0., 2.), + Row.of(Vectors.sparse(4, new int[] {0, 3}, new double[] {3, 4}), 0., 3.), + Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {4, 4}), 0., 4.), + Row.of(Vectors.sparse(4, new int[] {0, 1}, new double[] {5, 4}), 0., 5.), + Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {11, 3}), 1., 1.), + Row.of(Vectors.sparse(4, new int[] {0, 3}, new double[] {12, 4}), 1., 2.), + Row.of(Vectors.sparse(4, new int[] {0, 1}, new double[] {13, 2}), 1., 3.), + Row.of(Vectors.sparse(4, new int[] {0, 3}, new double[] {14, 4}), 1., 4.), + Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {15, 4}), 1., 5.)); + + private StreamTableEnvironment tEnv; + private Table trainTable; + private Table testTable; + private DataFrame testDataFrame; + + @Before + public void before() { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + tEnv = StreamTableEnvironment.create(env); + trainTable = + tEnv.fromDataStream( + env.fromCollection( + trainRows, + new RowTypeInfo( + new TypeInformation[] { + new TupleTypeInfo<>( + PrimitiveArrayTypeInfo + .LONG_PRIMITIVE_ARRAY_TYPE_INFO, + PrimitiveArrayTypeInfo + .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + testTable = + tEnv.fromDataStream( + env.fromCollection( + testRows, + new RowTypeInfo( + new TypeInformation[] { + SparseVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + + testDataFrame = + TestUtils.constructDataFrame( + new ArrayList<>(Arrays.asList("features", "label", "weight")), + new ArrayList<>( + Arrays.asList( + DataTypes.VECTOR(BasicType.DOUBLE), + DataTypes.DOUBLE, + DataTypes.DOUBLE)), + testRows); + } + + @Test + public void testParam() { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = new LogisticRegressionWithFtrl(); + assertEquals("features", logisticRegressionWithFtrl.getFeaturesCol()); + assertEquals("label", logisticRegressionWithFtrl.getLabelCol()); + assertNull(logisticRegressionWithFtrl.getWeightCol()); + assertEquals(20, logisticRegressionWithFtrl.getMaxIter()); + assertEquals(1e-6, logisticRegressionWithFtrl.getTol(), TOLERANCE); + assertEquals(32, logisticRegressionWithFtrl.getGlobalBatchSize()); + assertEquals(0, logisticRegressionWithFtrl.getReg(), TOLERANCE); + assertEquals(0, logisticRegressionWithFtrl.getElasticNet(), TOLERANCE); + assertEquals("auto", logisticRegressionWithFtrl.getMultiClass()); + assertEquals("prediction", logisticRegressionWithFtrl.getPredictionCol()); + assertEquals("rawPrediction", logisticRegressionWithFtrl.getRawPredictionCol()); + + assertEquals(0.1, logisticRegressionWithFtrl.getAlpha(), TOLERANCE); + assertEquals(0.1, logisticRegressionWithFtrl.getBeta(), TOLERANCE); + assertEquals(0L, logisticRegressionWithFtrl.getModelDim()); + assertEquals(1, logisticRegressionWithFtrl.getNumServers()); + assertEquals(1, logisticRegressionWithFtrl.getNumServerCores()); + + logisticRegressionWithFtrl + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setMaxIter(1000) + .setTol(0.001) + .setGlobalBatchSize(1000) + .setReg(0.1) + .setElasticNet(0.5) + .setMultiClass("binomial") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol") + .setAlpha(0.2) + .setBeta(0.2) + .setModelDim(10000000L) + .setNumServers(4) + .setNumServerCores(2); + assertEquals("test_features", logisticRegressionWithFtrl.getFeaturesCol()); + assertEquals("test_label", logisticRegressionWithFtrl.getLabelCol()); + assertEquals("test_weight", logisticRegressionWithFtrl.getWeightCol()); + assertEquals(1000, logisticRegressionWithFtrl.getMaxIter()); + assertEquals(0.001, logisticRegressionWithFtrl.getTol(), TOLERANCE); + assertEquals(1000, logisticRegressionWithFtrl.getGlobalBatchSize()); + assertEquals(0.1, logisticRegressionWithFtrl.getReg(), TOLERANCE); + assertEquals(0.5, logisticRegressionWithFtrl.getElasticNet(), TOLERANCE); + assertEquals("binomial", logisticRegressionWithFtrl.getMultiClass()); + assertEquals("test_predictionCol", logisticRegressionWithFtrl.getPredictionCol()); + assertEquals("test_rawPredictionCol", logisticRegressionWithFtrl.getRawPredictionCol()); + + assertEquals(0.2, logisticRegressionWithFtrl.getAlpha(), TOLERANCE); + assertEquals(0.2, logisticRegressionWithFtrl.getBeta(), TOLERANCE); + assertEquals(10000000L, logisticRegressionWithFtrl.getModelDim()); + assertEquals(4, logisticRegressionWithFtrl.getNumServers()); + assertEquals(2, logisticRegressionWithFtrl.getNumServerCores()); + } + + @Test + public void testOutputSchema() { + Table tempTable = trainTable.as("test_features", "test_label", "test_weight"); + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl() + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + Table output = logisticRegressionWithFtrl.fit(trainTable).transform(tempTable)[0]; + assertEquals( + Arrays.asList( + "test_features", + "test_label", + "test_weight", + "test_predictionCol", + "test_rawPredictionCol"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + @SuppressWarnings("unchecked") + public void testGetModelData() throws Exception { + int numServers = 2; + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setNumServers(numServers).setNumServerCores(1); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + List modelData = + IteratorUtils.toList( + LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + + assertEquals(numServers, modelData.size()); + + modelData.sort(Comparator.comparingLong(o -> o.startIndex)); + + double[] collectedCoefficient = new double[4]; + for (LogisticRegressionModelData modelPiece : modelData) { + int startIndex = (int) modelPiece.startIndex; + double[] pieceCoeff = modelPiece.coefficient.values; + System.arraycopy(pieceCoeff, 0, collectedCoefficient, startIndex, pieceCoeff.length); + } + assertArrayEquals(expectedCoefficient, collectedCoefficient, 1e-7); + } + + @Test + public void testFitAndPredict() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setNumServers(2); + Table output = logisticRegressionWithFtrl.fit(trainTable).transform(testTable)[0]; + verifyPredictionResult( + output, + logisticRegressionWithFtrl.getFeaturesCol(), + logisticRegressionWithFtrl.getPredictionCol(), + logisticRegressionWithFtrl.getRawPredictionCol()); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setNumServers(2); + logisticRegressionWithFtrl = + TestUtils.saveAndReload( + tEnv, + logisticRegressionWithFtrl, + tempFolder.newFolder().getAbsolutePath(), + LogisticRegressionWithFtrl::load); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + model = + TestUtils.saveAndReload( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + LogisticRegressionModel::load); + assertEquals( + Arrays.asList("coefficient", "startIndex", "endIndex", "modelVersion"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = model.transform(testTable)[0]; + verifyPredictionResult( + output, + logisticRegressionWithFtrl.getFeaturesCol(), + logisticRegressionWithFtrl.getPredictionCol(), + logisticRegressionWithFtrl.getRawPredictionCol()); + } + + @Test + public void testSetModelData() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setNumServers(2); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + + LogisticRegressionModel newModel = new LogisticRegressionModel(); + ParamUtils.updateExistingParams(newModel, model.getParamMap()); + newModel.setModelData(model.getModelData()); + Table output = newModel.transform(testTable)[0]; + verifyPredictionResult( + output, + logisticRegressionWithFtrl.getFeaturesCol(), + logisticRegressionWithFtrl.getPredictionCol(), + logisticRegressionWithFtrl.getRawPredictionCol()); + } + + @Test + public void testSaveLoadServableAndPredict() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setNumServers(2); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + + LogisticRegressionModelServable servable = + saveAndLoadServable( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + LogisticRegressionModel::loadServable); + + DataFrame output = servable.transform(testDataFrame); + verifyPredictionResult( + output, + servable.getFeaturesCol(), + servable.getPredictionCol(), + servable.getRawPredictionCol()); + } + + @Test + public void testSetModelDataToServable() throws Exception { + LogisticRegressionWithFtrl logisticRegressionWithFtrl = + new LogisticRegressionWithFtrl().setNumServers(2); + LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); + List serializedModelData = + IteratorUtils.toList( + LogisticRegressionModelDataUtil.getModelDataByteStream( + model.getModelData()[0]) + .executeAndCollect()); + + LogisticRegressionModelServable servable = new LogisticRegressionModelServable(); + ParamUtils.updateExistingParams(servable, model.getParamMap()); + + List modelStreams = + serializedModelData.stream() + .map(ByteArrayInputStream::new) + .collect(Collectors.toList()); + servable.setModelData(new SequenceInputStream(Collections.enumeration(modelStreams))); + DataFrame output = servable.transform(testDataFrame); + verifyPredictionResult( + output, + servable.getFeaturesCol(), + servable.getPredictionCol(), + servable.getRawPredictionCol()); + } + + private void verifyPredictionResult( + Table output, String featuresCol, String predictionCol, String rawPredictionCol) + throws Exception { + List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row predictionRow : predResult) { + DenseVector feature = ((Vector) predictionRow.getField(featuresCol)).toDense(); + double prediction = (double) predictionRow.getField(predictionCol); + DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0.5); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0.5); + } + } + } + + private void verifyPredictionResult( + DataFrame output, String featuresCol, String predictionCol, String rawPredictionCol) { + int featuresColIndex = output.getIndex(featuresCol); + int predictionColIndex = output.getIndex(predictionCol); + int rawPredictionColIndex = output.getIndex(rawPredictionCol); + + for (org.apache.flink.ml.servable.api.Row predictionRow : output.collect()) { + DenseVector feature = ((Vector) predictionRow.get(featuresColIndex)).toDense(); + double prediction = (double) predictionRow.get(predictionColIndex); + DenseVector rawPrediction = (DenseVector) predictionRow.get(rawPredictionColIndex); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0.5); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0.5); + } + } + } +} diff --git a/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py b/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py index db59df0b7..2ce8786a9 100644 --- a/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py +++ b/flink-ml-python/pyflink/ml/tests/test_ml_lib_completeness.py @@ -103,6 +103,12 @@ def module(self): from pyflink.ml import classification return classification + def exclude_java_stage(self) -> List[str]: + # TODO: Add python support for LogisticRegressionWithFtrl. + return [ + "logisticregression.LogisticRegressionWithFtrl", + ] + class ClusteringCompletenessTest(CompletenessTest, MLLibTest): diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java new file mode 100644 index 000000000..43ad621e5 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.feature; + +import org.apache.flink.api.java.tuple.Tuple2; + +/** A data point to represent values that use long as index and double as values. */ +public class LabeledLargePointWithWeight { + public Tuple2 features; + + public double label; + + public double weight; + + public LabeledLargePointWithWeight( + Tuple2 features, double label, double weight) { + this.features = features; + this.label = label; + this.weight = weight; + } + + /** Makes it pojo to use flink serializer. */ + public LabeledLargePointWithWeight() {} +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java index 8de3a44d4..f28231704 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java @@ -44,6 +44,17 @@ public static double getDouble(byte[] b, int off) { return Double.longBitsToDouble(getLong(b, off)); } + public static int getInt(byte[] b, int off) { + return ((b[off + 3] & 0xFF)) + + ((b[off + 2] & 0xFF) << 8) + + ((b[off + 1] & 0xFF) << 16) + + ((b[off]) << 24); + } + + public static char getChar(byte[] b, int off) { + return (char) ((b[off + 1] & 0xFF) + (b[off] << 8)); + } + /* * Methods for packing primitive values into byte arrays starting at given * offsets. @@ -63,4 +74,16 @@ public static void putLong(byte[] b, int off, long val) { public static void putDouble(byte[] b, int off, double val) { putLong(b, off, Double.doubleToLongBits(val)); } + + public static void putInt(byte[] b, int off, int val) { + b[off + 3] = (byte) (val); + b[off + 2] = (byte) (val >>> 8); + b[off + 1] = (byte) (val >>> 16); + b[off] = (byte) (val >>> 24); + } + + public static void putChar(byte[] b, int off, char val) { + b[off + 1] = (byte) (val); + b[off] = (byte) (val >>> 8); + } } diff --git a/flink-ml-uber/pom.xml b/flink-ml-uber/pom.xml index f56b6faf3..5cded2fee 100644 --- a/flink-ml-uber/pom.xml +++ b/flink-ml-uber/pom.xml @@ -94,6 +94,7 @@ under the License. org.apache.flink:flink-ml-lib org.apache.flink:flink-ml-benchmark dev.ludovic.netlib:blas + fastutil:fastutil From e70c3fe46173a6ac86d2d165ef8bb7b5f93f3a86 Mon Sep 17 00:00:00 2001 From: Zhangzp Date: Mon, 29 May 2023 10:13:34 +0800 Subject: [PATCH 05/18] Average the gradient from workers --- .../LogisticRegressionWithFtrl.java | 10 ++++++- .../flink/ml/common/ps/RangePartitioner.java | 2 +- .../apache/flink/ml/common/updater/FTRL.java | 7 +++-- .../LogisticRegressionWithFtrlTest.java | 26 +++++++++++-------- 4 files changed, 30 insertions(+), 15 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java index 7a56d9edc..f8c0c4343 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -38,6 +38,7 @@ import org.apache.flink.ml.common.ps.training.TrainingContext; import org.apache.flink.ml.common.ps.training.TrainingUtils; import org.apache.flink.ml.common.updater.FTRL; +import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -150,7 +151,13 @@ public LogisticRegressionModel fit(Table... inputs) { .setTerminationCriteria( (SerializableFunction) o -> o.iterationId >= getMaxIter()); - FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet()); + FTRL ftrl = + new FTRL( + getAlpha(), + getBeta(), + getReg(), + getElasticNet(), + trainData.getParallelism()); DataStream> rawModelData = TrainingUtils.train( @@ -274,6 +281,7 @@ private double[] computeGradient( for (int i = 0; i < sortedBatchIndices.length; i++) { cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]); } + BLAS.scal(1.0 / batchData.size(), Vectors.dense(cumGradientValues)); return cumGradientValues; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java index d16ef9896..f550fa7f2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java @@ -75,7 +75,7 @@ private static class RequestsIterator implements Iterator boundaryState; private ListState modelDataState; - public FTRL(double alpha, double beta, double lambda1, double lambda2) { + public FTRL(double alpha, double beta, double lambda1, double lambda2, int numWorkers) { this.alpha = alpha; this.beta = beta; this.lambda1 = lambda1; this.lambda2 = lambda2; + this.numWorkers = numWorkers; } @Override @@ -70,7 +73,7 @@ public void open(long startFeatureIndex, long endFeatureIndex) { public void handlePush(long[] keys, double[] values) { for (int i = 0; i < keys.length; i++) { int index = (int) (keys[i] - startIndex); - double gi = values[i]; + double gi = values[i] / numWorkers; updateModelOnOneDim(gi, index, weight); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java index 93e9a9a14..1fca23dda 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -69,8 +69,9 @@ public class LogisticRegressionWithFtrlTest { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); private final double[] expectedCoefficient = - new double[] {0.3140991, -0.6776634, -0.5825635, -0.4035519}; - + new double[] {0.5287258, -1.2163098, -1.0710997, -0.8591691}; + private static final int MAX_ITER = 100; + private static final int NUM_SERVERS = 2; private static final double TOLERANCE = 1e-7; private static final List trainRows = @@ -100,14 +101,16 @@ public class LogisticRegressionWithFtrlTest { Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {15, 4}), 1., 5.)); private StreamTableEnvironment tEnv; + private StreamExecutionEnvironment env; private Table trainTable; private Table testTable; private DataFrame testDataFrame; @Before public void before() { - StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + env = TestUtils.getExecutionEnvironment(); tEnv = StreamTableEnvironment.create(env); + trainTable = tEnv.fromDataStream( env.fromCollection( @@ -227,16 +230,17 @@ public void testOutputSchema() { @Test @SuppressWarnings("unchecked") public void testGetModelData() throws Exception { - int numServers = 2; + // Fix the parallelism as one for stability tests. + env.setParallelism(1); LogisticRegressionWithFtrl logisticRegressionWithFtrl = - new LogisticRegressionWithFtrl().setNumServers(numServers).setNumServerCores(1); + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); List modelData = IteratorUtils.toList( LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); - assertEquals(numServers, modelData.size()); + assertEquals(NUM_SERVERS, modelData.size()); modelData.sort(Comparator.comparingLong(o -> o.startIndex)); @@ -252,7 +256,7 @@ public void testGetModelData() throws Exception { @Test public void testFitAndPredict() throws Exception { LogisticRegressionWithFtrl logisticRegressionWithFtrl = - new LogisticRegressionWithFtrl().setNumServers(2); + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); Table output = logisticRegressionWithFtrl.fit(trainTable).transform(testTable)[0]; verifyPredictionResult( output, @@ -264,7 +268,7 @@ public void testFitAndPredict() throws Exception { @Test public void testSaveLoadAndPredict() throws Exception { LogisticRegressionWithFtrl logisticRegressionWithFtrl = - new LogisticRegressionWithFtrl().setNumServers(2); + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); logisticRegressionWithFtrl = TestUtils.saveAndReload( tEnv, @@ -292,7 +296,7 @@ public void testSaveLoadAndPredict() throws Exception { @Test public void testSetModelData() throws Exception { LogisticRegressionWithFtrl logisticRegressionWithFtrl = - new LogisticRegressionWithFtrl().setNumServers(2); + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); LogisticRegressionModel newModel = new LogisticRegressionModel(); @@ -309,7 +313,7 @@ public void testSetModelData() throws Exception { @Test public void testSaveLoadServableAndPredict() throws Exception { LogisticRegressionWithFtrl logisticRegressionWithFtrl = - new LogisticRegressionWithFtrl().setNumServers(2); + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); LogisticRegressionModelServable servable = @@ -330,7 +334,7 @@ public void testSaveLoadServableAndPredict() throws Exception { @Test public void testSetModelDataToServable() throws Exception { LogisticRegressionWithFtrl logisticRegressionWithFtrl = - new LogisticRegressionWithFtrl().setNumServers(2); + new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); List serializedModelData = IteratorUtils.toList( From 396632133c5e67478448bc23a3fb3977d01ab082 Mon Sep 17 00:00:00 2001 From: Zhangzp Date: Mon, 29 May 2023 18:06:53 +0800 Subject: [PATCH 06/18] Support pull/push value as array --- .../LogisticRegressionWithFtrl.java | 20 +- .../LogisticRegressionWithFtrlParams.java | 15 - .../flink/ml/common/ps/RangePartitioner.java | 21 +- .../flink/ml/common/ps/ServerOperator.java | 269 ++++++++++-------- .../flink/ml/common/ps/WorkerOperator.java | 2 +- ...PushM.java => InitializeModelAsZeroM.java} | 20 +- .../{IndicesToPullM.java => PullIndexM.java} | 18 +- .../{ValuesPulledM.java => PulledValueM.java} | 17 +- .../message/{KVsToPushM.java => PushKvM.java} | 17 +- .../ps/training/IterationStageList.java | 2 +- .../{TrainingContext.java => MLSession.java} | 2 +- .../ml/common/ps/training/TrainingUtils.java | 17 +- .../ml/common/{ => ps}/updater/FTRL.java | 7 +- .../common/{ => ps}/updater/ModelUpdater.java | 4 +- .../LogisticRegressionWithFtrlTest.java | 10 +- 15 files changed, 234 insertions(+), 207 deletions(-) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/{ZerosToPushM.java => InitializeModelAsZeroM.java} (80%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/{IndicesToPullM.java => PullIndexM.java} (81%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/{ValuesPulledM.java => PulledValueM.java} (83%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/{KVsToPushM.java => PushKvM.java} (83%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/{TrainingContext.java => MLSession.java} (96%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/{ => ps}/updater/FTRL.java (95%) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/{ => ps}/updater/ModelUpdater.java (95%) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java index f8c0c4343..de81651ff 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -137,14 +137,14 @@ public LogisticRegressionModel fit(Table... inputs) { IterationStageList iterationStages = new IterationStageList<>(trainingContext); iterationStages - .addTrainingStage(new ComputeIndices()) - .addTrainingStage( + .addStage(new ComputeIndices()) + .addStage( new PullStage( (SerializableSupplier) () -> trainingContext.pullIndices, (SerializableConsumer) x -> trainingContext.pulledValues = x)) - .addTrainingStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) - .addTrainingStage( + .addStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) + .addStage( new PushStage( (SerializableSupplier) () -> trainingContext.pushIndices, (SerializableSupplier) () -> trainingContext.pushValues)) @@ -160,13 +160,7 @@ public LogisticRegressionModel fit(Table... inputs) { trainData.getParallelism()); DataStream> rawModelData = - TrainingUtils.train( - modelDim, - trainData, - ftrl, - iterationStages, - getNumServers(), - getNumServerCores()); + TrainingUtils.train(modelDim, trainData, ftrl, iterationStages, getNumServers()); final long modelVersion = 0L; @@ -341,8 +335,8 @@ public void setWorldInfo(int workerId, int numWorkers) { } @Override - public void setTrainData(ResettableIterator trainData) { - this.trainData = (ResettableIterator) trainData; + public void setInputData(ResettableIterator inputData) { + this.trainData = (ResettableIterator) inputData; } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java index be00aeab4..4b2bd72a9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrlParams.java @@ -51,13 +51,6 @@ public interface LogisticRegressionWithFtrlParams 1, ParamValidators.gtEq(1)); - Param NUM_SERVER_CORES = - new IntParam( - "numServerCores", - "number of cores that a server can use.", - 1, - ParamValidators.gtEq(1)); - Param ALPHA = new DoubleParam( "alpha", @@ -81,14 +74,6 @@ default T setNumServers(Integer value) { return set(NUM_SERVERS, value); } - default int getNumServerCores() { - return get(NUM_SERVER_CORES); - } - - default T setNumServerCores(int value) { - return set(NUM_SERVER_CORES, value); - } - default double getAlpha() { return get(ALPHA); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java index f550fa7f2..79fd24a72 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java @@ -67,6 +67,12 @@ private static class RequestsIterator implements Iterator next() { double[] splitValues = values == null ? null : new double[0]; if (s < e) { splitIndices = Arrays.copyOfRange(indices, s, e); - splitValues = values == null ? null : Arrays.copyOfRange(values, s, e); + splitValues = + values == null + ? null + : Arrays.copyOfRange( + values, s * numValuesPerKey, e * numValuesPerKey); } s = e; serverId++; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java index cbdcc1d90..90fed895e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -44,11 +44,13 @@ import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; +import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -60,7 +62,7 @@ * *

    *
  • The server operator deals with the message from workers and decide when to process the - * received message. (i.e., synchronous vs. asynchronous). + * received message. *
  • The server operator calls {@link ModelUpdater#handlePush(long[], double[])} and {@link * ModelUpdater#handlePull(long[])} to process the messages in detail. *
  • The server operator ensures that {@link ModelUpdater} is robust to failures. @@ -79,55 +81,59 @@ public class ServerOperator extends AbstractStreamOperator> modelOutputTag; - + /** Index of the server task. */ private int serverId = -1; - /** - * Lock for output records to downstream operators. Note that we use multiple threads to deal - * with push/pull requests for better performance. + * Thread pool to answer push/pull requests, to decouple the network traffic and computation + * logic. */ - private final SerializableObject lock = new SerializableObject(); - /** Number of threads to answer push/pull requests. */ - private final int numServerCores; - /** Thread pool to answer push/pull requests. */ - private transient ExecutorService fixedThreadPool; + private transient ExecutorService singleThreadExecutor; /** The future objects of thread calls in one epoch. */ private final List> futuresInEpoch = new ArrayList<>(); - /** The accumulated push request from workers by threadId. */ - private final ConcurrentHashMap accumulatedKvsByThreadId; - /** The accumulated results of Kvs. */ - private final Long2DoubleOpenHashMap accumulatedKvs; - /** The state for accumulated Kvs. */ - private ListState accumulatedKvsState; + /** The merger for push requests. */ + private final PushRequestMerger pushRequestMerger; /** The pending pull requests. */ private ListState pendingPulls; public ServerOperator( - ModelUpdater modelUpdater, - OutputTag> modelOutputTag, - int numServerCores) { + ModelUpdater modelUpdater, OutputTag> modelOutputTag) { this.modelUpdater = modelUpdater; this.modelOutputTag = modelOutputTag; - this.numServerCores = numServerCores; - this.accumulatedKvsByThreadId = new ConcurrentHashMap<>(); - this.accumulatedKvs = new Long2DoubleOpenHashMap(); + this.pushRequestMerger = new PushRequestMerger(); } @Override public void open() throws Exception { super.open(); - serverId = getRuntimeContext().getIndexOfThisSubtask(); - fixedThreadPool = Executors.newFixedThreadPool(numServerCores); + this.serverId = getRuntimeContext().getIndexOfThisSubtask(); + this.singleThreadExecutor = Executors.newSingleThreadExecutor(); } @Override public void processElement(StreamRecord> element) throws Exception { byte[] request = element.getValue().f1; MessageType type = MessageUtils.getMessageType(request); - if (type == MessageType.INDICES_TO_PULL) { - pendingPulls.add(request); - } else { - processPushRequest(request); + switch (type) { + case INDICES_TO_PULL: + pendingPulls.add(request); + break; + case ZEROS_TO_PUSH: + ZerosToPushM zerosToPush = ZerosToPushM.fromBytes(request); + Preconditions.checkState(serverId == zerosToPush.serverId); + + long start = zerosToPush.startIndex; + long end = zerosToPush.endIndex; + if (zerosToPush.workerId == 0) { + modelUpdater.open(start, end); + } + break; + case KVS_TO_PUSH: + futuresInEpoch.add( + singleThreadExecutor.submit( + () -> pushRequestMerger.processPushRequest(request))); + break; + default: + throw new UnsupportedOperationException("Unsupported message type: " + type + "."); } } @@ -141,12 +147,11 @@ public void onEpochWatermarkIncremented( } futuresInEpoch.clear(); - Iterator kvsFromAllThreads = - accumulatedKvsByThreadId.values().iterator(); - if (kvsFromAllThreads.hasNext()) { - Tuple2 kvs = mergeKvsFromAllThreads(kvsFromAllThreads); + if (epochWatermark > 0) { + Tuple2 kvs = pushRequestMerger.toKvArrays(); + pushRequestMerger.accumulatedKvsForMatrix.clear(); + pushRequestMerger.accumulatedKvsForVector.clear(); modelUpdater.handlePush(kvs.f0, kvs.f1); - accumulatedKvs.clear(); } Iterator pullsIterator = pendingPulls.get().iterator(); @@ -154,7 +159,7 @@ public void onEpochWatermarkIncremented( // The last iteration contains no pulls. while (pullsIterator.hasNext()) { byte[] pull = pullsIterator.next(); - futuresInEpoch.add(fixedThreadPool.submit(() -> processPullRequest(pull))); + futuresInEpoch.add(singleThreadExecutor.submit(() -> processPullRequest(pull))); } } for (Future future : futuresInEpoch) { @@ -184,25 +189,7 @@ public void initializeState(StateInitializationContext context) throws Exception "pendingPulls", PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); modelUpdater.initializeState(context); - - accumulatedKvsState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "accumulatedKvs", - PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); - - byte[] accumulatedKvsInBytes = - OperatorStateUtils.getUniqueElement(accumulatedKvsState, "accumulatedKvs") - .orElse(null); - if (accumulatedKvsInBytes != null) { - Tuple2 kvs = - MessageUtils.readLongDoubleArray(accumulatedKvsInBytes, 0); - accumulatedKvs.clear(); - for (int i = 0; i < kvs.f0.length; i++) { - accumulatedKvs.put(kvs.f0[i], kvs.f1[i]); - } - } + pushRequestMerger.initializeState(context); } @Override @@ -215,51 +202,7 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } futuresInEpoch.clear(); modelUpdater.snapshotState(context); - - // Snapshots the pending pushes. - Tuple2 kvs = - mergeKvsFromAllThreads(accumulatedKvsByThreadId.values().iterator()); - accumulatedKvsState.clear(); - if (kvs.f0.length > 0) { - byte[] bytes = new byte[MessageUtils.getLongDoubleArraySizeInBytes(kvs)]; - MessageUtils.writeLongDoubleArray(kvs, bytes, 0); - accumulatedKvsState.add(bytes); - } - } - - private void processPushRequest(byte[] pushRpc) { - MessageType type = MessageUtils.getMessageType(pushRpc); - if (type == MessageType.ZEROS_TO_PUSH) { - ZerosToPushM zerosToPush = ZerosToPushM.fromBytes(pushRpc); - Preconditions.checkState(serverId == zerosToPush.serverId); - - long start = zerosToPush.startIndex; - long end = zerosToPush.endIndex; - if (zerosToPush.workerId == 0) { - modelUpdater.open(start, end); - } - } else if (type == MessageType.KVS_TO_PUSH) { - futuresInEpoch.add(fixedThreadPool.submit(() -> processPushedKvs(pushRpc))); - } else { - throw new UnsupportedOperationException("Unsupported message type: " + type + "."); - } - } - - private Object processPushedKvs(byte[] pushKv) { - KVsToPushM kvsToPush = KVsToPushM.fromBytes(pushKv); - Preconditions.checkState(kvsToPush.serverId == serverId); - long threadId = Thread.currentThread().getId(); - accumulatedKvsByThreadId.putIfAbsent(threadId, new Long2DoubleOpenHashMap()); - Long2DoubleOpenHashMap tmpGrad = accumulatedKvsByThreadId.get(threadId); - - Tuple2 pushedGrad = kvsToPush.kvs; - long[] indices = pushedGrad.f0; - double[] values = pushedGrad.f1; - for (int i = 0; i < indices.length; i++) { - tmpGrad.merge(indices[i], values[i], Double::sum); - } - - return new Object(); + pushRequestMerger.snapshotState(context); } private Object processPullRequest(byte[] bytesData) { @@ -272,30 +215,120 @@ private Object processPullRequest(byte[] bytesData) { StreamRecord> record = new StreamRecord<>(Tuple2.of(workerId, pulledModelM.toBytes())); - // Holds the lock for output. - synchronized (lock) { - output.collect(record); - } + output.collect(record); return new Object(); } - private Tuple2 mergeKvsFromAllThreads( - Iterator kvsFromAllThreads) { - while (kvsFromAllThreads.hasNext()) { - Long2DoubleOpenHashMap kv = kvsFromAllThreads.next(); - for (Map.Entry entry : kv.entrySet()) { - accumulatedKvs.merge(entry.getKey(), entry.getValue(), Double::sum); + /** Utility class to merge the push request from different workers. */ + private static class PushRequestMerger implements Serializable { + /** The accumulated kv if the push request is for a vector. */ + private final Long2DoubleOpenHashMap accumulatedKvsForVector; + /** The accumulated kv if the push request is for a matrix. */ + private final Map accumulatedKvsForMatrix; + /** The state for accumulated kv. */ + private ListState accumulatedKvsState; + + public PushRequestMerger() { + this.accumulatedKvsForVector = new Long2DoubleOpenHashMap(); + this.accumulatedKvsForMatrix = new HashMap<>(); + } + + private Object processPushRequest(byte[] pushKv) { + KVsToPushM kvsToPush = KVsToPushM.fromBytes(pushKv); + Tuple2 pushedKvs = kvsToPush.kvs; + long[] keys = pushedKvs.f0; + double[] values = pushedKvs.f1; + + if (values.length == keys.length) { + for (int i = 0; i < keys.length; i++) { + accumulatedKvsForVector.merge(keys[i], values[i], Double::sum); + } + } else { + int valuesPerKey = values.length / keys.length; + for (int i = 0; i < keys.length; i++) { + accumulatedKvsForMatrix.putIfAbsent(keys[i], new double[valuesPerKey]); + double[] partialValue = accumulatedKvsForMatrix.get(keys[i]); + for (int j = 0; j < valuesPerKey; j++) { + partialValue[j] += values[i * valuesPerKey + j]; + } + } + } + return new Object(); + } + + /** Transforms the processed push request to kv arrays. */ + private Tuple2 toKvArrays() { + long[] indices = new long[0]; + double[] values = new double[0]; + if (accumulatedKvsForVector.size() != 0) { + indices = new long[accumulatedKvsForVector.size()]; + values = new double[indices.length]; + + int idx = 0; + for (Map.Entry entry : accumulatedKvsForVector.entrySet()) { + indices[idx] = entry.getKey(); + values[idx] = entry.getValue(); + idx++; + } + } else if (accumulatedKvsForMatrix.size() != 0) { + indices = new long[accumulatedKvsForMatrix.size()]; + int numValuesPerKey = + accumulatedKvsForMatrix.entrySet().iterator().next().getValue().length; + values = new double[indices.length * numValuesPerKey]; + int idx = 0; + for (Map.Entry entry : accumulatedKvsForMatrix.entrySet()) { + indices[idx] = entry.getKey(); + System.arraycopy( + entry.getValue(), 0, values, idx * numValuesPerKey, numValuesPerKey); + idx++; + } + } + return Tuple2.of(indices, values); + } + + private void initializeState(StateInitializationContext context) throws Exception { + accumulatedKvsState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "accumulatedKvs", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + + byte[] accumulatedKvsInBytes = + OperatorStateUtils.getUniqueElement(accumulatedKvsState, "accumulatedKvs") + .orElse(null); + + if (accumulatedKvsInBytes != null) { + Tuple2 kvs = + MessageUtils.readLongDoubleArray(accumulatedKvsInBytes, 0); + long[] keys = kvs.f0; + double[] values = kvs.f1; + int numValuesPerKey = values.length / keys.length; + if (numValuesPerKey == 1) { + for (int i = 0; i < keys.length; i++) { + accumulatedKvsForVector.put(keys[i], values[i]); + } + } else { + for (int i = 0; i < keys.length; i++) { + accumulatedKvsForMatrix.put( + keys[i], + Arrays.copyOfRange( + values, + i * numValuesPerKey, + i * numValuesPerKey + numValuesPerKey)); + } + } } - kv.clear(); } - long[] indices = new long[accumulatedKvs.size()]; - double[] values = new double[indices.length]; - int idx = 0; - for (Map.Entry entry : accumulatedKvs.entrySet()) { - indices[idx] = entry.getKey(); - values[idx] = entry.getValue(); - idx++; + + private void snapshotState(StateSnapshotContext context) throws Exception { + Tuple2 kvs = toKvArrays(); + accumulatedKvsState.clear(); + if (kvs.f0.length > 0) { + byte[] bytes = new byte[MessageUtils.getLongDoubleArraySizeInBytes(kvs)]; + MessageUtils.writeLongDoubleArray(kvs, bytes, 0); + accumulatedKvsState.add(bytes); + } } - return Tuple2.of(indices, values); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java index 0e4ef5fbc..c601356bc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -121,7 +121,7 @@ public void onEpochWatermarkIncremented( modelDim = Bits.getLong(feedback, 0); serverAgent.setPartitioner(new RangePartitioner(modelDim, numServers)); serverAgent.zeros(); - iterationStages.context.setTrainData(new ResettableTrainDataIterator<>(trainDataState)); + iterationStages.context.setInputData(new ResettableTrainDataIterator<>(trainDataState)); nextStageToExecute = processTrainingStage(nextStageToExecute, iterationStages); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java similarity index 80% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java index d226efad2..859eddcbb 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java @@ -21,6 +21,8 @@ import org.apache.flink.ml.util.Bits; import org.apache.flink.util.Preconditions; +import static org.apache.flink.ml.common.ps.message.MessageType.INITIALIZE_MODEL_AS_ZERO; + /** * Message sent by worker to server that initializes the model as a dense array with defined range. */ @@ -30,8 +32,6 @@ public class ZerosToPushM implements Message { public final long startIndex; public final long endIndex; - public static final MessageType MESSAGE_TYPE = MessageType.ZEROS_TO_PUSH; - public ZerosToPushM(int workerId, int serverId, long startIndex, long endIndex) { this.workerId = workerId; this.serverId = serverId; @@ -39,19 +39,19 @@ public ZerosToPushM(int workerId, int serverId, long startIndex, long endIndex) this.endIndex = endIndex; } - public static ZerosToPushM fromBytes(byte[] bytesData) { + public static ZerosToPushM fromBytes(byte[] bytes) { int offset = 0; - char type = Bits.getChar(bytesData, offset); + char type = Bits.getChar(bytes, offset); offset += Character.BYTES; - Preconditions.checkState(type == MESSAGE_TYPE.type); + Preconditions.checkState(type == INITIALIZE_MODEL_AS_ZERO.type); - int workerId = Bits.getInt(bytesData, offset); + int workerId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - int serverId = Bits.getInt(bytesData, offset); + int serverId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - long startIndex = Bits.getLong(bytesData, offset); + long startIndex = Bits.getLong(bytes, offset); offset += Long.BYTES; - long endIndex = Bits.getLong(bytesData, offset); + long endIndex = Bits.getLong(bytes, offset); return new ZerosToPushM(workerId, serverId, startIndex, endIndex); } @@ -60,7 +60,7 @@ public byte[] toBytes() { int numBytes = Character.BYTES + Integer.BYTES + Integer.BYTES + Long.BYTES + Long.BYTES; byte[] buffer = new byte[numBytes]; int offset = 0; - Bits.putChar(buffer, offset, MESSAGE_TYPE.type); + Bits.putChar(buffer, offset, INITIALIZE_MODEL_AS_ZERO.type); offset += Character.BYTES; Bits.putInt(buffer, offset, this.workerId); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java similarity index 81% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java index c6742b41f..a87b84583 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java @@ -21,31 +21,31 @@ import org.apache.flink.ml.util.Bits; import org.apache.flink.util.Preconditions; +import static org.apache.flink.ml.common.ps.message.MessageType.PULL_INDEX; + /** The indices one worker needs to pull from servers. */ public class IndicesToPullM implements Message { public final int serverId; public final int workerId; public final long[] indicesToPull; - public static final MessageType MESSAGE_TYPE = MessageType.INDICES_TO_PULL; - public IndicesToPullM(int serverId, int workerId, long[] indicesToPull) { this.serverId = serverId; this.workerId = workerId; this.indicesToPull = indicesToPull; } - public static IndicesToPullM fromBytes(byte[] bytesData) { + public static IndicesToPullM fromBytes(byte[] bytes) { int offset = 0; - char type = Bits.getChar(bytesData, offset); + char type = Bits.getChar(bytes, offset); offset += Character.BYTES; - Preconditions.checkState(type == MESSAGE_TYPE.type); + Preconditions.checkState(type == PULL_INDEX.type); - int psId = Bits.getInt(bytesData, offset); + int psId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - int workerId = Bits.getInt(bytesData, offset); + int workerId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - long[] toPullIndices = MessageUtils.readLongArray(bytesData, offset); + long[] toPullIndices = MessageUtils.readLongArray(bytes, offset); return new IndicesToPullM(psId, workerId, toPullIndices); } @@ -58,7 +58,7 @@ public byte[] toBytes() { byte[] buffer = new byte[numBytes]; int offset = 0; - Bits.putChar(buffer, offset, MESSAGE_TYPE.type); + Bits.putChar(buffer, offset, PULL_INDEX.type); offset += Character.BYTES; Bits.putInt(buffer, offset, this.serverId); offset += Integer.BYTES; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ValuesPulledM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java similarity index 83% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ValuesPulledM.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java index 61bf6900e..ea44076a1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ValuesPulledM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java @@ -21,12 +21,13 @@ import org.apache.flink.ml.util.Bits; import org.apache.flink.util.Preconditions; +import static org.apache.flink.ml.common.ps.message.MessageType.PULLED_VALUE; + /** The values pulled from servers. */ public class ValuesPulledM implements Message { public final int serverId; public final int workerId; public final double[] valuesPulled; - public static final MessageType MESSAGE_TYPE = MessageType.VALUES_PULLED; public ValuesPulledM(int serverId, int workerId, double[] valuesPulled) { this.serverId = serverId; @@ -34,17 +35,17 @@ public ValuesPulledM(int serverId, int workerId, double[] valuesPulled) { this.valuesPulled = valuesPulled; } - public static ValuesPulledM fromBytes(byte[] bytesData) { + public static ValuesPulledM fromBytes(byte[] bytes) { int offset = 0; - char type = Bits.getChar(bytesData, offset); + char type = Bits.getChar(bytes, offset); offset += Character.BYTES; - Preconditions.checkState(type == MESSAGE_TYPE.type); + Preconditions.checkState(type == PULLED_VALUE.type); - int psId = Bits.getInt(bytesData, offset); + int psId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - int workerId = Bits.getInt(bytesData, offset); + int workerId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - double[] pulledValues = MessageUtils.readDoubleArray(bytesData, offset); + double[] pulledValues = MessageUtils.readDoubleArray(bytes, offset); return new ValuesPulledM(psId, workerId, pulledValues); } @@ -57,7 +58,7 @@ public byte[] toBytes() { + MessageUtils.getDoubleArraySizeInBytes(valuesPulled); byte[] buffer = new byte[numBytes]; int offset = 0; - Bits.putChar(buffer, offset, MESSAGE_TYPE.type); + Bits.putChar(buffer, offset, PULLED_VALUE.type); offset += Character.BYTES; Bits.putInt(buffer, offset, this.serverId); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/KVsToPushM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java similarity index 83% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/KVsToPushM.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java index 58be79fdc..11eb25a6a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/KVsToPushM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java @@ -22,12 +22,13 @@ import org.apache.flink.ml.util.Bits; import org.apache.flink.util.Preconditions; +import static org.apache.flink.ml.common.ps.message.MessageType.PUSH_KV; + /** The sparse key-values to push from workers to servers. */ public class KVsToPushM implements Message { public final int serverId; public final int workerId; public final Tuple2 kvs; - public static final MessageType MESSAGE_TYPE = MessageType.KVS_TO_PUSH; public KVsToPushM(int workerId, int serverId, Tuple2 kvs) { this.workerId = workerId; @@ -35,17 +36,17 @@ public KVsToPushM(int workerId, int serverId, Tuple2 kvs) { this.kvs = kvs; } - public static KVsToPushM fromBytes(byte[] bytesData) { + public static KVsToPushM fromBytes(byte[] bytes) { int offset = 0; - char type = Bits.getChar(bytesData, offset); + char type = Bits.getChar(bytes, offset); offset += Character.BYTES; - Preconditions.checkState(type == MESSAGE_TYPE.type); + Preconditions.checkState(type == PUSH_KV.type); - int workerId = Bits.getInt(bytesData, offset); + int workerId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - int psId = Bits.getInt(bytesData, offset); + int psId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - Tuple2 grad = MessageUtils.readLongDoubleArray(bytesData, offset); + Tuple2 grad = MessageUtils.readLongDoubleArray(bytes, offset); return new KVsToPushM(workerId, psId, grad); } @@ -59,7 +60,7 @@ public byte[] toBytes() { byte[] buffer = new byte[numBytes]; int offset = 0; - Bits.putChar(buffer, offset, MESSAGE_TYPE.type); + Bits.putChar(buffer, offset, PUSH_KV.type); offset += Character.BYTES; Bits.putInt(buffer, offset, this.workerId); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java index f6e34095d..08c2f8417 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java @@ -45,7 +45,7 @@ public void setTerminationCriteria(SerializableFunction shouldTermin } /** Adds an iteration stage into the stage list. */ - public IterationStageList addTrainingStage(IterationStage stage) { + public IterationStageList addStage(IterationStage stage) { stageList.add(stage); return this; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java similarity index 96% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java index e12864b77..9c4f7ee28 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java @@ -40,7 +40,7 @@ default void setIterationId(int iterationId) {} default void setWorldInfo(int workerId, int numWorkers) {} /** Sets the training data. */ - default void setTrainData(ResettableIterator trainData) {} + default void setInputData(ResettableIterator inputData) {} /** Recover from state. */ default void initializeState(StateInitializationContext context) throws Exception {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java index 20b187358..5a94ae0d1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java @@ -52,6 +52,8 @@ public final class TrainingUtils { * @param modelDim dimension of the input model. * @param trainData the training data. * @param iterationStages the iterative training logic. + * @param modelUpdater the logic to update model on servers. + * @param numServers number of servers. * @return the fitted model data. */ public static DataStream> train( @@ -59,9 +61,7 @@ public static DataStream> train( DataStream trainData, ModelUpdater modelUpdater, IterationStageList iterationStages, - int numServers, - int numServerCores) { - // TODO: Support different types for model data types. + int numServers) { // TODO: Support incremental training for multiple models. // TODO: Support user defined model partitioner. @@ -81,8 +81,7 @@ public static DataStream> train( ReplayableDataStreamList.notReplay( trainData.rebalance().map(x -> x, trainData.getType())), IterationConfig.newBuilder().build(), - new TrainIterationBody( - modelUpdater, iterationStages, numServers, numServerCores)); + new TrainIterationBody(modelUpdater, iterationStages, numServers)); return resultList.get(0); } @@ -92,17 +91,14 @@ private static class TrainIterationBody implements IterationBody { private final ModelUpdater modelUpdater; private final IterationStageList iterationStages; private final int numServers; - private final int numServerCores; public TrainIterationBody( ModelUpdater modelUpdater, IterationStageList iterationStages, - int numServers, - int numServerCores) { + int numServers) { this.iterationStages = iterationStages; this.modelUpdater = modelUpdater; this.numServers = numServers; - this.numServerCores = numServerCores; } @Override @@ -138,8 +134,7 @@ public IterationBodyResult process( new TupleTypeInfo<>( Types.INT, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO), - new ServerOperator( - modelUpdater, modelDataOutputTag, numServerCores)); + new ServerOperator(modelUpdater, modelDataOutputTag)); messageToWorker.setParallelism(numServers); DataStream combinedMessageToWorker = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java similarity index 95% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java index b800e9152..f38440274 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java @@ -30,7 +30,12 @@ import java.util.Iterator; import java.util.List; -/** The FTRL model updater. */ +/** + * FTRL (Follow-the-regularized-leader) is an optimization algorithm which is widely deployed by online learning. + * + *

    See H. Brendan McMahan et al., Ad click + * * prediction: a view from the trenches. + */ public class FTRL implements ModelUpdater { private final double alpha; private final double beta; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java similarity index 95% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java index fc4c4af8f..3aafae47b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java @@ -32,7 +32,7 @@ */ public interface ModelUpdater extends Serializable { - /** Initialize the model data. */ + /** Initializes the model data. */ void open(long startFeatureIndex, long endFeatureIndex); /** Applies the push to update the model data, e.g., using gradient to update model. */ @@ -44,7 +44,7 @@ public interface ModelUpdater extends Serializable { /** Returns model pieces with the format of (startFeatureIdx, endFeatureIdx, modelValues). */ Iterator> getModelPieces(); - /** Recover the model data from state. */ + /** Recovers the model data from state. */ void initializeState(StateInitializationContext context) throws Exception; /** Snapshots the model data to state. */ diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java index 1fca23dda..c0abea005 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -68,8 +68,7 @@ public class LogisticRegressionWithFtrlTest { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); - private final double[] expectedCoefficient = - new double[] {0.5287258, -1.2163098, -1.0710997, -0.8591691}; + private final double[] expectedCoefficient = new double[] {0.52, -1.21, -1.07, -0.85}; private static final int MAX_ITER = 100; private static final int NUM_SERVERS = 2; private static final double TOLERANCE = 1e-7; @@ -168,7 +167,6 @@ public void testParam() { assertEquals(0.1, logisticRegressionWithFtrl.getBeta(), TOLERANCE); assertEquals(0L, logisticRegressionWithFtrl.getModelDim()); assertEquals(1, logisticRegressionWithFtrl.getNumServers()); - assertEquals(1, logisticRegressionWithFtrl.getNumServerCores()); logisticRegressionWithFtrl .setFeaturesCol("test_features") @@ -185,8 +183,7 @@ public void testParam() { .setAlpha(0.2) .setBeta(0.2) .setModelDim(10000000L) - .setNumServers(4) - .setNumServerCores(2); + .setNumServers(4); assertEquals("test_features", logisticRegressionWithFtrl.getFeaturesCol()); assertEquals("test_label", logisticRegressionWithFtrl.getLabelCol()); assertEquals("test_weight", logisticRegressionWithFtrl.getWeightCol()); @@ -203,7 +200,6 @@ public void testParam() { assertEquals(0.2, logisticRegressionWithFtrl.getBeta(), TOLERANCE); assertEquals(10000000L, logisticRegressionWithFtrl.getModelDim()); assertEquals(4, logisticRegressionWithFtrl.getNumServers()); - assertEquals(2, logisticRegressionWithFtrl.getNumServerCores()); } @Test @@ -250,7 +246,7 @@ public void testGetModelData() throws Exception { double[] pieceCoeff = modelPiece.coefficient.values; System.arraycopy(pieceCoeff, 0, collectedCoefficient, startIndex, pieceCoeff.length); } - assertArrayEquals(expectedCoefficient, collectedCoefficient, 1e-7); + assertArrayEquals(expectedCoefficient, collectedCoefficient, 0.1); } @Test From b04cd849ebfed8bdd724612805a6d0a96b34a188 Mon Sep 17 00:00:00 2001 From: Zhangzp Date: Tue, 30 May 2023 16:52:55 +0800 Subject: [PATCH 07/18] resolve comments --- .../LogisticRegression.java | 3 +- .../LogisticRegressionModel.java | 2 +- .../LogisticRegressionWithFtrl.java | 230 ++---------------- .../ml/common/ps/MirrorWorkerOperator.java | 34 +-- .../flink/ml/common/ps/RangePartitioner.java | 16 +- .../flink/ml/common/ps/ServerAgent.java | 27 +- .../flink/ml/common/ps/ServerOperator.java | 56 ++--- .../flink/ml/common/ps/WorkerOperator.java | 38 +-- .../ps/message/InitializeModelAsZeroM.java | 12 +- .../ml/common/ps/message/MessageType.java | 26 +- .../ml/common/ps/message/MessageUtils.java | 54 ++-- .../ml/common/ps/message/PullIndexM.java | 20 +- .../ml/common/ps/message/PulledValueM.java | 18 +- .../flink/ml/common/ps/message/PushKvM.java | 12 +- .../common/ps/training/ComputeGradients.java | 91 +++++++ .../ml/common/ps/training/ComputeIndices.java | 59 +++++ .../ps/training/IterationStageList.java | 8 +- .../ml/common/ps/training/MLSession.java | 10 +- .../ml/common/ps/training/MLSessionImpl.java | 53 ++++ .../ps/training/MiniBatchMLSession.java | 104 ++++++++ .../ml/common/ps/training/ProcessStage.java | 6 +- .../ml/common/ps/training/TrainingUtils.java | 9 +- .../flink/ml/common/ps/updater/FTRL.java | 17 +- .../ml/common/ps/updater/ModelUpdater.java | 13 +- .../LogisticRegressionWithFtrlTest.java | 8 +- .../LogisticRegressionModelData.java | 29 +++ .../LogisticRegressionModelServable.java | 36 +-- 27 files changed, 559 insertions(+), 432 deletions(-) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index 2ff61c7c5..e7a896059 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -86,8 +86,7 @@ public LogisticRegressionModel fit(Table... inputs) { "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); } Vector features = - ((Vector) dataPoint.getField(getFeaturesCol())) - .toDense(); + ((Vector) dataPoint.getField(getFeaturesCol())); return new LabeledPointWithWeight(features, label, weight); }); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java index 248cf9438..1f7176d48 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java @@ -155,7 +155,7 @@ public Row map(Row dataPoint) { servable = new LogisticRegressionModelServable(modelData.get(0)); } else { LogisticRegressionModelData mergedModel = - LogisticRegressionModelServable.mergePieces(modelData); + LogisticRegressionModelData.mergeSegments(modelData); servable = new LogisticRegressionModelServable(mergedModel); } ParamUtils.updateExistingParams(servable, params); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java index de81651ff..2c30740b3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -20,8 +20,6 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; @@ -29,23 +27,19 @@ import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; -import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.common.ps.training.ComputeGradients; +import org.apache.flink.ml.common.ps.training.ComputeIndices; import org.apache.flink.ml.common.ps.training.IterationStageList; -import org.apache.flink.ml.common.ps.training.ProcessStage; +import org.apache.flink.ml.common.ps.training.MiniBatchMLSession; import org.apache.flink.ml.common.ps.training.PullStage; import org.apache.flink.ml.common.ps.training.PushStage; import org.apache.flink.ml.common.ps.training.SerializableConsumer; -import org.apache.flink.ml.common.ps.training.TrainingContext; import org.apache.flink.ml.common.ps.training.TrainingUtils; -import org.apache.flink.ml.common.updater.FTRL; -import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.common.ps.updater.FTRL; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; -import org.apache.flink.runtime.state.StateInitializationContext; -import org.apache.flink.runtime.state.StateSnapshotContext; -import org.apache.flink.runtime.util.ResettableIterator; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -55,16 +49,8 @@ import org.apache.flink.util.function.SerializableFunction; import org.apache.flink.util.function.SerializableSupplier; -import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; -import it.unimi.dsi.fastutil.longs.LongOpenHashSet; -import org.apache.commons.collections.IteratorUtils; - import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; -import java.util.Iterator; -import java.util.List; import java.util.Map; /** @@ -131,25 +117,27 @@ public LogisticRegressionModel fit(Table... inputs) { .map((MapFunction) value -> value + 1); } - LogisticRegressionWithFtrlTrainingContext trainingContext = - new LogisticRegressionWithFtrlTrainingContext(getParamMap()); + MiniBatchMLSession mlSession = + new MiniBatchMLSession<>( + getGlobalBatchSize(), + TypeInformation.of(LabeledLargePointWithWeight.class)); - IterationStageList iterationStages = - new IterationStageList<>(trainingContext); + IterationStageList> iterationStages = + new IterationStageList<>(mlSession); iterationStages .addStage(new ComputeIndices()) .addStage( new PullStage( - (SerializableSupplier) () -> trainingContext.pullIndices, - (SerializableConsumer) - x -> trainingContext.pulledValues = x)) + (SerializableSupplier) () -> mlSession.pullIndices, + (SerializableConsumer) x -> mlSession.pulledValues = x)) .addStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) .addStage( new PushStage( - (SerializableSupplier) () -> trainingContext.pushIndices, - (SerializableSupplier) () -> trainingContext.pushValues)) + (SerializableSupplier) () -> mlSession.pushIndices, + (SerializableSupplier) () -> mlSession.pushValues)) .setTerminationCriteria( - (SerializableFunction) + (SerializableFunction< + MiniBatchMLSession, Boolean>) o -> o.iterationId >= getMaxIter()); FTRL ftrl = new FTRL( @@ -194,189 +182,3 @@ public Map, Object> getParamMap() { return paramMap; } } - -/** - * An iteration stage that samples a batch of training data and computes the indices needed to - * compute gradients. - */ -class ComputeIndices extends ProcessStage { - - @Override - public void process(LogisticRegressionWithFtrlTrainingContext context) throws Exception { - context.readInNextBatchData(); - context.pullIndices = computeIndices(context.batchData); - } - - public static long[] computeIndices(List dataPoints) { - LongOpenHashSet indices = new LongOpenHashSet(); - for (LabeledLargePointWithWeight dataPoint : dataPoints) { - long[] notZeros = dataPoint.features.f0; - for (long index : notZeros) { - indices.add(index); - } - } - - long[] sortedIndices = new long[indices.size()]; - Iterator iterator = indices.iterator(); - int i = 0; - while (iterator.hasNext()) { - sortedIndices[i++] = iterator.next(); - } - Arrays.sort(sortedIndices); - return sortedIndices; - } -} - -/** - * An iteration stage that uses the pulled model values and sampled batch data to compute the - * gradients. - */ -class ComputeGradients extends ProcessStage { - private final LossFunc lossFunc; - - public ComputeGradients(LossFunc lossFunc) { - this.lossFunc = lossFunc; - } - - @Override - public void process(LogisticRegressionWithFtrlTrainingContext context) throws IOException { - long[] indices = ComputeIndices.computeIndices(context.batchData); - double[] pulledModelValues = context.pulledValues; - double[] gradients = computeGradient(context.batchData, indices, pulledModelValues); - - context.pushIndices = indices; - context.pushValues = gradients; - } - - private double[] computeGradient( - List batchData, - long[] sortedBatchIndices, - double[] pulledModelValues) { - Long2DoubleOpenHashMap coefficient = new Long2DoubleOpenHashMap(sortedBatchIndices.length); - for (int i = 0; i < sortedBatchIndices.length; i++) { - coefficient.put(sortedBatchIndices[i], pulledModelValues[i]); - } - Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(sortedBatchIndices.length); - - for (LabeledLargePointWithWeight dataPoint : batchData) { - double dot = dot(dataPoint.features, coefficient); - double multiplier = lossFunc.computeGradient(dataPoint.label, dot) * dataPoint.weight; - - long[] featureIndices = dataPoint.features.f0; - double[] featureValues = dataPoint.features.f1; - double z; - for (int i = 0; i < featureIndices.length; i++) { - long currentIndex = featureIndices[i]; - z = featureValues[i] * multiplier + cumGradients.getOrDefault(currentIndex, 0.); - cumGradients.put(currentIndex, z); - } - } - double[] cumGradientValues = new double[sortedBatchIndices.length]; - for (int i = 0; i < sortedBatchIndices.length; i++) { - cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]); - } - BLAS.scal(1.0 / batchData.size(), Vectors.dense(cumGradientValues)); - return cumGradientValues; - } - - private static double dot( - Tuple2 features, Long2DoubleOpenHashMap coefficient) { - double dot = 0; - for (int i = 0; i < features.f0.length; i++) { - dot += features.f1[i] * coefficient.get(features.f0[i]); - } - return dot; - } -} - -/** The context information of local computing process. */ -class LogisticRegressionWithFtrlTrainingContext - implements TrainingContext, - LogisticRegressionWithFtrlParams { - /** Parameters of LogisticRegressionWithFtrl. */ - private final Map, Object> paramMap; - /** Current iteration id. */ - int iterationId; - /** The local batch size. */ - private int localBatchSize = -1; - /** The training data. */ - private ResettableIterator trainData; - /** The batch of training data for computing gradients. */ - List batchData; - - private ListState batchDataState; - - /** The placeholder for indices to pull for each iteration. */ - long[] pullIndices; - /** The placeholder for the pulled values for each iteration. */ - double[] pulledValues; - /** The placeholder for indices to push for each iteration. */ - long[] pushIndices; - /** The placeholder for values to push for each iteration. */ - double[] pushValues; - - public LogisticRegressionWithFtrlTrainingContext(Map, Object> paramMap) { - this.paramMap = paramMap; - } - - @Override - public void setIterationId(int iterationId) { - this.iterationId = iterationId; - } - - @Override - public void setWorldInfo(int workerId, int numWorkers) { - int globalBatchSize = getGlobalBatchSize(); - this.localBatchSize = globalBatchSize / numWorkers; - if (globalBatchSize % numWorkers > workerId) { - localBatchSize++; - } - this.batchData = new ArrayList<>(localBatchSize); - } - - @Override - public void setInputData(ResettableIterator inputData) { - this.trainData = (ResettableIterator) inputData; - } - - @Override - public void initializeState(StateInitializationContext context) throws Exception { - batchDataState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "batchDataState", - TypeInformation.of(LabeledLargePointWithWeight.class))); - - Iterator batchDataIterator = batchDataState.get().iterator(); - if (batchDataIterator.hasNext()) { - batchData = IteratorUtils.toList(batchDataIterator); - } - } - - @Override - public void snapshotState(StateSnapshotContext context) throws Exception { - batchDataState.clear(); - if (batchData.size() > 0) { - batchDataState.addAll(batchData); - } - } - - @Override - public Map, Object> getParamMap() { - return paramMap; - } - - /** Reads in next batch of training data. */ - public void readInNextBatchData() throws IOException { - batchData.clear(); - int i = 0; - while (i < localBatchSize && trainData.hasNext()) { - batchData.add(trainData.next()); - i++; - } - if (!trainData.hasNext()) { - trainData.reset(); - } - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java index 40d055da4..32a601d0e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java @@ -22,7 +22,7 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.ml.common.ps.message.ValuesPulledM; +import org.apache.flink.ml.common.ps.message.PulledValueM; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -38,8 +38,8 @@ /** * Merges the message from different servers for one pull request. * - *

    Note that for each single-thread worker, there are at exactly #numServers pieces for each pull - * request in the feedback edge. + *

    Note that for each single-thread worker, there are at exactly #numServers segments for each + * pull request in the feedback edge. */ public class MirrorWorkerOperator extends AbstractStreamOperator implements OneInputStreamOperator, byte[]> { @@ -47,7 +47,7 @@ public class MirrorWorkerOperator extends AbstractStreamOperator private int workerId; /** The received messages from servers for the current pull request. */ - private List messageReceived; + private List messageReceived; private ListState messageReceivedState; @@ -64,28 +64,28 @@ public void open() throws Exception { @Override public void processElement(StreamRecord> element) throws Exception { Preconditions.checkState(element.getValue().f0 == workerId); - ValuesPulledM pulledModelM = ValuesPulledM.fromBytes(element.getValue().f1); - messageReceived.add(pulledModelM); + PulledValueM pulledValueM = PulledValueM.fromBytes(element.getValue().f1); + messageReceived.add(pulledValueM); trySendingPulls(numServers); } - private void trySendingPulls(int numPieces) { - if (messageReceived.size() == numPieces) { - Comparator comparator = Comparator.comparingInt(o -> o.serverId); + private void trySendingPulls(int numSegments) { + if (messageReceived.size() == numSegments) { + Comparator comparator = Comparator.comparingInt(o -> o.serverId); messageReceived.sort(comparator); int size = 0; - for (ValuesPulledM pulledModelM : messageReceived) { - size += pulledModelM.valuesPulled.length; + for (PulledValueM pulledValueM : messageReceived) { + size += pulledValueM.values.length; } double[] answer = new double[size]; int offset = 0; - for (ValuesPulledM pulledModelM : messageReceived) { - double[] values = pulledModelM.valuesPulled; + for (PulledValueM pulledValueM : messageReceived) { + double[] values = pulledValueM.values; System.arraycopy(values, 0, answer, offset, values.length); offset += values.length; } - ValuesPulledM pulledModelM = new ValuesPulledM(-1, workerId, answer); - output.collect(new StreamRecord<>(pulledModelM.toBytes())); + PulledValueM pulledValueM = new PulledValueM(-1, workerId, answer); + output.collect(new StreamRecord<>(pulledValueM.toBytes())); messageReceived.clear(); } } @@ -104,7 +104,7 @@ public void initializeState(StateInitializationContext context) throws Exception Iterator iterator = messageReceivedState.get().iterator(); if (iterator.hasNext()) { while (iterator.hasNext()) { - messageReceived.add(ValuesPulledM.fromBytes(iterator.next())); + messageReceived.add(PulledValueM.fromBytes(iterator.next())); } } } @@ -114,7 +114,7 @@ public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); messageReceivedState.clear(); if (messageReceived.size() > 0) { - for (ValuesPulledM valuesPulled : messageReceived) { + for (PulledValueM valuesPulled : messageReceived) { messageReceivedState.add(valuesPulled.toBytes()); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java index 79fd24a72..e8de5e920 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java @@ -26,7 +26,11 @@ import java.util.Arrays; import java.util.Iterator; -/** Range partitioner for model data. */ +/** + * Range partitioner for model data. It partitions the model data for each dimension according to + * the dimension id. The model data for each dimension could be a double or several doubles. Note + * that the model data for all dimensions should share the same size. + */ public class RangePartitioner { public final long dim; public final int numServers; @@ -54,9 +58,9 @@ public RangePartitioner(long dim, int numServers) { * Splits the push/pull request according to the given sorted indices and the corresponding * values. * - * @param indices Sorted indices of push/pull request. - * @param values The push values if not null. - * @return The split requests for each server task. + * @param indices sorted indices of push/pull request. + * @param values the push values if not null. + * @return the split requests for each server. */ public Iterator> splitRequest( long[] indices, @Nullable double[] values) { @@ -89,7 +93,9 @@ public RequestsIterator( numValuesPerKey = values.length / indices.length; Preconditions.checkArgument( numValuesPerKey * indices.length == values.length, - "The size of values cannot be divided by size of keys."); + String.format( + "The size of values [%d] cannot be divided by size of keys [%d].", + values.length, indices.length)); } else { numValuesPerKey = 1; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java index 7a6b5dea1..cc6d7ae1f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -20,9 +20,9 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.ml.common.ps.message.IndicesToPullM; -import org.apache.flink.ml.common.ps.message.KVsToPushM; -import org.apache.flink.ml.common.ps.message.ZerosToPushM; +import org.apache.flink.ml.common.ps.message.InitializeModelAsZeroM; +import org.apache.flink.ml.common.ps.message.PullIndexM; +import org.apache.flink.ml.common.ps.message.PushKvM; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -30,9 +30,9 @@ /** ServerAgent resides on each worker. It serves as an agent for workers to talk with servers. */ public class ServerAgent { - /** Id of the worker that this agent resides on. */ + /** Index of the worker that this agent resides on. */ private final int workerId; - + /** Partitioner of the model data that this ServerAgent maintains. */ private RangePartitioner partitioner; /** The collector on this worker. */ private final Output>> output; @@ -52,19 +52,20 @@ public void pushKVs(long[] indices, double[] values) { partitioner.splitRequest(indices, values); while (requests.hasNext()) { Tuple3 request = requests.next(); - KVsToPushM kvToPush = - new KVsToPushM(workerId, request.f0, Tuple2.of(request.f1, request.f2)); - output.collect(new StreamRecord<>(Tuple2.of(request.f0, kvToPush.toBytes()))); + PushKvM pushKvM = new PushKvM(workerId, request.f0, Tuple2.of(request.f1, request.f2)); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, pushKvM.toBytes()))); } } /** Sends a request to servers to initialize the values stored as zeros. */ - public void zeros() { + public void initializeModelAsZeros() { for (int serverId = 0; serverId < partitioner.numServers; serverId++) { long start = partitioner.ranges[serverId]; long end = partitioner.ranges[serverId + 1]; - ZerosToPushM zerosToPush = new ZerosToPushM(workerId, serverId, start, end); - output.collect(new StreamRecord<>(Tuple2.of(serverId, zerosToPush.toBytes()))); + InitializeModelAsZeroM initializeModelAsZeroM = + new InitializeModelAsZeroM(workerId, serverId, start, end); + output.collect( + new StreamRecord<>(Tuple2.of(serverId, initializeModelAsZeroM.toBytes()))); } } @@ -74,8 +75,8 @@ public void pull(long[] indices) { partitioner.splitRequest(indices, null); while (requests.hasNext()) { Tuple3 request = requests.next(); - IndicesToPullM indicesToPullM = new IndicesToPullM(request.f0, workerId, request.f1); - output.collect(new StreamRecord<>(Tuple2.of(request.f0, indicesToPullM.toBytes()))); + PullIndexM pullIndexM = new PullIndexM(request.f0, workerId, request.f1); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, pullIndexM.toBytes()))); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java index 90fed895e..a9b9772da 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -25,13 +25,13 @@ import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.operator.OperatorStateUtils; -import org.apache.flink.ml.common.ps.message.IndicesToPullM; -import org.apache.flink.ml.common.ps.message.KVsToPushM; +import org.apache.flink.ml.common.ps.message.InitializeModelAsZeroM; import org.apache.flink.ml.common.ps.message.MessageType; import org.apache.flink.ml.common.ps.message.MessageUtils; -import org.apache.flink.ml.common.ps.message.ValuesPulledM; -import org.apache.flink.ml.common.ps.message.ZerosToPushM; -import org.apache.flink.ml.common.updater.ModelUpdater; +import org.apache.flink.ml.common.ps.message.PullIndexM; +import org.apache.flink.ml.common.ps.message.PulledValueM; +import org.apache.flink.ml.common.ps.message.PushKvM; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -40,7 +40,6 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.OutputTag; import org.apache.flink.util.Preconditions; -import org.apache.flink.util.SerializableObject; import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; @@ -65,9 +64,9 @@ * received message. *

  • The server operator calls {@link ModelUpdater#handlePush(long[], double[])} and {@link * ModelUpdater#handlePull(long[])} to process the messages in detail. - *
  • The server operator ensures that {@link ModelUpdater} is robust to failures. + *
  • The server operator triggers checkpoint for {@link ModelUpdater}. *
  • The server operator outputs the final output parameters by calling {@link - * ModelUpdater#getModelPieces()}. + * ModelUpdater#getModelSegments()}. *
* *

TODO: Add support for asynchronous operations on servers. @@ -114,11 +113,11 @@ public void processElement(StreamRecord> element) throws byte[] request = element.getValue().f1; MessageType type = MessageUtils.getMessageType(request); switch (type) { - case INDICES_TO_PULL: + case PULL_INDEX: pendingPulls.add(request); break; - case ZEROS_TO_PUSH: - ZerosToPushM zerosToPush = ZerosToPushM.fromBytes(request); + case INITIALIZE_MODEL_AS_ZERO: + InitializeModelAsZeroM zerosToPush = InitializeModelAsZeroM.fromBytes(request); Preconditions.checkState(serverId == zerosToPush.serverId); long start = zerosToPush.startIndex; @@ -127,7 +126,7 @@ public void processElement(StreamRecord> element) throws modelUpdater.open(start, end); } break; - case KVS_TO_PUSH: + case PUSH_KV: futuresInEpoch.add( singleThreadExecutor.submit( () -> pushRequestMerger.processPushRequest(request))); @@ -148,6 +147,7 @@ public void onEpochWatermarkIncremented( futuresInEpoch.clear(); if (epochWatermark > 0) { + // The first iteration contains no push kvs, but model initialization request. Tuple2 kvs = pushRequestMerger.toKvArrays(); pushRequestMerger.accumulatedKvsForMatrix.clear(); pushRequestMerger.accumulatedKvsForVector.clear(); @@ -172,10 +172,10 @@ public void onEpochWatermarkIncremented( @Override public void onIterationTerminated( Context context, Collector> collector) { - Iterator> modelPieces = modelUpdater.getModelPieces(); - while (modelPieces.hasNext()) { - Tuple3 modelPiece = modelPieces.next(); - output.collect(modelOutputTag, new StreamRecord<>(modelPiece)); + Iterator> modelSegments = modelUpdater.getModelSegments(); + while (modelSegments.hasNext()) { + Tuple3 modelSegment = modelSegments.next(); + output.collect(modelOutputTag, new StreamRecord<>(modelSegment)); } } @@ -206,14 +206,14 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } private Object processPullRequest(byte[] bytesData) { - IndicesToPullM sparsePullModeM = IndicesToPullM.fromBytes(bytesData); - Preconditions.checkState(serverId == sparsePullModeM.serverId); - int workerId = sparsePullModeM.workerId; - long[] indices = sparsePullModeM.indicesToPull; + PullIndexM pullIndexM = PullIndexM.fromBytes(bytesData); + Preconditions.checkState(serverId == pullIndexM.serverId); + int workerId = pullIndexM.workerId; + long[] indices = pullIndexM.indices; double[] pulledValues = modelUpdater.handlePull(indices); - ValuesPulledM pulledModelM = new ValuesPulledM(serverId, workerId, pulledValues); + PulledValueM pulledValueM = new PulledValueM(serverId, workerId, pulledValues); StreamRecord> record = - new StreamRecord<>(Tuple2.of(workerId, pulledModelM.toBytes())); + new StreamRecord<>(Tuple2.of(workerId, pulledValueM.toBytes())); output.collect(record); return new Object(); @@ -234,10 +234,10 @@ public PushRequestMerger() { } private Object processPushRequest(byte[] pushKv) { - KVsToPushM kvsToPush = KVsToPushM.fromBytes(pushKv); - Tuple2 pushedKvs = kvsToPush.kvs; - long[] keys = pushedKvs.f0; - double[] values = pushedKvs.f1; + PushKvM pushKvM = PushKvM.fromBytes(pushKv); + Tuple2 pushKvs = pushKvM.kvs; + long[] keys = pushKvs.f0; + double[] values = pushKvs.f1; if (values.length == keys.length) { for (int i = 0; i < keys.length; i++) { @@ -300,7 +300,7 @@ private void initializeState(StateInitializationContext context) throws Exceptio if (accumulatedKvsInBytes != null) { Tuple2 kvs = - MessageUtils.readLongDoubleArray(accumulatedKvsInBytes, 0); + MessageUtils.getLongDoubleArray(accumulatedKvsInBytes, 0); long[] keys = kvs.f0; double[] values = kvs.f1; int numValuesPerKey = values.length / keys.length; @@ -326,7 +326,7 @@ private void snapshotState(StateSnapshotContext context) throws Exception { accumulatedKvsState.clear(); if (kvs.f0.length > 0) { byte[] bytes = new byte[MessageUtils.getLongDoubleArraySizeInBytes(kvs)]; - MessageUtils.writeLongDoubleArray(kvs, bytes, 0); + MessageUtils.putLongDoubleArray(kvs, bytes, 0); accumulatedKvsState.add(bytes); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java index c601356bc..dc69974a0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -26,13 +26,13 @@ import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; -import org.apache.flink.ml.common.ps.message.ValuesPulledM; +import org.apache.flink.ml.common.ps.message.PulledValueM; import org.apache.flink.ml.common.ps.training.IterationStage; import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.MLSession; import org.apache.flink.ml.common.ps.training.ProcessStage; import org.apache.flink.ml.common.ps.training.PullStage; import org.apache.flink.ml.common.ps.training.PushStage; -import org.apache.flink.ml.common.ps.training.TrainingContext; import org.apache.flink.ml.util.Bits; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; @@ -53,15 +53,15 @@ * *

    *
  • Caches the training data. - *
  • Initializes the {@link TrainingContext}. + *
  • Initializes the {@link MLSession}. *
  • Splits the {@link IterationStageList} by {@link PullStage} into multiple sequences and map * it into flink-ml-iterations. *
  • Executes the process function in each {@link ProcessStage}. *
  • Executes the push/pull request in {@link PushStage} and {@link PullStage} and talk to - * servers, by reading/writing {@link TrainingContext}. + * servers, by reading/writing {@link MLSession}. *
*/ -public class WorkerOperator +public class WorkerOperator extends AbstractStreamOperator> implements TwoInputStreamOperator>, IterationListener> { @@ -72,7 +72,7 @@ public class WorkerOperator private ServerAgent serverAgent; /** The user defined iteration logic. */ - private final IterationStageList iterationStages; + private final IterationStageList iterationStages; /** * Iteration id in terms of {@link IterationStageList}. When we finished processing all stages @@ -100,7 +100,7 @@ public class WorkerOperator private ListState modelDimState; - public WorkerOperator(IterationStageList iterationStages, int numServers) { + public WorkerOperator(IterationStageList iterationStages, int numServers) { this.iterationStages = iterationStages; this.numServers = numServers; } @@ -110,7 +110,7 @@ public void open() { int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); int workerId = getRuntimeContext().getIndexOfThisSubtask(); this.serverAgent = new ServerAgent(workerId, output); - iterationStages.context.setWorldInfo(workerId, numTasks); + iterationStages.session.setWorldInfo(workerId, numTasks); } @Override @@ -120,8 +120,8 @@ public void onEpochWatermarkIncremented( if (epochWatermark == 0) { modelDim = Bits.getLong(feedback, 0); serverAgent.setPartitioner(new RangePartitioner(modelDim, numServers)); - serverAgent.zeros(); - iterationStages.context.setInputData(new ResettableTrainDataIterator<>(trainDataState)); + serverAgent.initializeModelAsZeros(); + iterationStages.session.setInputData(new ResettableTrainDataIterator<>(trainDataState)); nextStageToExecute = processTrainingStage(nextStageToExecute, iterationStages); } } @@ -141,12 +141,12 @@ public void processElement1(StreamRecord
streamRecord) throws Exception { public void processElement2(StreamRecord streamRecord) throws Exception { feedback = streamRecord.getValue(); if (modelDim > 0) { - // Decodes the pulled method and put it in training context. + // Decodes the pulled method and put it in ml session. PullStage pullStage = (PullStage) iterationStages.stageList.get(nextStageToExecute); - ValuesPulledM valuesPulledMessage = ValuesPulledM.fromBytes(streamRecord.getValue()); + PulledValueM valuesPulledMessage = PulledValueM.fromBytes(streamRecord.getValue()); Preconditions.checkState( getRuntimeContext().getIndexOfThisSubtask() == valuesPulledMessage.workerId); - pullStage.valuesConsumer.accept(valuesPulledMessage.valuesPulled); + pullStage.valuesConsumer.accept(valuesPulledMessage.values); nextStageToExecute++; nextStageToExecute = processTrainingStage(nextStageToExecute, iterationStages); @@ -197,7 +197,7 @@ public void initializeState(StateInitializationContext context) throws Exception serverAgent.setPartitioner(new RangePartitioner(modelDim, numServers)); } - iterationStages.context.initializeState(context); + iterationStages.session.initializeState(context); } @Override @@ -216,7 +216,7 @@ public void snapshotState(StateSnapshotContext context) throws Exception { iterationIdState.add(iterationId); trainDataState.snapshotState(context); - iterationStages.context.snapshotState(context); + iterationStages.session.snapshotState(context); } /** @@ -229,12 +229,12 @@ public void snapshotState(StateSnapshotContext context) throws Exception { */ @SuppressWarnings("unchecked") private int processTrainingStage( - int nextStageToExecute, IterationStageList iterationStages) throws Exception { + int nextStageToExecute, IterationStageList iterationStages) throws Exception { while (true) { if (nextStageToExecute >= iterationStages.stageList.size()) { iterationId++; - iterationStages.context.setIterationId(iterationId); - if (iterationStages.shouldTerminate.apply(iterationStages.context)) { + iterationStages.session.setIterationId(iterationId); + if (iterationStages.shouldTerminate.apply(iterationStages.session)) { return -1; } nextStageToExecute -= iterationStages.stageList.size(); @@ -252,7 +252,7 @@ private int processTrainingStage( serverAgent.pushKVs(pushStage.keysSupplier.get(), pushStage.valuesSupplier.get()); nextStageToExecute++; } else if (stage instanceof ProcessStage) { - ((ProcessStage) stage).process(iterationStages.context); + ((ProcessStage) stage).process(iterationStages.session); nextStageToExecute++; } else { throw new IllegalStateException( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java index 859eddcbb..3c2bee6d7 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java @@ -23,23 +23,21 @@ import static org.apache.flink.ml.common.ps.message.MessageType.INITIALIZE_MODEL_AS_ZERO; -/** - * Message sent by worker to server that initializes the model as a dense array with defined range. - */ -public class ZerosToPushM implements Message { +/** Message sent by worker to server that initializes the model as zeros with defined range. */ +public class InitializeModelAsZeroM implements Message { public final int workerId; public final int serverId; public final long startIndex; public final long endIndex; - public ZerosToPushM(int workerId, int serverId, long startIndex, long endIndex) { + public InitializeModelAsZeroM(int workerId, int serverId, long startIndex, long endIndex) { this.workerId = workerId; this.serverId = serverId; this.startIndex = startIndex; this.endIndex = endIndex; } - public static ZerosToPushM fromBytes(byte[] bytes) { + public static InitializeModelAsZeroM fromBytes(byte[] bytes) { int offset = 0; char type = Bits.getChar(bytes, offset); offset += Character.BYTES; @@ -52,7 +50,7 @@ public static ZerosToPushM fromBytes(byte[] bytes) { long startIndex = Bits.getLong(bytes, offset); offset += Long.BYTES; long endIndex = Bits.getLong(bytes, offset); - return new ZerosToPushM(workerId, serverId, startIndex, endIndex); + return new InitializeModelAsZeroM(workerId, serverId, startIndex, endIndex); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java index b6e9a6afd..79da886d6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java @@ -20,10 +20,20 @@ /** Message Type between workers and servers. */ public enum MessageType { - ZEROS_TO_PUSH((char) 0), - INDICES_TO_PULL((char) 1), - VALUES_PULLED((char) 2), - KVS_TO_PUSH((char) 3); + /** Message sent from workers to servers, which initializes the model on servers as zero. */ + INITIALIZE_MODEL_AS_ZERO((char) 0), + /** Message sent from workers to servers, which specifies the indices of model to pull. */ + PULL_INDEX((char) 1), + /** + * Message sent from server to workers, which specifies the values of the model pulled from + * servers. + */ + PULLED_VALUE((char) 2), + /** + * Message sent from workers to servers, which specifies the indices and values of the model to + * push to servers. + */ + PUSH_KV((char) 3); public final char type; @@ -34,13 +44,13 @@ public enum MessageType { public static MessageType valueOf(char value) { switch (value) { case (char) 0: - return MessageType.ZEROS_TO_PUSH; + return MessageType.INITIALIZE_MODEL_AS_ZERO; case (char) 1: - return MessageType.INDICES_TO_PULL; + return MessageType.PULL_INDEX; case ((char) 2): - return MessageType.VALUES_PULLED; + return MessageType.PULLED_VALUE; case ((char) 3): - return MessageType.KVS_TO_PUSH; + return MessageType.PUSH_KV; default: throw new UnsupportedOperationException(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java index 01c4fa4aa..d2a628870 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java @@ -25,33 +25,33 @@ public class MessageUtils { /** Retrieves the message type from the byte array. */ - public static MessageType getMessageType(byte[] bytesData) { - char type = Bits.getChar(bytesData, 0); + public static MessageType getMessageType(byte[] bytes) { + char type = Bits.getChar(bytes, 0); return MessageType.valueOf(type); } - /** Reads a long array from the byte array starting from the given offset. */ - public static long[] readLongArray(byte[] bytesData, int offset) { - int size = Bits.getInt(bytesData, offset); + /** Gets a long array from the byte array starting from the given offset. */ + public static long[] getLongArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); offset += Integer.BYTES; long[] result = new long[size]; for (int i = 0; i < size; i++) { - result[i] = Bits.getLong(bytesData, offset); + result[i] = Bits.getLong(bytes, offset); offset += Long.BYTES; } return result; } /** - * Writes a long array to the byte array starting from the given offset. + * Puts a long array to the byte array starting from the given offset. * * @return the next position to write on. */ - public static int writeLongArray(long[] array, byte[] bytesData, int offset) { - Bits.putInt(bytesData, offset, array.length); + public static int putLongArray(long[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); offset += Integer.BYTES; for (int i = 0; i < array.length; i++) { - Bits.putLong(bytesData, offset, array[i]); + Bits.putLong(bytes, offset, array[i]); offset += Long.BYTES; } return offset; @@ -62,28 +62,28 @@ public static int getLongArraySizeInBytes(long[] array) { return Integer.BYTES + array.length * Long.BYTES; } - /** Reads a double array from the byte array starting from the given offset. */ - public static double[] readDoubleArray(byte[] bytesData, int offset) { - int size = Bits.getInt(bytesData, offset); + /** Gets a double array from the byte array starting from the given offset. */ + public static double[] getDoubleArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); offset += Integer.BYTES; double[] result = new double[size]; for (int i = 0; i < size; i++) { - result[i] = Bits.getDouble(bytesData, offset); + result[i] = Bits.getDouble(bytes, offset); offset += Long.BYTES; } return result; } /** - * Writes a double array to the byte array starting from the given offset. + * Puts a double array to the byte array starting from the given offset. * * @return the next position to write on. */ - public static int writeDoubleArray(double[] array, byte[] bytesData, int offset) { - Bits.putInt(bytesData, offset, array.length); + public static int putDoubleArray(double[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); offset += Integer.BYTES; for (int i = 0; i < array.length; i++) { - Bits.putDouble(bytesData, offset, array[i]); + Bits.putDouble(bytes, offset, array[i]); offset += Double.BYTES; } return offset; @@ -94,23 +94,23 @@ public static int getDoubleArraySizeInBytes(double[] array) { return Integer.BYTES + array.length * Long.BYTES; } - /** Reads a long-double array from the byte array starting from the given offset. */ - public static Tuple2 readLongDoubleArray(byte[] bytesData, int offset) { - long[] indices = readLongArray(bytesData, offset); + /** Gets a long-double array from the byte array starting from the given offset. */ + public static Tuple2 getLongDoubleArray(byte[] bytes, int offset) { + long[] indices = getLongArray(bytes, offset); offset += getLongArraySizeInBytes(indices); - double[] values = readDoubleArray(bytesData, offset); + double[] values = getDoubleArray(bytes, offset); return Tuple2.of(indices, values); } /** - * Writes a long-double to the byte array starting from the given offset. + * Puts a long-double array to the byte array starting from the given offset. * * @return the next position to write on. */ - public static int writeLongDoubleArray( - Tuple2 longDoubleArray, byte[] bytesData, int offset) { - offset = writeLongArray(longDoubleArray.f0, bytesData, offset); - offset = writeDoubleArray(longDoubleArray.f1, bytesData, offset); + public static int putLongDoubleArray( + Tuple2 longDoubleArray, byte[] bytes, int offset) { + offset = putLongArray(longDoubleArray.f0, bytes, offset); + offset = putDoubleArray(longDoubleArray.f1, bytes, offset); return offset; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java index a87b84583..bf6d4caa6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java @@ -24,18 +24,18 @@ import static org.apache.flink.ml.common.ps.message.MessageType.PULL_INDEX; /** The indices one worker needs to pull from servers. */ -public class IndicesToPullM implements Message { +public class PullIndexM implements Message { public final int serverId; public final int workerId; - public final long[] indicesToPull; + public final long[] indices; - public IndicesToPullM(int serverId, int workerId, long[] indicesToPull) { + public PullIndexM(int serverId, int workerId, long[] indices) { this.serverId = serverId; this.workerId = workerId; - this.indicesToPull = indicesToPull; + this.indices = indices; } - public static IndicesToPullM fromBytes(byte[] bytes) { + public static PullIndexM fromBytes(byte[] bytes) { int offset = 0; char type = Bits.getChar(bytes, offset); offset += Character.BYTES; @@ -45,16 +45,14 @@ public static IndicesToPullM fromBytes(byte[] bytes) { offset += Integer.BYTES; int workerId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - long[] toPullIndices = MessageUtils.readLongArray(bytes, offset); - return new IndicesToPullM(psId, workerId, toPullIndices); + long[] indices = MessageUtils.getLongArray(bytes, offset); + return new PullIndexM(psId, workerId, indices); } @Override public byte[] toBytes() { int numBytes = - Character.BYTES - + Integer.BYTES * 2 - + MessageUtils.getLongArraySizeInBytes(indicesToPull); + Character.BYTES + Integer.BYTES * 2 + MessageUtils.getLongArraySizeInBytes(indices); byte[] buffer = new byte[numBytes]; int offset = 0; @@ -64,7 +62,7 @@ public byte[] toBytes() { offset += Integer.BYTES; Bits.putInt(buffer, offset, this.workerId); offset += Integer.BYTES; - MessageUtils.writeLongArray(this.indicesToPull, buffer, offset); + MessageUtils.putLongArray(this.indices, buffer, offset); return buffer; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java index ea44076a1..96ab8b072 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java @@ -24,18 +24,18 @@ import static org.apache.flink.ml.common.ps.message.MessageType.PULLED_VALUE; /** The values pulled from servers. */ -public class ValuesPulledM implements Message { +public class PulledValueM implements Message { public final int serverId; public final int workerId; - public final double[] valuesPulled; + public final double[] values; - public ValuesPulledM(int serverId, int workerId, double[] valuesPulled) { + public PulledValueM(int serverId, int workerId, double[] values) { this.serverId = serverId; this.workerId = workerId; - this.valuesPulled = valuesPulled; + this.values = values; } - public static ValuesPulledM fromBytes(byte[] bytes) { + public static PulledValueM fromBytes(byte[] bytes) { int offset = 0; char type = Bits.getChar(bytes, offset); offset += Character.BYTES; @@ -45,8 +45,8 @@ public static ValuesPulledM fromBytes(byte[] bytes) { offset += Integer.BYTES; int workerId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - double[] pulledValues = MessageUtils.readDoubleArray(bytes, offset); - return new ValuesPulledM(psId, workerId, pulledValues); + double[] values = MessageUtils.getDoubleArray(bytes, offset); + return new PulledValueM(psId, workerId, values); } @Override @@ -55,7 +55,7 @@ public byte[] toBytes() { Character.BYTES + Integer.BYTES + Integer.BYTES - + MessageUtils.getDoubleArraySizeInBytes(valuesPulled); + + MessageUtils.getDoubleArraySizeInBytes(values); byte[] buffer = new byte[numBytes]; int offset = 0; Bits.putChar(buffer, offset, PULLED_VALUE.type); @@ -65,7 +65,7 @@ public byte[] toBytes() { offset += Integer.BYTES; Bits.putInt(buffer, offset, this.workerId); offset += Integer.BYTES; - MessageUtils.writeDoubleArray(valuesPulled, buffer, offset); + MessageUtils.putDoubleArray(values, buffer, offset); return buffer; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java index 11eb25a6a..b3162cbe9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java @@ -25,18 +25,18 @@ import static org.apache.flink.ml.common.ps.message.MessageType.PUSH_KV; /** The sparse key-values to push from workers to servers. */ -public class KVsToPushM implements Message { +public class PushKvM implements Message { public final int serverId; public final int workerId; public final Tuple2 kvs; - public KVsToPushM(int workerId, int serverId, Tuple2 kvs) { + public PushKvM(int workerId, int serverId, Tuple2 kvs) { this.workerId = workerId; this.serverId = serverId; this.kvs = kvs; } - public static KVsToPushM fromBytes(byte[] bytes) { + public static PushKvM fromBytes(byte[] bytes) { int offset = 0; char type = Bits.getChar(bytes, offset); offset += Character.BYTES; @@ -46,8 +46,8 @@ public static KVsToPushM fromBytes(byte[] bytes) { offset += Integer.BYTES; int psId = Bits.getInt(bytes, offset); offset += Integer.BYTES; - Tuple2 grad = MessageUtils.readLongDoubleArray(bytes, offset); - return new KVsToPushM(workerId, psId, grad); + Tuple2 grad = MessageUtils.getLongDoubleArray(bytes, offset); + return new PushKvM(workerId, psId, grad); } @Override @@ -67,7 +67,7 @@ public byte[] toBytes() { offset += Integer.BYTES; Bits.putInt(buffer, offset, this.serverId); offset += Integer.BYTES; - MessageUtils.writeLongDoubleArray(kvs, buffer, offset); + MessageUtils.putLongDoubleArray(kvs, buffer, offset); return buffer; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java new file mode 100644 index 000000000..ebeca86d7 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.Vectors; + +import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; + +import java.io.IOException; +import java.util.List; + +/** An iteration stage that uses the pulled model values and batch data to compute the gradients. */ +public class ComputeGradients + extends ProcessStage> { + private final LossFunc lossFunc; + + public ComputeGradients(LossFunc lossFunc) { + this.lossFunc = lossFunc; + } + + @Override + public void process(MiniBatchMLSession session) + throws IOException { + long[] indices = ComputeIndices.getSortedIndices(session.batchData); + double[] modelValues = session.pulledValues; + double[] gradients = computeGradient(session.batchData, Tuple2.of(indices, modelValues)); + + session.pushIndices = indices; + session.pushValues = gradients; + } + + private double[] computeGradient( + List batchData, Tuple2 modelData) { + long[] modelIndices = modelData.f0; + double[] modelValues = modelData.f1; + Long2DoubleOpenHashMap modelInMap = new Long2DoubleOpenHashMap(modelIndices.length); + for (int i = 0; i < modelIndices.length; i++) { + modelInMap.put(modelIndices[i], modelValues[i]); + } + Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(modelIndices.length); + + for (LabeledLargePointWithWeight dataPoint : batchData) { + double dot = dot(dataPoint.features, modelInMap); + double multiplier = lossFunc.computeGradient(dataPoint.label, dot) * dataPoint.weight; + + long[] featureIndices = dataPoint.features.f0; + double[] featureValues = dataPoint.features.f1; + double z; + for (int i = 0; i < featureIndices.length; i++) { + long currentIndex = featureIndices[i]; + z = featureValues[i] * multiplier + cumGradients.getOrDefault(currentIndex, 0.); + cumGradients.put(currentIndex, z); + } + } + double[] cumGradientValues = new double[modelIndices.length]; + for (int i = 0; i < modelIndices.length; i++) { + cumGradientValues[i] = cumGradients.get(modelIndices[i]); + } + BLAS.scal(1.0 / batchData.size(), Vectors.dense(cumGradientValues)); + return cumGradientValues; + } + + private static double dot( + Tuple2 features, Long2DoubleOpenHashMap coefficient) { + double dot = 0; + for (int i = 0; i < features.f0.length; i++) { + dot += features.f1[i] * coefficient.get(features.f0[i]); + } + return dot; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java new file mode 100644 index 000000000..5cf868f21 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; + +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +/** + * An iteration stage that samples a batch of training data and computes the indices needed to + * compute gradients. + */ +public class ComputeIndices extends ProcessStage> { + + @Override + public void process(MiniBatchMLSession context) throws Exception { + context.readInNextBatchData(); + context.pullIndices = getSortedIndices(context.batchData); + } + + public static long[] getSortedIndices(List dataPoints) { + LongOpenHashSet indices = new LongOpenHashSet(); + for (LabeledLargePointWithWeight dataPoint : dataPoints) { + long[] notZeros = dataPoint.features.f0; + for (long index : notZeros) { + indices.add(index); + } + } + + long[] sortedIndices = new long[indices.size()]; + Iterator iterator = indices.iterator(); + int i = 0; + while (iterator.hasNext()) { + sortedIndices[i++] = iterator.next(); + } + Arrays.sort(sortedIndices); + return sortedIndices; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java index 08c2f8417..1cffbcaa4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java @@ -29,14 +29,14 @@ * A list of iteration stages to express the logic of an iterative machine learning training * process. */ -public class IterationStageList implements Serializable { - public final T context; +public class IterationStageList implements Serializable { + public final T session; public Function shouldTerminate; public List stageList; - public IterationStageList(T context) { + public IterationStageList(T session) { this.stageList = new ArrayList<>(); - this.context = context; + this.session = session; } /** Sets the criteria of termination. */ diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java index 9c4f7ee28..f19f035af 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java @@ -25,14 +25,14 @@ import java.io.Serializable; /** - * Stores the context information that is alive during the training process. Note that the context + * Stores the session information that is alive during the training process. Note that the session * information will be updated by each {@link IterationStage}. * - *

Note that subclasses should take care of the snapshot of object stored in {@link - * TrainingContext} if the object satisfies that: the write-process is followed by an {@link - * PullStage}, which is later again read by other stages. + *

Note that subclasses should take care of the snapshot of object stored in {@link MLSession} if + * the object satisfies that: the write-process is followed by an {@link PullStage}, which is later + * again read by other stages. */ -public interface TrainingContext extends Serializable { +public interface MLSession extends Serializable { /** Sets the current iteration ID. */ default void setIterationId(int iterationId) {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java new file mode 100644 index 000000000..13cc70e08 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.runtime.util.ResettableIterator; + +/** + * The default implementation of {@link MLSession}. + * + * @param

Data type of input data. + */ +public class MLSessionImpl
implements MLSession { + /** Current iteration id. */ + public int iterationId; + /** Index of this worker. */ + public int workerId; + /** Number of workers in total for this distributed ML job. */ + public int numWorkers; + /** The input data. */ + public ResettableIterator
inputData; + + @Override + public void setIterationId(int iterationId) { + this.iterationId = iterationId; + } + + @Override + public void setWorldInfo(int workerId, int numWorkers) { + this.workerId = workerId; + this.numWorkers = numWorkers; + } + + @Override + public void setInputData(ResettableIterator inputData) { + this.inputData = (ResettableIterator
) inputData; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java new file mode 100644 index 000000000..afeaa4cb5 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** The ML session for machine learning algorithms that adopts mini-batch training. */ +public class MiniBatchMLSession
extends MLSessionImpl
{ + + /** The placeholder for indices to pull for each iteration. */ + public long[] pullIndices; + /** The placeholder for the pulled values for each iteration. */ + public double[] pulledValues; + /** The placeholder for indices to push for each iteration. */ + public long[] pushIndices; + /** The placeholder for values to push for each iteration. */ + public double[] pushValues; + + /** The batch of training data for computing gradients. */ + public List
batchData; + + private ListState
batchDataState; + /** Global batch size. */ + private final int globalBatchSize; + /** The local batch size. */ + private int localBatchSize; + /** Type information of the input data. */ + private final TypeInformation
typeInformation; + + public MiniBatchMLSession(int globalBatchSize, TypeInformation
typeInformation) { + this.globalBatchSize = globalBatchSize; + this.typeInformation = typeInformation; + } + + @Override + public void setWorldInfo(int workerId, int numWorkers) { + super.setWorldInfo(workerId, numWorkers); + this.localBatchSize = globalBatchSize / numWorkers; + if (globalBatchSize % numWorkers > workerId) { + localBatchSize++; + } + this.batchData = new ArrayList<>(localBatchSize); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + batchDataState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("batchDataState", typeInformation)); + + Iterator
batchDataIterator = batchDataState.get().iterator(); + if (batchDataIterator.hasNext()) { + batchData = IteratorUtils.toList(batchDataIterator); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + batchDataState.clear(); + if (batchData.size() > 0) { + batchDataState.addAll(batchData); + } + } + + /** Reads in next batch of training data. */ + public void readInNextBatchData() throws IOException { + batchData.clear(); + int i = 0; + while (i < localBatchSize && inputData.hasNext()) { + batchData.add(inputData.next()); + i++; + } + if (!inputData.hasNext()) { + inputData.reset(); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java index 2469b1eeb..8a2daa751 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProcessStage.java @@ -20,14 +20,14 @@ /** * A local computation stage of the training process. The input and output of {@link ProcessStage} - * can be accessed via {@link TrainingContext}. + * can be accessed via {@link MLSession}. * * @param Type of the training data. */ -public abstract class ProcessStage implements IterationStage { +public abstract class ProcessStage implements IterationStage { /** * Does a local computation logic using the information from context. Example stages could be * computing gradients. */ - public abstract void process(T context) throws Exception; + public abstract void process(T session) throws Exception; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java index 5a94ae0d1..baf810ea4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java @@ -36,7 +36,7 @@ import org.apache.flink.ml.common.ps.MirrorWorkerOperator; import org.apache.flink.ml.common.ps.ServerOperator; import org.apache.flink.ml.common.ps.WorkerOperator; -import org.apache.flink.ml.common.updater.ModelUpdater; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; import org.apache.flink.ml.util.Bits; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; @@ -60,10 +60,9 @@ public static DataStream> train( DataStream modelDim, DataStream trainData, ModelUpdater modelUpdater, - IterationStageList iterationStages, + IterationStageList iterationStages, int numServers) { // TODO: Support incremental training for multiple models. - // TODO: Support user defined model partitioner. DataStream variableStream = modelDim.broadcast() @@ -89,12 +88,12 @@ public static DataStream> train( /** The iteration implementation for training process. */ private static class TrainIterationBody implements IterationBody { private final ModelUpdater modelUpdater; - private final IterationStageList iterationStages; + private final IterationStageList iterationStages; private final int numServers; public TrainIterationBody( ModelUpdater modelUpdater, - IterationStageList iterationStages, + IterationStageList iterationStages, int numServers) { this.iterationStages = iterationStages; this.modelUpdater = modelUpdater; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java index f38440274..c3fa92b2d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.updater; +package org.apache.flink.ml.common.ps.updater; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; @@ -31,10 +31,11 @@ import java.util.List; /** - * FTRL (Follow-the-regularized-leader) is an optimization algorithm which is widely deployed by online learning. + * FTRL (Follow-the-regularized-leader) is an optimization algorithm which is widely deployed by + * online learning. * - *

See H. Brendan McMahan et al., Ad click - * * prediction: a view from the trenches. + *

See H. Brendan McMahan et al., Ad click * + * prediction: a view from the trenches. */ public class FTRL implements ModelUpdater { private final double alpha; @@ -108,10 +109,10 @@ public double[] handlePull(long[] keys) { } @Override - public Iterator> getModelPieces() { - List> modelPieces = new ArrayList<>(); - modelPieces.add(Tuple3.of(startIndex, endIndex, weight)); - return modelPieces.iterator(); + public Iterator> getModelSegments() { + List> modelSegments = new ArrayList<>(); + modelSegments.add(Tuple3.of(startIndex, endIndex, weight)); + return modelSegments.iterator(); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java index 3aafae47b..feb04e185 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.common.updater; +package org.apache.flink.ml.common.ps.updater; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.runtime.state.StateInitializationContext; @@ -28,7 +28,8 @@ /** * A model updater that could be used to handle push/pull request from workers. * - *

Note that model updater should also ensure that model data is robust to failures. + *

Note that model updater should also ensure that model data is robust to failures, by writing + * model data to snapshots. */ public interface ModelUpdater extends Serializable { @@ -41,8 +42,12 @@ public interface ModelUpdater extends Serializable { /** Applies the pull and return the retrieved model data. */ double[] handlePull(long[] keys); - /** Returns model pieces with the format of (startFeatureIdx, endFeatureIdx, modelValues). */ - Iterator> getModelPieces(); + /** + * Returns model segments with the format of (startFeatureIdx, endFeatureIdx, modelValues). The + * model segments are continuously updated/retrieved by push/pull(i.e., `handlePush` and + * `handlePull`). + */ + Iterator> getModelSegments(); /** Recovers the model data from state. */ void initializeState(StateInitializationContext context) throws Exception; diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java index c0abea005..0ac5ae74a 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -241,10 +241,10 @@ public void testGetModelData() throws Exception { modelData.sort(Comparator.comparingLong(o -> o.startIndex)); double[] collectedCoefficient = new double[4]; - for (LogisticRegressionModelData modelPiece : modelData) { - int startIndex = (int) modelPiece.startIndex; - double[] pieceCoeff = modelPiece.coefficient.values; - System.arraycopy(pieceCoeff, 0, collectedCoefficient, startIndex, pieceCoeff.length); + for (LogisticRegressionModelData modelSegment : modelData) { + int startIndex = (int) modelSegment.startIndex; + double[] segment = modelSegment.coefficient.values; + System.arraycopy(segment, 0, collectedCoefficient, startIndex, segment.length); } assertArrayEquals(expectedCoefficient, collectedCoefficient, 0.1); } diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java index 32934ea30..ebaec117b 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java @@ -23,10 +23,12 @@ import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.util.Preconditions; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.List; /** Model data of {@link LogisticRegressionModelServable}. */ public class LogisticRegressionModelData { @@ -88,4 +90,31 @@ static LogisticRegressionModelData decode(InputStream inputStream) throws IOExce return new LogisticRegressionModelData(coefficient, startIndex, endIndex, modelVersion); } + + @VisibleForTesting + public static LogisticRegressionModelData mergeSegments( + List segments) { + long dim = 0; + for (LogisticRegressionModelData segment : segments) { + dim = Math.max(dim, segment.endIndex); + } + // TODO: Add distributed inference for very large models. + Preconditions.checkState( + dim < Integer.MAX_VALUE, + "The dimension of logistic regression model is larger than INT.MAX. Please consider using distributed inference."); + int intDim = (int) dim; + DenseVector mergedCoefficient = new DenseVector(intDim); + for (LogisticRegressionModelData segment : segments) { + int startIndex = (int) segment.startIndex; + int endIndex = (int) segment.endIndex; + System.arraycopy( + segment.coefficient.values, + 0, + mergedCoefficient.values, + startIndex, + endIndex - startIndex); + } + return new LogisticRegressionModelData( + mergedCoefficient, 0, mergedCoefficient.size(), segments.get(0).modelVersion); + } } diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java index c2e14029d..6662b6ccf 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java @@ -18,7 +18,6 @@ package org.apache.flink.ml.classification.logisticregression; -import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.DenseVector; @@ -82,49 +81,22 @@ public DataFrame transform(DataFrame input) { public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs) throws IOException { Preconditions.checkArgument(modelDataInputs.length == 1); - List modelPieces = new ArrayList<>(); + List modelSegments = new ArrayList<>(); while (true) { try { - LogisticRegressionModelData piece = + LogisticRegressionModelData segment = LogisticRegressionModelData.decode(modelDataInputs[0]); - modelPieces.add(piece); + modelSegments.add(segment); } catch (IOException e) { // Reached the end of model stream. break; } } - modelData = mergePieces(modelPieces); + modelData = LogisticRegressionModelData.mergeSegments(modelSegments); return this; } - @VisibleForTesting - public static LogisticRegressionModelData mergePieces( - List pieces) { - long dim = 0; - for (LogisticRegressionModelData piece : pieces) { - dim = Math.max(dim, piece.endIndex); - } - // TODO: Add distributed inference for very large models. - Preconditions.checkState( - dim < Integer.MAX_VALUE, - "The dimension of logistic regression model is larger than INT.MAX. Please consider using distributed inference."); - int intDim = (int) dim; - DenseVector mergedCoefficient = new DenseVector(intDim); - for (LogisticRegressionModelData piece : pieces) { - int startIndex = (int) piece.startIndex; - int endIndex = (int) piece.endIndex; - System.arraycopy( - piece.coefficient.values, - 0, - mergedCoefficient.values, - startIndex, - endIndex - startIndex); - } - return new LogisticRegressionModelData( - mergedCoefficient, 0, mergedCoefficient.size(), pieces.get(0).modelVersion); - } - public static LogisticRegressionModelServable load(String path) throws IOException { LogisticRegressionModelServable servable = ServableReadWriteUtils.loadServableParam( From d5dd3a9b5061fa44997581acc0d64cc53ac7c1ac Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Mon, 5 Jun 2023 10:53:03 +0800 Subject: [PATCH 08/18] add allreduce stage impl --- .../flink/ml/common/ps/RangePartitioner.java | 4 +- .../flink/ml/common/ps/ServerAgent.java | 56 +++++++--- .../flink/ml/common/ps/ServerOperator.java | 102 +++++++++++++++--- .../flink/ml/common/ps/WorkerOperator.java | 39 +++++-- .../ml/common/ps/message/AllReduceM.java | 72 +++++++++++++ .../ml/common/ps/message/MessageType.java | 6 +- .../ml/common/ps/message/PulledValueM.java | 2 +- .../ml/common/ps/training/AllReduceStage.java | 55 ++++++++++ .../ml/common/ps/training/PullStage.java | 7 +- .../ps/training/SerializableBiFunction.java | 6 ++ .../ml/common/ps/training/TrainingUtils.java | 3 +- 11 files changed, 309 insertions(+), 43 deletions(-) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableBiFunction.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java index e8de5e920..2bfc255e6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java @@ -84,8 +84,8 @@ private static class RequestsIterator implements Iterator>> output; - public ServerAgent(int workerId, Output>> output) { + ServerAgent(int workerId, Output>> output) { this.workerId = workerId; this.output = output; } - public void setPartitioner(RangePartitioner partitioner) { + void setPartitioner(RangePartitioner partitioner) { this.partitioner = partitioner; } - /** Pushes a key-value arrays to servers. */ - public void pushKVs(long[] indices, double[] values) { - Iterator> requests = - partitioner.splitRequest(indices, values); - while (requests.hasNext()) { - Tuple3 request = requests.next(); - PushKvM pushKvM = new PushKvM(workerId, request.f0, Tuple2.of(request.f1, request.f2)); - output.collect(new StreamRecord<>(Tuple2.of(request.f0, pushKvM.toBytes()))); - } - } - /** Sends a request to servers to initialize the values stored as zeros. */ - public void initializeModelAsZeros() { + void initializeModelAsZeros() { for (int serverId = 0; serverId < partitioner.numServers; serverId++) { long start = partitioner.ranges[serverId]; long end = partitioner.ranges[serverId + 1]; @@ -69,8 +60,19 @@ public void initializeModelAsZeros() { } } + /** Pushes a key-value arrays to servers. */ + void push(long[] indices, double[] values) { + Iterator> requests = + partitioner.splitRequest(indices, values); + while (requests.hasNext()) { + Tuple3 request = requests.next(); + PushKvM pushKvM = new PushKvM(workerId, request.f0, Tuple2.of(request.f1, request.f2)); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, pushKvM.toBytes()))); + } + } + /** Pulls the values from servers with the specified indices. */ - public void pull(long[] indices) { + void pull(long[] indices) { Iterator> requests = partitioner.splitRequest(indices, null); while (requests.hasNext()) { @@ -79,4 +81,28 @@ public void pull(long[] indices) { output.collect(new StreamRecord<>(Tuple2.of(request.f0, pullIndexM.toBytes()))); } } + + /** + * Pushes the values to servers to apply all reduce operation. + * + *

Note that the values pushed by this function are not going to update the model, but just + * perform an all reduce operation. + */ + void allReducePush(double[] values) { + final int MIN_MESSAGE_SIZE = 1024; + int numServers = partitioner.numServers; + int messageSize = Math.max(MIN_MESSAGE_SIZE, values.length / numServers + 1); + for (int serverId = 0; serverId < numServers; serverId++) { + int s = Math.min(serverId * messageSize, values.length); + int e = Math.min(s + messageSize, values.length); + double[] segment; + if (s == e) { + segment = new double[0]; + } else { + segment = Arrays.copyOfRange(values, s, e); + } + AllReduceM allReduceM = new AllReduceM(serverId, workerId, segment); + output.collect(new StreamRecord<>(Tuple2.of(serverId, allReduceM.toBytes()))); + } + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java index a9b9772da..31fd23441 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -25,6 +25,7 @@ import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.message.AllReduceM; import org.apache.flink.ml.common.ps.message.InitializeModelAsZeroM; import org.apache.flink.ml.common.ps.message.MessageType; import org.apache.flink.ml.common.ps.message.MessageUtils; @@ -69,6 +70,10 @@ * ModelUpdater#getModelSegments()}. * * + *

Moreover, it accepts all-reduce request from workers and returns the reduced result to all + * workers. Note that the input of all reduce operation is not going to be used in {@link + * ModelUpdater}. + * *

TODO: Add support for asynchronous operations on servers. * *

TODO: Add support for maintaining multiple parameters on servers. @@ -76,6 +81,8 @@ public class ServerOperator extends AbstractStreamOperator> implements OneInputStreamOperator, Tuple2>, IterationListener> { + /** Number of workers to communicate with. */ + private final int numWorkers; /** The logic to answer push/pull request from workers. */ private final ModelUpdater modelUpdater; /** Format of model data: start index, end index, dense double array. */ @@ -91,14 +98,20 @@ public class ServerOperator extends AbstractStreamOperator> futuresInEpoch = new ArrayList<>(); /** The merger for push requests. */ private final PushRequestMerger pushRequestMerger; + /** The merger for all reduce requests. */ + private final AllReduceMerger allReduceMerger; /** The pending pull requests. */ private ListState pendingPulls; public ServerOperator( - ModelUpdater modelUpdater, OutputTag> modelOutputTag) { + int numWorkers, + ModelUpdater modelUpdater, + OutputTag> modelOutputTag) { + this.numWorkers = numWorkers; this.modelUpdater = modelUpdater; this.modelOutputTag = modelOutputTag; this.pushRequestMerger = new PushRequestMerger(); + this.allReduceMerger = new AllReduceMerger(); } @Override @@ -117,12 +130,13 @@ public void processElement(StreamRecord> element) throws pendingPulls.add(request); break; case INITIALIZE_MODEL_AS_ZERO: - InitializeModelAsZeroM zerosToPush = InitializeModelAsZeroM.fromBytes(request); - Preconditions.checkState(serverId == zerosToPush.serverId); + InitializeModelAsZeroM initializeModelAsZeroM = + InitializeModelAsZeroM.fromBytes(request); + Preconditions.checkState(serverId == initializeModelAsZeroM.serverId); - long start = zerosToPush.startIndex; - long end = zerosToPush.endIndex; - if (zerosToPush.workerId == 0) { + long start = initializeModelAsZeroM.startIndex; + long end = initializeModelAsZeroM.endIndex; + if (initializeModelAsZeroM.workerId == 0) { modelUpdater.open(start, end); } break; @@ -131,6 +145,11 @@ public void processElement(StreamRecord> element) throws singleThreadExecutor.submit( () -> pushRequestMerger.processPushRequest(request))); break; + case ALL_REDUCE_VALUE: + futuresInEpoch.add( + singleThreadExecutor.submit( + () -> allReduceMerger.processAllReduceRequest(request))); + break; default: throw new UnsupportedOperationException("Unsupported message type: " + type + "."); } @@ -146,26 +165,45 @@ public void onEpochWatermarkIncremented( } futuresInEpoch.clear(); - if (epochWatermark > 0) { - // The first iteration contains no push kvs, but model initialization request. - Tuple2 kvs = pushRequestMerger.toKvArrays(); - pushRequestMerger.accumulatedKvsForMatrix.clear(); - pushRequestMerger.accumulatedKvsForVector.clear(); + // Processes the pushes first. + Tuple2 kvs = pushRequestMerger.toKvArrays(); + pushRequestMerger.accumulatedKvsForMatrix.clear(); + pushRequestMerger.accumulatedKvsForVector.clear(); + if (kvs.f0.length > 0) { + // There are pushes at this epoch. modelUpdater.handlePush(kvs.f0, kvs.f1); } Iterator pullsIterator = pendingPulls.get().iterator(); if (pullsIterator.hasNext()) { - // The last iteration contains no pulls. + // This is a pull stage. while (pullsIterator.hasNext()) { byte[] pull = pullsIterator.next(); futuresInEpoch.add(singleThreadExecutor.submit(() -> processPullRequest(pull))); } } + if (allReduceMerger.reducedResult != null) { + // This is an all reduce stage. + PulledValueM pulledValueM = + new PulledValueM(serverId, -1, allReduceMerger.reducedResult); + for (int workerId = 0; workerId < numWorkers; workerId++) { + int finalWorkerId = workerId; + pulledValueM.workerId = finalWorkerId; + futuresInEpoch.add( + singleThreadExecutor.submit( + () -> + output.collect( + new StreamRecord<>( + Tuple2.of( + finalWorkerId, + pulledValueM.toBytes()))))); + } + } for (Future future : futuresInEpoch) { future.get(); } pendingPulls.clear(); + allReduceMerger.reducedResult = null; futuresInEpoch.clear(); } @@ -190,6 +228,7 @@ public void initializeState(StateInitializationContext context) throws Exception PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); modelUpdater.initializeState(context); pushRequestMerger.initializeState(context); + allReduceMerger.initializeState(context); } @Override @@ -203,6 +242,7 @@ public void snapshotState(StateSnapshotContext context) throws Exception { futuresInEpoch.clear(); modelUpdater.snapshotState(context); pushRequestMerger.snapshotState(context); + allReduceMerger.snapshotState(context); } private Object processPullRequest(byte[] bytesData) { @@ -331,4 +371,42 @@ private void snapshotState(StateSnapshotContext context) throws Exception { } } } + + private static class AllReduceMerger implements Serializable { + private double[] reducedResult; + private ListState reducedResultState; + + private void processAllReduceRequest(byte[] request) { + AllReduceM allReduceM = AllReduceM.fromBytes(request); + double[] receivedValues = allReduceM.values; + if (reducedResult == null) { + reducedResult = receivedValues; + } else { + Preconditions.checkArgument(reducedResult.length == receivedValues.length); + for (int i = 0; i < reducedResult.length; i++) { + reducedResult[i] += receivedValues[i]; + } + } + } + + private void initializeState(StateInitializationContext context) throws Exception { + reducedResultState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor( + "reducedResultState", + PrimitiveArrayTypeInfo + .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); + reducedResult = + OperatorStateUtils.getUniqueElement(reducedResultState, "reducedResultState") + .orElse(null); + } + + private void snapshotState(StateSnapshotContext context) throws Exception { + reducedResultState.clear(); + if (reducedResult != null) { + reducedResultState.add(reducedResult); + } + } + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java index dc69974a0..7fc0011bb 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -27,6 +27,7 @@ import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.ps.message.PulledValueM; +import org.apache.flink.ml.common.ps.training.AllReduceStage; import org.apache.flink.ml.common.ps.training.IterationStage; import org.apache.flink.ml.common.ps.training.IterationStageList; import org.apache.flink.ml.common.ps.training.MLSession; @@ -43,6 +44,7 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; +import java.util.Arrays; import java.util.Iterator; /** @@ -141,12 +143,28 @@ public void processElement1(StreamRecord

streamRecord) throws Exception { public void processElement2(StreamRecord streamRecord) throws Exception { feedback = streamRecord.getValue(); if (modelDim > 0) { - // Decodes the pulled method and put it in ml session. - PullStage pullStage = (PullStage) iterationStages.stageList.get(nextStageToExecute); - PulledValueM valuesPulledMessage = PulledValueM.fromBytes(streamRecord.getValue()); - Preconditions.checkState( - getRuntimeContext().getIndexOfThisSubtask() == valuesPulledMessage.workerId); - pullStage.valuesConsumer.accept(valuesPulledMessage.values); + // Decodes the pulled values and puts it in ml session. + IterationStage stage = iterationStages.stageList.get(nextStageToExecute); + if (stage instanceof PullStage) { + PullStage pullStage = (PullStage) stage; + PulledValueM valuesPulledMessage = PulledValueM.fromBytes(streamRecord.getValue()); + Preconditions.checkState( + getRuntimeContext().getIndexOfThisSubtask() + == valuesPulledMessage.workerId); + pullStage.valuesConsumer.accept(valuesPulledMessage.values); + } else if (stage instanceof AllReduceStage) { + AllReduceStage allReduceStage = (AllReduceStage) stage; + PulledValueM pulledValueM = PulledValueM.fromBytes(streamRecord.getValue()); + Preconditions.checkState( + getRuntimeContext().getIndexOfThisSubtask() == pulledValueM.workerId); + System.out.println( + "Worker received allreduce result: " + + Arrays.toString(pulledValueM.values)); + allReduceStage.valuesConsumer.accept(pulledValueM.values); + } else { + throw new IllegalStateException( + String.format("Illegal stage type: %s", stage.getClass().getSimpleName())); + } nextStageToExecute++; nextStageToExecute = processTrainingStage(nextStageToExecute, iterationStages); @@ -247,10 +265,17 @@ private int processTrainingStage( PullStage pullStage = ((PullStage) stage); serverAgent.pull(pullStage.keysSupplier.get()); return nextStageToExecute; + } else if (stage instanceof AllReduceStage) { + // We are not incrementing nextStageToExecute here, since we will need to pull + // values from servers. + AllReduceStage allReduceStage = (AllReduceStage) stage; + serverAgent.allReducePush(allReduceStage.valuesSupplier.get()); + return nextStageToExecute; } else if (stage instanceof PushStage) { PushStage pushStage = (PushStage) stage; - serverAgent.pushKVs(pushStage.keysSupplier.get(), pushStage.valuesSupplier.get()); + serverAgent.push(pushStage.keysSupplier.get(), pushStage.valuesSupplier.get()); nextStageToExecute++; + } else if (stage instanceof ProcessStage) { ((ProcessStage) stage).process(iterationStages.session); nextStageToExecute++; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java new file mode 100644 index 000000000..d53af554e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.message; + +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +import static org.apache.flink.ml.common.ps.message.MessageType.ALL_REDUCE_VALUE; + +/** The message to apply all-reduce among workers. */ +public class AllReduceM implements Message { + public final int serverId; + public final int workerId; + public final double[] values; + + public AllReduceM(int serverId, int workerId, double[] values) { + this.serverId = serverId; + this.workerId = workerId; + this.values = values; + } + + public static AllReduceM fromBytes(byte[] bytes) { + int offset = 0; + char type = Bits.getChar(bytes, offset); + offset += Character.BYTES; + Preconditions.checkState(type == ALL_REDUCE_VALUE.type); + + int psId = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + int workerId = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + double[] values = MessageUtils.getDoubleArray(bytes, offset); + return new AllReduceM(psId, workerId, values); + } + + @Override + public byte[] toBytes() { + int numBytes = + Character.BYTES + + Integer.BYTES + + Integer.BYTES + + MessageUtils.getDoubleArraySizeInBytes(values); + byte[] buffer = new byte[numBytes]; + int offset = 0; + Bits.putChar(buffer, offset, ALL_REDUCE_VALUE.type); + offset += Character.BYTES; + + Bits.putInt(buffer, offset, this.serverId); + offset += Integer.BYTES; + Bits.putInt(buffer, offset, this.workerId); + offset += Integer.BYTES; + MessageUtils.putDoubleArray(values, buffer, offset); + + return buffer; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java index 79da886d6..9df4e599d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java @@ -33,7 +33,9 @@ public enum MessageType { * Message sent from workers to servers, which specifies the indices and values of the model to * push to servers. */ - PUSH_KV((char) 3); + PUSH_KV((char) 3), + /** Message to apply all-reduce among workers. */ + ALL_REDUCE_VALUE((char) 4); public final char type; @@ -51,6 +53,8 @@ public static MessageType valueOf(char value) { return MessageType.PULLED_VALUE; case ((char) 3): return MessageType.PUSH_KV; + case ((char) 4): + return MessageType.ALL_REDUCE_VALUE; default: throw new UnsupportedOperationException(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java index 96ab8b072..5457f1a57 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java @@ -26,7 +26,7 @@ /** The values pulled from servers. */ public class PulledValueM implements Message { public final int serverId; - public final int workerId; + public int workerId; public final double[] values; public PulledValueM(int serverId, int workerId, double[] values) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java new file mode 100644 index 000000000..c9153862e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.util.Preconditions; + +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** A communication stage that conducts all-reduce on the given double array. */ +public final class AllReduceStage implements IterationStage { + public final Supplier valuesSupplier; + public final Consumer valuesConsumer; + public final BiFunction valuesAggregator; + + public AllReduceStage( + Supplier valuesSupplier, + Consumer valuesConsumer, + BiFunction valuesAggregator) { + this.valuesSupplier = valuesSupplier; + this.valuesConsumer = valuesConsumer; + this.valuesAggregator = valuesAggregator; + } + + public AllReduceStage(Supplier valuesSupplier, Consumer valuesConsumer) { + this( + valuesSupplier, + valuesConsumer, + (SerializableBiFunction) + (array1, array2) -> { + Preconditions.checkState(array1.length == array2.length); + for (int i = 0; i < array1.length; i++) { + array2[i] += array1[i]; + } + return array2; + }); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java index 585b23296..fec86d87e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PullStage.java @@ -18,9 +18,8 @@ package org.apache.flink.ml.common.ps.training; -import org.apache.flink.util.function.SerializableSupplier; - import java.util.function.Consumer; +import java.util.function.Supplier; /** * A communication stage that pulls data from servers using keys as {@code @@ -28,10 +27,10 @@ * PullStage#valuesConsumer#accept()}. */ public final class PullStage implements IterationStage { - public final SerializableSupplier keysSupplier; + public final Supplier keysSupplier; public final Consumer valuesConsumer; - public PullStage(SerializableSupplier keysSupplier, Consumer valuesConsumer) { + public PullStage(Supplier keysSupplier, Consumer valuesConsumer) { this.keysSupplier = keysSupplier; this.valuesConsumer = valuesConsumer; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableBiFunction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableBiFunction.java new file mode 100644 index 000000000..e191a38a6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableBiFunction.java @@ -0,0 +1,6 @@ +package org.apache.flink.ml.common.ps.training; + +import java.io.Serializable; +import java.util.function.BiFunction; + +public interface SerializableBiFunction extends BiFunction, Serializable {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java index baf810ea4..3dff96868 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java @@ -133,7 +133,8 @@ public IterationBodyResult process( new TupleTypeInfo<>( Types.INT, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO), - new ServerOperator(modelUpdater, modelDataOutputTag)); + new ServerOperator( + numWorkers, modelUpdater, modelDataOutputTag)); messageToWorker.setParallelism(numServers); DataStream combinedMessageToWorker = From 5c1d7aa8626a8b5708129b996c9eaab0bd02f182 Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Tue, 6 Jun 2023 10:25:02 +0800 Subject: [PATCH 09/18] Reorganize Vectors and add SparseLongDoubleVector --- .../docs/operators/classification/knn.md | 2 +- .../operators/classification/linearsvc.md | 2 +- .../classification/logisticregression.md | 6 +- .../operators/classification/naivebayes.md | 2 +- .../clustering/agglomerativeclustering.md | 2 +- .../docs/operators/clustering/kmeans.md | 6 +- .../docs/operators/feature/countvectorizer.md | 2 +- docs/content/docs/operators/feature/dct.md | 2 +- .../operators/feature/elementwiseproduct.md | 2 +- .../docs/operators/feature/featurehasher.md | 2 +- .../docs/operators/feature/hashingtf.md | 2 +- docs/content/docs/operators/feature/idf.md | 2 +- .../docs/operators/feature/interaction.md | 2 +- .../operators/feature/kbinsdiscretizer.md | 2 +- .../docs/operators/feature/maxabsscaler.md | 2 +- .../docs/operators/feature/minhashlsh.md | 6 +- .../docs/operators/feature/minmaxscaler.md | 2 +- .../docs/operators/feature/normalizer.md | 2 +- .../docs/operators/feature/onehotencoder.md | 2 +- .../operators/feature/onlinestandardscaler.md | 10 +- .../operators/feature/polynomialexpansion.md | 2 +- .../docs/operators/feature/robustscaler.md | 2 +- .../docs/operators/feature/standardscaler.md | 2 +- .../feature/univariatefeatureselector.md | 2 +- .../feature/variancethresholdselector.md | 2 +- .../docs/operators/feature/vectorassembler.md | 2 +- .../docs/operators/feature/vectorslicer.md | 2 +- docs/content/docs/operators/functions.md | 4 +- .../operators/regression/linearregression.md | 2 +- .../java/build-your-own-project.md | 2 +- .../clustering/KMeansModelDataGenerator.java | 10 +- .../common/DenseVectorArrayGenerator.java | 10 +- .../common/DenseVectorGenerator.java | 5 +- .../LabeledPointWithWeightGenerator.java | 4 +- .../flink/ml/benchmark/DataGeneratorTest.java | 16 +- .../org/apache/flink/ml/api/StageTest.java | 4 +- .../ml/common/datastream/TableUtilsTest.java | 16 +- .../org/apache/flink/ml/util/TestUtils.java | 18 +- .../ml/examples/ArrayToVectorExample.java | 4 +- .../ml/examples/VectorToArrayExample.java | 3 +- .../examples/classification/KnnExample.java | 5 +- .../classification/LinearSVCExample.java | 9 +- .../LogisticRegressionExample.java | 8 +- .../classification/NaiveBayesExample.java | 5 +- .../OnlineLogisticRegressionExample.java | 12 +- .../AgglomerativeClusteringExample.java | 8 +- .../ml/examples/clustering/KMeansExample.java | 7 +- .../clustering/OnlineKMeansExample.java | 15 +- .../feature/CountVectorizerExample.java | 5 +- .../flink/ml/examples/feature/DCTExample.java | 8 +- .../feature/ElementwiseProductExample.java | 8 +- .../feature/FeatureHasherExample.java | 5 +- .../ml/examples/feature/HashingTFExample.java | 5 +- .../flink/ml/examples/feature/IDFExample.java | 8 +- .../examples/feature/InteractionExample.java | 5 +- .../feature/KBinsDiscretizerExample.java | 8 +- .../examples/feature/MaxAbsScalerExample.java | 8 +- .../examples/feature/MinHashLSHExample.java | 17 +- .../examples/feature/MinMaxScalerExample.java | 8 +- .../examples/feature/NormalizerExample.java | 6 +- .../feature/OneHotEncoderExample.java | 6 +- .../feature/OnlineStandardScalerExample.java | 16 +- .../feature/PolynomialExpansionExample.java | 8 +- .../examples/feature/RobustScalerExample.java | 8 +- .../feature/StandardScalerExample.java | 8 +- .../UnivariateFeatureSelectorExample.java | 10 +- .../VarianceThresholdSelectorExample.java | 10 +- .../feature/VectorAssemblerExample.java | 5 +- .../examples/feature/VectorSlicerExample.java | 7 +- .../regression/LinearRegressionExample.java | 5 +- .../java/org/apache/flink/ml/Functions.java | 26 +-- .../flink/ml/classification/knn/Knn.java | 38 ++-- .../flink/ml/classification/knn/KnnModel.java | 12 +- .../ml/classification/knn/KnnModelData.java | 28 +-- .../classification/linearsvc/LinearSVC.java | 14 +- .../linearsvc/LinearSVCModel.java | 15 +- .../linearsvc/LinearSVCModelData.java | 20 ++- .../LogisticRegression.java | 15 +- .../LogisticRegressionModel.java | 11 +- .../LogisticRegressionModelDataUtil.java | 4 +- .../OnlineLogisticRegression.java | 67 +++---- .../OnlineLogisticRegressionModel.java | 13 +- .../classification/naivebayes/NaiveBayes.java | 19 +- .../naivebayes/NaiveBayesModel.java | 13 +- .../naivebayes/NaiveBayesModelData.java | 36 ++-- .../flink/ml/clustering/kmeans/KMeans.java | 83 +++++---- .../ml/clustering/kmeans/KMeansModel.java | 7 +- .../ml/clustering/kmeans/KMeansModelData.java | 40 +++-- .../ml/clustering/kmeans/OnlineKMeans.java | 52 +++--- .../clustering/kmeans/OnlineKMeansModel.java | 8 +- .../common/lossfunc/BinaryLogisticLoss.java | 8 +- .../flink/ml/common/lossfunc/HingeLoss.java | 8 +- .../ml/common/lossfunc/LeastSquareLoss.java | 8 +- .../flink/ml/common/lossfunc/LossFunc.java | 8 +- .../flink/ml/common/optimizer/Optimizer.java | 6 +- .../common/optimizer/RegularizationUtils.java | 4 +- .../apache/flink/ml/common/optimizer/SGD.java | 31 ++-- .../flink/ml/common/util/VectorUtils.java | 20 +-- .../BinaryClassificationEvaluator.java | 6 +- .../flink/ml/feature/binarizer/Binarizer.java | 33 ++-- .../countvectorizer/CountVectorizerModel.java | 9 +- .../org/apache/flink/ml/feature/dct/DCT.java | 13 +- .../ElementwiseProduct.java | 10 +- .../ElementwiseProductParams.java | 8 +- .../feature/featurehasher/FeatureHasher.java | 4 +- .../flink/ml/feature/hashingtf/HashingTF.java | 5 +- .../org/apache/flink/ml/feature/idf/IDF.java | 36 ++-- .../apache/flink/ml/feature/idf/IDFModel.java | 9 +- .../flink/ml/feature/idf/IDFModelData.java | 17 +- .../ml/feature/interaction/Interaction.java | 26 +-- .../kbinsdiscretizer/KBinsDiscretizer.java | 32 ++-- .../KBinsDiscretizerModel.java | 12 +- .../org/apache/flink/ml/feature/lsh/LSH.java | 4 +- .../apache/flink/ml/feature/lsh/LSHModel.java | 32 ++-- .../flink/ml/feature/lsh/LSHModelData.java | 8 +- .../ml/feature/lsh/MinHashLSHModelData.java | 12 +- .../ml/feature/maxabsscaler/MaxAbsScaler.java | 31 ++-- .../maxabsscaler/MaxAbsScalerModel.java | 10 +- .../maxabsscaler/MaxAbsScalerModelData.java | 18 +- .../ml/feature/minmaxscaler/MinMaxScaler.java | 53 +++--- .../minmaxscaler/MinMaxScalerModel.java | 22 +-- .../minmaxscaler/MinMaxScalerModelData.java | 23 +-- .../ml/feature/normalizer/Normalizer.java | 6 +- .../onehotencoder/OneHotEncoderModel.java | 5 +- .../PolynomialExpansion.java | 23 +-- .../ml/feature/robustscaler/RobustScaler.java | 22 ++- .../robustscaler/RobustScalerModel.java | 13 +- .../robustscaler/RobustScalerModelData.java | 26 +-- .../standardscaler/OnlineStandardScaler.java | 32 ++-- .../OnlineStandardScalerModel.java | 17 +- .../standardscaler/StandardScaler.java | 68 +++---- .../standardscaler/StandardScalerModel.java | 14 +- .../StandardScalerModelData.java | 22 +-- .../UnivariateFeatureSelectorModel.java | 6 +- .../VarianceThresholdSelector.java | 46 ++--- .../VarianceThresholdSelectorModel.java | 6 +- .../vectorassembler/VectorAssembler.java | 40 ++--- .../feature/vectorindexer/VectorIndexer.java | 6 +- .../vectorindexer/VectorIndexerModel.java | 6 +- .../ml/feature/vectorslicer/VectorSlicer.java | 20 +-- .../linearregression/LinearRegression.java | 14 +- .../LinearRegressionModel.java | 12 +- .../LinearRegressionModelData.java | 18 +- .../flink/ml/stats/anovatest/ANOVATest.java | 32 ++-- .../flink/ml/stats/chisqtest/ChiSqTest.java | 16 +- .../flink/ml/stats/fvaluetest/FValueTest.java | 104 ++++++----- .../org/apache/flink/ml/FunctionsTest.java | 10 +- .../flink/ml/classification/KnnTest.java | 18 +- .../ml/classification/LinearSVCTest.java | 24 ++- .../LogisticRegressionTest.java | 40 +++-- .../LogisticRegressionWithFtrlTest.java | 20 ++- .../ml/classification/NaiveBayesTest.java | 29 +-- .../OnlineLogisticRegressionTest.java | 130 ++++++++------ .../AgglomerativeClusteringTest.java | 85 +++++---- .../flink/ml/clustering/KMeansTest.java | 35 ++-- .../flink/ml/clustering/OnlineKMeansTest.java | 56 +++--- .../lossfunc/BinaryLogisticLossTest.java | 6 +- .../ml/common/lossfunc/HingeLossTest.java | 6 +- .../common/lossfunc/LeastSquareLossTest.java | 6 +- .../optimizer/RegularizationUtilsTest.java | 7 +- .../flink/ml/common/util/VectorUtilsTest.java | 8 +- .../BinaryClassificationEvaluatorTest.java | 4 +- .../flink/ml/feature/BinarizerTest.java | 10 +- .../flink/ml/feature/CountVectorizerTest.java | 22 +-- .../org/apache/flink/ml/feature/DCTTest.java | 17 +- .../ml/feature/ElementwiseProductTest.java | 18 +- .../flink/ml/feature/FeatureHasherTest.java | 6 +- .../org/apache/flink/ml/feature/IDFTest.java | 13 +- .../flink/ml/feature/InteractionTest.java | 24 +-- .../ml/feature/KBinsDiscretizerTest.java | 5 +- .../flink/ml/feature/MaxAbsScalerTest.java | 14 +- .../flink/ml/feature/MinHashLSHTest.java | 34 ++-- .../flink/ml/feature/MinMaxScalerTest.java | 28 +-- .../flink/ml/feature/NormalizerTest.java | 12 +- .../flink/ml/feature/OneHotEncoderTest.java | 28 +-- .../ml/feature/OnlineStandardScalerTest.java | 31 ++-- .../ml/feature/PolynomialExpansionTest.java | 14 +- .../flink/ml/feature/RobustScalerTest.java | 37 ++-- .../flink/ml/feature/StandardScalerTest.java | 24 +-- .../UnivariateFeatureSelectorTest.java | 10 +- .../VarianceThresholdSelectorTest.java | 7 +- .../flink/ml/feature/VectorAssemblerTest.java | 22 ++- .../flink/ml/feature/VectorSlicerTest.java | 12 +- .../ml/regression/LinearRegressionTest.java | 14 +- .../apache/flink/ml/stats/ANOVATestTest.java | 10 +- .../apache/flink/ml/stats/FValueTestTest.java | 10 +- .../feature/LabeledPointWithWeight.java | 10 +- .../java/org/apache/flink/ml/linalg/BLAS.java | 102 ++++++----- ...eVector.java => DenseIntDoubleVector.java} | 32 ++-- .../flink/ml/linalg/IntDoubleVector.java | 41 +++++ .../flink/ml/linalg/LongDoubleVector.java | 40 +++++ ...Vector.java => SparseIntDoubleVector.java} | 32 ++-- .../ml/linalg/SparseLongDoubleVector.java | 167 ++++++++++++++++++ .../org/apache/flink/ml/linalg/Vector.java | 20 +-- .../flink/ml/linalg/VectorWithNorm.java | 6 +- .../org/apache/flink/ml/linalg/Vectors.java | 8 +- ...va => DenseIntDoubleVectorSerializer.java} | 41 ++--- ...java => DenseIntDoubleVectorTypeInfo.java} | 22 +-- ... DenseIntDoubleVectorTypeInfoFactory.java} | 10 +- ...a => SparseIntDoubleVectorSerializer.java} | 35 ++-- ...ava => SparseIntDoubleVectorTypeInfo.java} | 21 +-- ...SparseIntDoubleVectorTypeInfoFactory.java} | 10 +- .../SparseLongDoubleVectorSerializer.java | 155 ++++++++++++++++ .../SparseLongDoubleVectorTypeInfo.java | 87 +++++++++ ...SparseLongDoubleVectorTypeInfoFactory.java | 40 +++++ .../ml/linalg/typeinfo/VectorSerializer.java | 63 ++++--- .../ml/linalg/typeinfo/VectorTypeInfo.java | 3 +- .../typeinfo/VectorWithNormSerializer.java | 15 +- .../apache/flink/ml/param/VectorParam.java | 22 +-- .../org/apache/flink/ml/linalg/BLASTest.java | 44 +++-- .../flink/ml/linalg/DenseVectorTest.java | 10 +- .../flink/ml/linalg/SparseVectorTest.java | 24 +-- .../flink/ml/linalg/VectorWithNormTest.java | 5 +- .../LogisticRegressionModelData.java | 18 +- .../LogisticRegressionModelServable.java | 12 +- 215 files changed, 2448 insertions(+), 1589 deletions(-) rename flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/{DenseVector.java => DenseIntDoubleVector.java} (71%) create mode 100644 flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/IntDoubleVector.java create mode 100644 flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/LongDoubleVector.java rename flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/{SparseVector.java => SparseIntDoubleVector.java} (86%) create mode 100644 flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseLongDoubleVector.java rename flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/{DenseVectorSerializer.java => DenseIntDoubleVectorSerializer.java} (74%) rename flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/{DenseVectorTypeInfo.java => DenseIntDoubleVectorTypeInfo.java} (71%) rename flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/{DenseVectorTypeInfoFactory.java => DenseIntDoubleVectorTypeInfoFactory.java} (81%) rename flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/{SparseVectorSerializer.java => SparseIntDoubleVectorSerializer.java} (76%) rename flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/{SparseVectorTypeInfo.java => SparseIntDoubleVectorTypeInfo.java} (70%) rename flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/{SparseVectorTypeInfoFactory.java => SparseIntDoubleVectorTypeInfoFactory.java} (80%) create mode 100644 flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorSerializer.java create mode 100644 flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfo.java create mode 100644 flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfoFactory.java diff --git a/docs/content/docs/operators/classification/knn.md b/docs/content/docs/operators/classification/knn.md index 0724f2daf..f62af4de7 100644 --- a/docs/content/docs/operators/classification/knn.md +++ b/docs/content/docs/operators/classification/knn.md @@ -67,7 +67,7 @@ Below are the parameters required by `KnnModel`. ```java import org.apache.flink.ml.classification.knn.Knn; import org.apache.flink.ml.classification.knn.KnnModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/classification/linearsvc.md b/docs/content/docs/operators/classification/linearsvc.md index 5b134a995..02e5f912a 100644 --- a/docs/content/docs/operators/classification/linearsvc.md +++ b/docs/content/docs/operators/classification/linearsvc.md @@ -77,7 +77,7 @@ Below are the parameters required by `LinearSVCModel`. ```java import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/classification/logisticregression.md b/docs/content/docs/operators/classification/logisticregression.md index edd9f8d33..c68f9ec2b 100644 --- a/docs/content/docs/operators/classification/logisticregression.md +++ b/docs/content/docs/operators/classification/logisticregression.md @@ -74,7 +74,7 @@ Below are the parameters required by `LogisticRegressionModel`. ```java import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -251,9 +251,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; diff --git a/docs/content/docs/operators/classification/naivebayes.md b/docs/content/docs/operators/classification/naivebayes.md index 3fe9beb8e..b6900bafc 100644 --- a/docs/content/docs/operators/classification/naivebayes.md +++ b/docs/content/docs/operators/classification/naivebayes.md @@ -66,7 +66,7 @@ Below are parameters required by `NaiveBayesModel`. ```java import org.apache.flink.ml.classification.naivebayes.NaiveBayes; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/clustering/agglomerativeclustering.md b/docs/content/docs/operators/clustering/agglomerativeclustering.md index 9ded65cca..254f92bfb 100644 --- a/docs/content/docs/operators/clustering/agglomerativeclustering.md +++ b/docs/content/docs/operators/clustering/agglomerativeclustering.md @@ -69,7 +69,7 @@ format of the merging information is import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering; import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/clustering/kmeans.md b/docs/content/docs/operators/clustering/kmeans.md index eeeff603a..cb6dae84f 100644 --- a/docs/content/docs/operators/clustering/kmeans.md +++ b/docs/content/docs/operators/clustering/kmeans.md @@ -67,7 +67,7 @@ Below are the parameters required by `KMeansModel`. ```java import org.apache.flink.ml.clustering.kmeans.KMeans; import org.apache.flink.ml.clustering.kmeans.KMeansModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -228,9 +228,9 @@ import org.apache.flink.ml.clustering.kmeans.KMeansModelData; import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; diff --git a/docs/content/docs/operators/feature/countvectorizer.md b/docs/content/docs/operators/feature/countvectorizer.md index b6658c06c..68905ac02 100644 --- a/docs/content/docs/operators/feature/countvectorizer.md +++ b/docs/content/docs/operators/feature/countvectorizer.md @@ -72,7 +72,7 @@ Below are the parameters required by `CountVectorizerModel`. ```java import org.apache.flink.ml.feature.countvectorizer.CountVectorizer; import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/dct.md b/docs/content/docs/operators/feature/dct.md index 356260be5..fde4b1e9b 100644 --- a/docs/content/docs/operators/feature/dct.md +++ b/docs/content/docs/operators/feature/dct.md @@ -60,7 +60,7 @@ that the transform matrix is unitary (aka scaled DCT-II). ```java import org.apache.flink.ml.feature.dct.DCT; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/elementwiseproduct.md b/docs/content/docs/operators/feature/elementwiseproduct.md index 0021c15b4..4461262de 100644 --- a/docs/content/docs/operators/feature/elementwiseproduct.md +++ b/docs/content/docs/operators/feature/elementwiseproduct.md @@ -58,7 +58,7 @@ scaling vector, the transformer will throw an IllegalArgumentException. ```java import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/featurehasher.md b/docs/content/docs/operators/feature/featurehasher.md index e804d9ae6..af93911fb 100644 --- a/docs/content/docs/operators/feature/featurehasher.md +++ b/docs/content/docs/operators/feature/featurehasher.md @@ -69,7 +69,7 @@ For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for det ```java import org.apache.flink.ml.feature.featurehasher.FeatureHasher; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/hashingtf.md b/docs/content/docs/operators/feature/hashingtf.md index d340d9096..45a27af59 100644 --- a/docs/content/docs/operators/feature/hashingtf.md +++ b/docs/content/docs/operators/feature/hashingtf.md @@ -63,7 +63,7 @@ the output values are accumulated by default. ```java import org.apache.flink.ml.feature.hashingtf.HashingTF; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/idf.md b/docs/content/docs/operators/feature/idf.md index ca26c286a..2b478184f 100644 --- a/docs/content/docs/operators/feature/idf.md +++ b/docs/content/docs/operators/feature/idf.md @@ -73,7 +73,7 @@ Below are the parameters required by `IDFModel`. ```java import org.apache.flink.ml.feature.idf.IDF; import org.apache.flink.ml.feature.idf.IDFModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/interaction.md b/docs/content/docs/operators/feature/interaction.md index 1b9bca1bc..a313b1246 100644 --- a/docs/content/docs/operators/feature/interaction.md +++ b/docs/content/docs/operators/feature/interaction.md @@ -62,7 +62,7 @@ be Vector(3, 6, 4, 8). ```java import org.apache.flink.ml.feature.interaction.Interaction; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/kbinsdiscretizer.md b/docs/content/docs/operators/feature/kbinsdiscretizer.md index b26e7804c..047851742 100644 --- a/docs/content/docs/operators/feature/kbinsdiscretizer.md +++ b/docs/content/docs/operators/feature/kbinsdiscretizer.md @@ -70,7 +70,7 @@ Below are the parameters required by `KBinsDiscretizerModel`. import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/maxabsscaler.md b/docs/content/docs/operators/feature/maxabsscaler.md index 1c3d4fd91..c5ba4c81b 100644 --- a/docs/content/docs/operators/feature/maxabsscaler.md +++ b/docs/content/docs/operators/feature/maxabsscaler.md @@ -59,7 +59,7 @@ It does not shift/center the data and thus does not destroy any sparsity. ```java import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler; import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/minhashlsh.md b/docs/content/docs/operators/feature/minhashlsh.md index 5c3aee92a..0c7494cd7 100644 --- a/docs/content/docs/operators/feature/minhashlsh.md +++ b/docs/content/docs/operators/feature/minhashlsh.md @@ -75,9 +75,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.feature.lsh.MinHashLSH; import org.apache.flink.ml.feature.lsh.MinHashLSHModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/minmaxscaler.md b/docs/content/docs/operators/feature/minmaxscaler.md index b4427b629..ed0e953eb 100644 --- a/docs/content/docs/operators/feature/minmaxscaler.md +++ b/docs/content/docs/operators/feature/minmaxscaler.md @@ -59,7 +59,7 @@ MinMaxScaler is an algorithm that rescales feature values to a common range ```java import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/normalizer.md b/docs/content/docs/operators/feature/normalizer.md index 65e3760a4..7562d33fa 100644 --- a/docs/content/docs/operators/feature/normalizer.md +++ b/docs/content/docs/operators/feature/normalizer.md @@ -57,7 +57,7 @@ A Transformer that normalizes a vector to have unit norm using the given p-norm. ```java import org.apache.flink.ml.feature.normalizer.Normalizer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/onehotencoder.md b/docs/content/docs/operators/feature/onehotencoder.md index 32ff2ffa0..3b706445b 100644 --- a/docs/content/docs/operators/feature/onehotencoder.md +++ b/docs/content/docs/operators/feature/onehotencoder.md @@ -64,7 +64,7 @@ vector column for each input column. ```java import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; diff --git a/docs/content/docs/operators/feature/onlinestandardscaler.md b/docs/content/docs/operators/feature/onlinestandardscaler.md index ef998f1f8..d158d16e5 100644 --- a/docs/content/docs/operators/feature/onlinestandardscaler.md +++ b/docs/content/docs/operators/feature/onlinestandardscaler.md @@ -91,9 +91,9 @@ import org.apache.flink.api.common.time.Time; import org.apache.flink.ml.common.window.EventTimeTumblingWindows; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.DataTypes; @@ -202,13 +202,13 @@ t_env = StreamTableEnvironment.create(env) # Generates input data. dense_vector_serializer = get_gateway().jvm.org.apache.flink.table.types.logical.RawType( - get_gateway().jvm.org.apache.flink.ml.linalg.DenseVector(0).getClass(), - get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer() + get_gateway().jvm.org.apache.flink.ml.linalg.DenseIntDoubleVector(0).getClass(), + get_gateway().jvm.org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer() ).getSerializerString() schema = Schema.new_builder() .column("ts", "TIMESTAMP_LTZ(3)") - .column("input", "RAW('org.apache.flink.ml.linalg.DenseVector', '{serializer}')" + .column("input", "RAW('org.apache.flink.ml.linalg.DenseIntDoubleVector', '{serializer}')" .format(serializer=dense_vector_serializer)) .watermark("ts", "ts - INTERVAL '1' SECOND") .build() diff --git a/docs/content/docs/operators/feature/polynomialexpansion.md b/docs/content/docs/operators/feature/polynomialexpansion.md index a7f5ee899..ab5cc89a9 100644 --- a/docs/content/docs/operators/feature/polynomialexpansion.md +++ b/docs/content/docs/operators/feature/polynomialexpansion.md @@ -63,7 +63,7 @@ http://en.wikipedia.org/wiki/Polynomial_expansion. ```java import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/robustscaler.md b/docs/content/docs/operators/feature/robustscaler.md index a37085dd6..69c50740f 100644 --- a/docs/content/docs/operators/feature/robustscaler.md +++ b/docs/content/docs/operators/feature/robustscaler.md @@ -87,7 +87,7 @@ Below are the parameters required by `RobustScalerModel`. ```java import org.apache.flink.ml.feature.robustscaler.RobustScaler; import org.apache.flink.ml.feature.robustscaler.RobustScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/standardscaler.md b/docs/content/docs/operators/feature/standardscaler.md index 55879d35c..c313a400d 100644 --- a/docs/content/docs/operators/feature/standardscaler.md +++ b/docs/content/docs/operators/feature/standardscaler.md @@ -59,7 +59,7 @@ the mean and scaling each dimension to unit variance. ```java import org.apache.flink.ml.feature.standardscaler.StandardScaler; import org.apache.flink.ml.feature.standardscaler.StandardScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/univariatefeatureselector.md b/docs/content/docs/operators/feature/univariatefeatureselector.md index 873919e21..577fd58a8 100644 --- a/docs/content/docs/operators/feature/univariatefeatureselector.md +++ b/docs/content/docs/operators/feature/univariatefeatureselector.md @@ -105,7 +105,7 @@ Below are the parameters required by `UnivariateFeatureSelectorModel`. ```java import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector; import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/variancethresholdselector.md b/docs/content/docs/operators/feature/variancethresholdselector.md index 5cd9a2f14..474276a1f 100644 --- a/docs/content/docs/operators/feature/variancethresholdselector.md +++ b/docs/content/docs/operators/feature/variancethresholdselector.md @@ -69,7 +69,7 @@ Below are the parameters required by `VarianceThresholdSelectorModel`. ```java import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector; import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/vectorassembler.md b/docs/content/docs/operators/feature/vectorassembler.md index 2877e419c..df8ea5216 100644 --- a/docs/content/docs/operators/feature/vectorassembler.md +++ b/docs/content/docs/operators/feature/vectorassembler.md @@ -69,7 +69,7 @@ the strategy specified by the {@link HasHandleInvalid} parameter as follows: ```java import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/feature/vectorslicer.md b/docs/content/docs/operators/feature/vectorslicer.md index c7965e0c0..8b6f0c11d 100644 --- a/docs/content/docs/operators/feature/vectorslicer.md +++ b/docs/content/docs/operators/feature/vectorslicer.md @@ -61,7 +61,7 @@ it throws an IllegalArgumentException. ```java import org.apache.flink.ml.feature.vectorslicer.VectorSlicer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/docs/content/docs/operators/functions.md b/docs/content/docs/operators/functions.md index 96968eae8..32bfd71fa 100644 --- a/docs/content/docs/operators/functions.md +++ b/docs/content/docs/operators/functions.md @@ -38,7 +38,7 @@ of double arrays. {{< tab "Java">}} ```java -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -145,7 +145,7 @@ DenseVector instances. {{< tab "Java">}} ```java -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; diff --git a/docs/content/docs/operators/regression/linearregression.md b/docs/content/docs/operators/regression/linearregression.md index 3b9717be5..1996ffdfe 100644 --- a/docs/content/docs/operators/regression/linearregression.md +++ b/docs/content/docs/operators/regression/linearregression.md @@ -72,7 +72,7 @@ Below are the parameters required by `LinearRegressionModel`. {{< tab "Java">}} ```java -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; diff --git a/docs/content/docs/try-flink-ml/java/build-your-own-project.md b/docs/content/docs/try-flink-ml/java/build-your-own-project.md index 713bb9241..78a1752ce 100644 --- a/docs/content/docs/try-flink-ml/java/build-your-own-project.md +++ b/docs/content/docs/try-flink-ml/java/build-your-own-project.md @@ -176,7 +176,7 @@ package myflinkml; import org.apache.flink.ml.clustering.kmeans.KMeans; import org.apache.flink.ml.clustering.kmeans.KMeansModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java index 4a272e54f..6ca9f1027 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.java @@ -23,8 +23,8 @@ import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorArrayGenerator; import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.table.api.DataTypes; @@ -77,8 +77,8 @@ public Table[] getData(StreamTableEnvironment tEnv) { * information. */ public static class GenerateWeightsFunction extends ScalarFunction { - public DenseVector eval(DenseVector[] centroids) { - return new DenseVector(centroids.length); + public DenseIntDoubleVector eval(DenseIntDoubleVector[] centroids) { + return new DenseIntDoubleVector(centroids.length); } @Override @@ -87,7 +87,7 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) { .outputTypeStrategy( callContext -> Optional.of( - DataTypes.of(DenseVectorTypeInfo.INSTANCE) + DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE) .toDataType(typeFactory))) .build(); } diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java index 0ab32cbb9..ef9c15baf 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java @@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.types.Row; import org.apache.flink.util.Preconditions; @@ -43,9 +43,9 @@ protected RowGenerator[] getRowGenerators() { new RowGenerator(getNumValues(), getSeed()) { @Override protected Row getRow() { - DenseVector[] result = new DenseVector[arraySize]; + DenseIntDoubleVector[] result = new DenseIntDoubleVector[arraySize]; for (int i = 0; i < arraySize; i++) { - result[i] = new DenseVector(vectorDim); + result[i] = new DenseIntDoubleVector(vectorDim); for (int j = 0; j < vectorDim; j++) { result[i].values[j] = random.nextDouble(); } @@ -58,7 +58,9 @@ protected Row getRow() { @Override protected RowTypeInfo getRowTypeInfo() { return new RowTypeInfo( - new TypeInformation[] {TypeInformation.of(DenseVector[].class)}, + new TypeInformation[] { + TypeInformation.of(DenseIntDoubleVector[].class) + }, columnNames[0]); } } diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java index 2923129f2..fcba2c51e 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java @@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.types.Row; import org.apache.flink.util.Preconditions; @@ -52,7 +52,8 @@ protected Row getRow() { @Override protected RowTypeInfo getRowTypeInfo() { return new RowTypeInfo( - new TypeInformation[] {DenseVectorTypeInfo.INSTANCE}, columnNames[0]); + new TypeInformation[] {DenseIntDoubleVectorTypeInfo.INSTANCE}, + columnNames[0]); } } }; diff --git a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java index dfc7b157d..4f4cdf383 100644 --- a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java +++ b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.java @@ -23,7 +23,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.IntParam; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.ParamValidators; @@ -113,7 +113,7 @@ protected Row getRow() { protected RowTypeInfo getRowTypeInfo() { return new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, colNames[0]); } diff --git a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java index edb53197f..6c5fe64e0 100644 --- a/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java +++ b/flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java @@ -25,7 +25,7 @@ import org.apache.flink.ml.benchmark.datagenerator.common.LabeledPointWithWeightGenerator; import org.apache.flink.ml.benchmark.datagenerator.common.RandomStringArrayGenerator; import org.apache.flink.ml.benchmark.datagenerator.common.RandomStringGenerator; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; @@ -68,9 +68,10 @@ public void testDenseVectorGenerator() { it.hasNext(); ) { Row row = it.next(); assertEquals(1, row.getArity()); - DenseVector vector = (DenseVector) row.getField(generator.getColNames()[0][0]); + DenseIntDoubleVector vector = + (DenseIntDoubleVector) row.getField(generator.getColNames()[0][0]); assertNotNull(vector); - assertEquals(vector.size(), generator.getVectorDim()); + assertEquals(vector.size().intValue(), generator.getVectorDim()); count++; } assertEquals(generator.getNumValues(), count); @@ -90,11 +91,12 @@ public void testDenseVectorArrayGenerator() { it.hasNext(); ) { Row row = it.next(); assertEquals(1, row.getArity()); - DenseVector[] vectors = (DenseVector[]) row.getField(generator.getColNames()[0][0]); + DenseIntDoubleVector[] vectors = + (DenseIntDoubleVector[]) row.getField(generator.getColNames()[0][0]); assertNotNull(vectors); assertEquals(generator.getArraySize(), vectors.length); - for (DenseVector vector : vectors) { - assertEquals(vector.size(), generator.getVectorDim()); + for (DenseIntDoubleVector vector : vectors) { + assertEquals(vector.size().intValue(), generator.getVectorDim()); } count++; } @@ -119,7 +121,7 @@ public void testLabeledPointWithWeightGenerator() { it.hasNext(); ) { Row row = it.next(); count++; - DenseVector features = (DenseVector) row.getField(featuresCol); + DenseIntDoubleVector features = (DenseIntDoubleVector) row.getField(featuresCol); assertNotNull(features); for (double value : features.values) { assertTrue(value >= 0); diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java index f812f8b20..acbff8c50 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java @@ -26,7 +26,7 @@ import org.apache.flink.ml.common.window.ProcessingTimeSessionWindows; import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows; import org.apache.flink.ml.common.window.Windows; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.BooleanParam; import org.apache.flink.ml.param.DoubleArrayArrayParam; @@ -138,7 +138,7 @@ private interface MyParams extends WithParams { "Description", new Double[][] {new Double[] {14.0, 15.0}, new Double[] {16.0, 17.0}}); - Param VECTOR_PARAM = + Param VECTOR_PARAM = new VectorParam("vectorParam", "Description", Vectors.dense(1.0, 2.0, 3.0)); Param WINDOWS_PARAM = diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java index e357d80c4..5a28997ce 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java @@ -22,12 +22,12 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.DenseMatrix; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfo; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -114,10 +114,10 @@ public void testGetRowTypeInfo() { dataFields.add(Collections.singletonMap(0.1, 1)); preDefinedDataTypes.add(DataTypes.ROW(DataTypes.INT(), DataTypes.BIGINT())); dataFields.add(Row.of(1, 2L)); - preDefinedDataTypes.add(DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)); - dataFields.add(new DenseVector(new double[] {0.1, 0.2})); - preDefinedDataTypes.add(DataTypes.RAW(SparseVectorTypeInfo.INSTANCE)); - dataFields.add(new SparseVector(2, new int[] {0}, new double[] {0.1})); + preDefinedDataTypes.add(DataTypes.RAW(DenseIntDoubleVectorTypeInfo.INSTANCE)); + dataFields.add(new DenseIntDoubleVector(new double[] {0.1, 0.2})); + preDefinedDataTypes.add(DataTypes.RAW(SparseIntDoubleVectorTypeInfo.INSTANCE)); + dataFields.add(new SparseIntDoubleVector(2, new int[] {0}, new double[] {0.1})); preDefinedDataTypes.add(DataTypes.RAW(DenseMatrixTypeInfo.INSTANCE)); dataFields.add(new DenseMatrix(2, 2)); preDefinedDataTypes.add( diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java index ec97b48c6..683c19e46 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java @@ -34,9 +34,9 @@ import org.apache.flink.ml.api.Stage; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.api.TransformerServable; import org.apache.flink.streaming.api.datastream.DataStream; @@ -253,8 +253,8 @@ public static Table convertDataTypesToSparseInt(StreamTableEnvironment tEnv, Tab RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(table.getResolvedSchema()); TypeInformation[] fieldTypes = inputTypeInfo.getFieldTypes(); for (int i = 0; i < fieldTypes.length; i++) { - if (fieldTypes[i].getTypeClass().equals(DenseVector.class)) { - fieldTypes[i] = SparseVectorTypeInfo.INSTANCE; + if (fieldTypes[i].getTypeClass().equals(DenseIntDoubleVector.class)) { + fieldTypes[i] = SparseIntDoubleVectorTypeInfo.INSTANCE; } else if (fieldTypes[i].getTypeClass().equals(Double.class)) { fieldTypes[i] = Types.INT; } @@ -273,8 +273,8 @@ public Row map(Row row) { int arity = row.getArity(); for (int i = 0; i < arity; i++) { Object obj = row.getField(i); - if (obj instanceof Vector) { - row.setField(i, ((Vector) obj).toSparse()); + if (obj instanceof IntDoubleVector) { + row.setField(i, ((IntDoubleVector) obj).toSparse()); } else if (obj instanceof Number) { row.setField(i, ((Number) obj).intValue()); } @@ -294,8 +294,8 @@ public static Class[] getColumnDataTypes(Table table) { } /** Note: this comparator imposes orderings that are inconsistent with equals. */ - public static int compare(Vector first, Vector second) { - if (first.size() != second.size()) { + public static int compare(IntDoubleVector first, IntDoubleVector second) { + if (first.size().intValue() != second.size().intValue()) { return Integer.compare(first.size(), second.size()); } else { for (int i = 0; i < first.size(); i++) { diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/ArrayToVectorExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/ArrayToVectorExample.java index 85c57b774..dc22dac24 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/ArrayToVectorExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/ArrayToVectorExample.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.examples; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -49,7 +49,7 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); Double[] doubleArray = row.getFieldAs("array"); - Vector vector = row.getFieldAs("vector"); + IntDoubleVector vector = row.getFieldAs("vector"); System.out.printf( "Input double array: %s\tOutput vector: %s\n", Arrays.toString(doubleArray), vector); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/VectorToArrayExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/VectorToArrayExample.java index 733fe45ce..f67a3b742 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/VectorToArrayExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/VectorToArrayExample.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.examples; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -54,7 +55,7 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector vector = row.getFieldAs("vector"); + IntDoubleVector vector = row.getFieldAs("vector"); Double[] doubleArray = row.getFieldAs("array"); System.out.printf( "Input vector: %s\tOutput double array: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/KnnExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/KnnExample.java index 9941c3065..3f3ca9fe4 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/KnnExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/KnnExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.classification.knn.Knn; import org.apache.flink.ml.classification.knn.KnnModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -78,7 +78,8 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(knn.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(knn.getFeaturesCol()); double expectedResult = (Double) row.getField(knn.getLabelCol()); double predictionResult = (Double) row.getField(knn.getPredictionCol()); System.out.printf( diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LinearSVCExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LinearSVCExample.java index f4941fb19..79ff44e74 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LinearSVCExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LinearSVCExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -62,11 +62,12 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(linearSVC.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(linearSVC.getFeaturesCol()); double expectedResult = (Double) row.getField(linearSVC.getLabelCol()); double predictionResult = (Double) row.getField(linearSVC.getPredictionCol()); - DenseVector rawPredictionResult = - (DenseVector) row.getField(linearSVC.getRawPredictionCol()); + DenseIntDoubleVector rawPredictionResult = + (DenseIntDoubleVector) row.getField(linearSVC.getRawPredictionCol()); System.out.printf( "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n", features, expectedResult, predictionResult, rawPredictionResult); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LogisticRegressionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LogisticRegressionExample.java index fc2168d06..6b3206496 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LogisticRegressionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/LogisticRegressionExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -62,10 +62,12 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(lr.getFeaturesCol()); double expectedResult = (Double) row.getField(lr.getLabelCol()); double predictionResult = (Double) row.getField(lr.getPredictionCol()); - DenseVector rawPredictionResult = (DenseVector) row.getField(lr.getRawPredictionCol()); + DenseIntDoubleVector rawPredictionResult = + (DenseIntDoubleVector) row.getField(lr.getRawPredictionCol()); System.out.printf( "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n", features, expectedResult, predictionResult, rawPredictionResult); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/NaiveBayesExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/NaiveBayesExample.java index 1f369200a..e78f14520 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/NaiveBayesExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/NaiveBayesExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.classification.naivebayes.NaiveBayes; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -69,7 +69,8 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(naiveBayes.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(naiveBayes.getFeaturesCol()); double predictionResult = (Double) row.getField(naiveBayes.getPredictionCol()); System.out.printf("Features: %s \tPrediction Result: %s\n", features, predictionResult); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java index 9523050ee..6fb1a0eff 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/classification/OnlineLogisticRegressionExample.java @@ -24,9 +24,9 @@ import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -81,7 +81,7 @@ public static void main(String[] args) { RowTypeInfo typeInfo = new RowTypeInfo( - new TypeInformation[] {DenseVectorTypeInfo.INSTANCE, Types.DOUBLE}, + new TypeInformation[] {DenseIntDoubleVectorTypeInfo.INSTANCE, Types.DOUBLE}, new String[] {"features", "label"}); SourceFunction trainSource = @@ -120,10 +120,12 @@ public static void main(String[] args) { // would change over time. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(olr.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(olr.getFeaturesCol()); Double expectedResult = (Double) row.getField(olr.getLabelCol()); Double predictionResult = (Double) row.getField(olr.getPredictionCol()); - DenseVector rawPredictionResult = (DenseVector) row.getField(olr.getRawPredictionCol()); + DenseIntDoubleVector rawPredictionResult = + (DenseIntDoubleVector) row.getField(olr.getRawPredictionCol()); System.out.printf( "Features: %-25s \tExpected Result: %s \tPrediction Result: %s \tRaw Prediction Result: %s\n", features, expectedResult, predictionResult, rawPredictionResult); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java index c48448f1d..54fa8851e 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/AgglomerativeClusteringExample.java @@ -21,7 +21,7 @@ import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering; import org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClusteringParams; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -37,7 +37,7 @@ public static void main(String[] args) { StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. - DataStream inputStream = + DataStream inputStream = env.fromElements( Vectors.dense(1, 1), Vectors.dense(1, 4), @@ -61,8 +61,8 @@ public static void main(String[] args) { // Extracts and displays the clustering results. for (CloseableIterator it = outputs[0].execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = - (DenseVector) row.getField(agglomerativeClustering.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(agglomerativeClustering.getFeaturesCol()); int clusterId = (Integer) row.getField(agglomerativeClustering.getPredictionCol()); System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/KMeansExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/KMeansExample.java index 62edf2e0d..5f5678fe1 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/KMeansExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/KMeansExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.clustering.kmeans.KMeans; import org.apache.flink.ml.clustering.kmeans.KMeansModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -36,7 +36,7 @@ public static void main(String[] args) { StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. - DataStream inputStream = + DataStream inputStream = env.fromElements( Vectors.dense(0.0, 0.0), Vectors.dense(0.0, 0.3), @@ -58,7 +58,8 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(kmeans.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(kmeans.getFeaturesCol()); int clusterId = (Integer) row.getField(kmeans.getPredictionCol()); System.out.printf("Features: %s \tCluster ID: %s\n", features, clusterId); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/OnlineKMeansExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/OnlineKMeansExample.java index b7b7eb3bb..579f0c113 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/OnlineKMeansExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/clustering/OnlineKMeansExample.java @@ -23,9 +23,9 @@ import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; import org.apache.flink.ml.examples.util.PeriodicSourceFunction; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -73,13 +73,14 @@ public static void main(String[] args) { SourceFunction trainSource = new PeriodicSourceFunction(1000, Arrays.asList(trainData1, trainData2)); DataStream trainStream = - env.addSource(trainSource, new RowTypeInfo(DenseVectorTypeInfo.INSTANCE)); + env.addSource(trainSource, new RowTypeInfo(DenseIntDoubleVectorTypeInfo.INSTANCE)); Table trainTable = tEnv.fromDataStream(trainStream).as("features"); SourceFunction predictSource = new PeriodicSourceFunction(1000, Collections.singletonList(predictData)); DataStream predictStream = - env.addSource(predictSource, new RowTypeInfo(DenseVectorTypeInfo.INSTANCE)); + env.addSource( + predictSource, new RowTypeInfo(DenseIntDoubleVectorTypeInfo.INSTANCE)); Table predictTable = tEnv.fromDataStream(predictStream).as("features"); // Creates an online K-means object and initializes its parameters and initial model data. @@ -102,10 +103,12 @@ public static void main(String[] args) { // would change over time. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row1 = it.next(); - DenseVector features1 = (DenseVector) row1.getField(onlineKMeans.getFeaturesCol()); + DenseIntDoubleVector features1 = + (DenseIntDoubleVector) row1.getField(onlineKMeans.getFeaturesCol()); Integer clusterId1 = (Integer) row1.getField(onlineKMeans.getPredictionCol()); Row row2 = it.next(); - DenseVector features2 = (DenseVector) row2.getField(onlineKMeans.getFeaturesCol()); + DenseIntDoubleVector features2 = + (DenseIntDoubleVector) row2.getField(onlineKMeans.getFeaturesCol()); Integer clusterId2 = (Integer) row2.getField(onlineKMeans.getPredictionCol()); if (Objects.equals(clusterId1, clusterId2)) { System.out.printf("%s and %s are now in the same cluster.\n", features1, features2); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java index fb1287caa..b28a37869 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.countvectorizer.CountVectorizer; import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -62,7 +62,8 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); String[] inputValue = (String[]) row.getField(countVectorizer.getInputCol()); - SparseVector outputValue = (SparseVector) row.getField(countVectorizer.getOutputCol()); + SparseIntDoubleVector outputValue = + (SparseIntDoubleVector) row.getField(countVectorizer.getOutputCol()); System.out.printf( "Input Value: %-15s \tOutput Value: %s\n", Arrays.toString(inputValue), outputValue.toString()); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/DCTExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/DCTExample.java index 2b3d68316..afb4987cd 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/DCTExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/DCTExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.dct.DCT; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -37,7 +37,7 @@ public static void main(String[] args) { StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); // Generates input data. - List inputData = + List inputData = Arrays.asList( Vectors.dense(1.0, 1.0, 1.0, 1.0), Vectors.dense(1.0, 0.0, -1.0, 0.0)); Table inputTable = tEnv.fromDataStream(env.fromCollection(inputData)).as("input"); @@ -52,8 +52,8 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = row.getFieldAs(dct.getInputCol()); - Vector outputValue = row.getFieldAs(dct.getOutputCol()); + IntDoubleVector inputValue = row.getFieldAs(dct.getInputCol()); + IntDoubleVector outputValue = row.getFieldAs(dct.getOutputCol()); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ElementwiseProductExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ElementwiseProductExample.java index 927133330..386b9da71 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ElementwiseProductExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ElementwiseProductExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -56,8 +56,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = (Vector) row.getField(elementwiseProduct.getInputCol()); - Vector outputValue = (Vector) row.getField(elementwiseProduct.getOutputCol()); + IntDoubleVector inputValue = + (IntDoubleVector) row.getField(elementwiseProduct.getInputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(elementwiseProduct.getOutputCol()); System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java index c0f81c62f..fbbb1827c 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.featurehasher.FeatureHasher; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -61,7 +61,8 @@ public static void main(String[] args) { for (int i = 0; i < inputValues.length; i++) { inputValues[i] = row.getField(featureHash.getInputCols()[i]); } - Vector outputValue = (Vector) row.getField(featureHash.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(featureHash.getOutputCol()); System.out.printf( "Input Values: %s \tOutput Value: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java index 213ebd5cd..4927c4a50 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.hashingtf.HashingTF; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -60,7 +60,8 @@ public static void main(String[] args) { Row row = it.next(); List inputValue = (List) row.getField(hashingTF.getInputCol()); - SparseVector outputValue = (SparseVector) row.getField(hashingTF.getOutputCol()); + SparseIntDoubleVector outputValue = + (SparseIntDoubleVector) row.getField(hashingTF.getOutputCol()); System.out.printf( "Input Value: %s \tOutput Value: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java index ffa4d710a..53bd93ec5 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.idf.IDF; import org.apache.flink.ml.feature.idf.IDFModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -56,8 +56,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(idf.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(idf.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(idf.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(idf.getOutputCol()); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/InteractionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/InteractionExample.java index 6acb89038..1115baefa 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/InteractionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/InteractionExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.interaction.Interaction; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -58,7 +58,8 @@ public static void main(String[] args) { for (int i = 0; i < inputValues.length; i++) { inputValues[i] = row.getField(interaction.getInputCols()[i]); } - Vector outputValue = (Vector) row.getField(interaction.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(interaction.getOutputCol()); System.out.printf( "Input Values: %s \tOutput Value: %s\n", Arrays.toString(inputValues), outputValue); diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java index 1478f8c84..bcc05684d 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java @@ -21,7 +21,7 @@ import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -63,8 +63,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(kBinsDiscretizer.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(kBinsDiscretizer.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(kBinsDiscretizer.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(kBinsDiscretizer.getOutputCol()); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java index cd53394f9..33826c989 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler; import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -64,8 +64,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(maxAbsScaler.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(maxAbsScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(maxAbsScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(maxAbsScaler.getOutputCol()); System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinHashLSHExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinHashLSHExample.java index 89c6bd624..c155ff11b 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinHashLSHExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinHashLSHExample.java @@ -22,9 +22,9 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.feature.lsh.MinHashLSH; import org.apache.flink.ml.feature.lsh.MinHashLSHModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -77,7 +77,7 @@ public static void main(String[] args) throws Exception { Types.ROW_NAMED( new String[] {"id", "vec"}, Types.INT, - TypeInformation.of(SparseVector.class)))); + TypeInformation.of(SparseIntDoubleVector.class)))); Table dataB = tEnv.fromDataStream( @@ -104,7 +104,7 @@ public static void main(String[] args) throws Exception { Types.ROW_NAMED( new String[] {"id", "vec"}, Types.INT, - TypeInformation.of(SparseVector.class)))); + TypeInformation.of(SparseIntDoubleVector.class)))); // Creates a MinHashLSH estimator object and initializes its parameters. MinHashLSH lsh = @@ -124,14 +124,15 @@ public static void main(String[] args) throws Exception { List fieldNames = output.getResolvedSchema().getColumnNames(); for (Row result : (List) IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect())) { - Vector inputValue = result.getFieldAs(fieldNames.indexOf(lsh.getInputCol())); - DenseVector[] outputValue = result.getFieldAs(fieldNames.indexOf(lsh.getOutputCol())); + IntDoubleVector inputValue = result.getFieldAs(fieldNames.indexOf(lsh.getInputCol())); + DenseIntDoubleVector[] outputValue = + result.getFieldAs(fieldNames.indexOf(lsh.getOutputCol())); System.out.printf( "Vector: %s \tHash values: %s\n", inputValue, Arrays.toString(outputValue)); } // Finds approximate nearest neighbors of the key. - Vector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1., 1.}); + IntDoubleVector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1., 1.}); output = model.approxNearestNeighbors(dataA, key, 2).select($("id"), $("distCol")); for (Row result : (List) IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect())) { diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinMaxScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinMaxScalerExample.java index 98e908ff4..cde8541c1 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinMaxScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MinMaxScalerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -64,8 +64,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(minMaxScaler.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(minMaxScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(minMaxScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(minMaxScaler.getOutputCol()); System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NormalizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NormalizerExample.java index b0c6332b2..26a813b4e 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NormalizerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/NormalizerExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.normalizer.Normalizer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -52,9 +52,9 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = (Vector) row.getField(normalizer.getInputCol()); + IntDoubleVector inputValue = (IntDoubleVector) row.getField(normalizer.getInputCol()); - Vector outputValue = (Vector) row.getField(normalizer.getOutputCol()); + IntDoubleVector outputValue = (IntDoubleVector) row.getField(normalizer.getOutputCol()); System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OneHotEncoderExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OneHotEncoderExample.java index 1ebc8d95b..976ffa922 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OneHotEncoderExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OneHotEncoderExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -56,8 +56,8 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); Double inputValue = (Double) row.getField(oneHotEncoder.getInputCols()[0]); - SparseVector outputValue = - (SparseVector) row.getField(oneHotEncoder.getOutputCols()[0]); + SparseIntDoubleVector outputValue = + (SparseIntDoubleVector) row.getField(oneHotEncoder.getOutputCols()[0]); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java index d16ed12e3..8675e1445 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/OnlineStandardScalerExample.java @@ -24,9 +24,9 @@ import org.apache.flink.ml.common.window.EventTimeTumblingWindows; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.DataTypes; @@ -73,7 +73,10 @@ public static void main(String[] args) { inputStreamWithEventTime, Schema.newBuilder() .column("f0", DataTypes.BIGINT()) - .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .column( + "f1", + DataTypes.RAW( + DenseIntDoubleVectorTypeInfo.INSTANCE)) .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") .watermark("rowtime", "SOURCE_WATERMARK()") .build()) @@ -94,9 +97,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(onlineStandardScaler.getInputCol()); - DenseVector outputValue = - (DenseVector) row.getField(onlineStandardScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(onlineStandardScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(onlineStandardScaler.getOutputCol()); long modelVersion = row.getFieldAs(onlineStandardScaler.getModelVersionCol()); System.out.printf( "Input Value: %s\tOutput Value: %-65s\tModel Version: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/PolynomialExpansionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/PolynomialExpansionExample.java index e4e6e20d1..632981978 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/PolynomialExpansionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/PolynomialExpansionExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -56,9 +56,11 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = (Vector) row.getField(polynomialExpansion.getInputCol()); + IntDoubleVector inputValue = + (IntDoubleVector) row.getField(polynomialExpansion.getInputCol()); - Vector outputValue = (Vector) row.getField(polynomialExpansion.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(polynomialExpansion.getOutputCol()); System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java index 04f082375..98b6e0396 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.robustscaler.RobustScaler; import org.apache.flink.ml.feature.robustscaler.RobustScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -67,8 +67,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(robustScaler.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(robustScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(robustScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(robustScaler.getOutputCol()); System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/StandardScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/StandardScalerExample.java index 571a58a97..5d38290bc 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/StandardScalerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/StandardScalerExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.standardscaler.StandardScaler; import org.apache.flink.ml.feature.standardscaler.StandardScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -55,8 +55,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = (DenseVector) row.getField(standardScaler.getInputCol()); - DenseVector outputValue = (DenseVector) row.getField(standardScaler.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(standardScaler.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(standardScaler.getOutputCol()); System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java index 4d4c07f6a..a5a82751f 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector; import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -67,10 +67,10 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = - (DenseVector) row.getField(univariateFeatureSelector.getFeaturesCol()); - DenseVector outputValue = - (DenseVector) row.getField(univariateFeatureSelector.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(univariateFeatureSelector.getFeaturesCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(univariateFeatureSelector.getOutputCol()); System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java index d441a3bcb..2ec96fb9f 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector; import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -67,10 +67,10 @@ public static void main(String[] args) { System.out.printf("Variance Threshold: %s\n", threshold); for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector inputValue = - (DenseVector) row.getField(varianceThresholdSelector.getInputCol()); - DenseVector outputValue = - (DenseVector) row.getField(varianceThresholdSelector.getOutputCol()); + DenseIntDoubleVector inputValue = + (DenseIntDoubleVector) row.getField(varianceThresholdSelector.getInputCol()); + DenseIntDoubleVector outputValue = + (DenseIntDoubleVector) row.getField(varianceThresholdSelector.getOutputCol()); System.out.printf("Input Values: %-15s\tOutput Values: %s\n", inputValue, outputValue); } } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java index 0c146251e..5480c58ed 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -71,7 +71,8 @@ public static void main(String[] args) { inputValues[i] = row.getField(vectorAssembler.getInputCols()[i]); } - Vector outputValue = (Vector) row.getField(vectorAssembler.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(vectorAssembler.getOutputCol()); System.out.printf( "Input Values: %s \tOutput Value: %s\n", diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java index 7c138992c..b5e7d91af 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.examples.feature; import org.apache.flink.ml.feature.vectorslicer.VectorSlicer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -52,9 +52,10 @@ public static void main(String[] args) { for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - Vector inputValue = (Vector) row.getField(vectorSlicer.getInputCol()); + IntDoubleVector inputValue = (IntDoubleVector) row.getField(vectorSlicer.getInputCol()); - Vector outputValue = (Vector) row.getField(vectorSlicer.getOutputCol()); + IntDoubleVector outputValue = + (IntDoubleVector) row.getField(vectorSlicer.getOutputCol()); System.out.printf("Input Value: %s \tOutput Value: %s\n", inputValue, outputValue); } diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/regression/LinearRegressionExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/regression/LinearRegressionExample.java index bc39f3adb..73599c54b 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/regression/LinearRegressionExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/regression/LinearRegressionExample.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.examples.regression; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; @@ -60,7 +60,8 @@ public static void main(String[] args) { // Extracts and displays the results. for (CloseableIterator it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); - DenseVector features = (DenseVector) row.getField(lr.getFeaturesCol()); + DenseIntDoubleVector features = + (DenseIntDoubleVector) row.getField(lr.getFeaturesCol()); double expectedResult = (Double) row.getField(lr.getLabelCol()); double predictionResult = (Double) row.getField(lr.getPredictionCol()); System.out.printf( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/Functions.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/Functions.java index 694b35e4f..a798852a3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/Functions.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/Functions.java @@ -18,10 +18,11 @@ package org.apache.flink.ml; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.table.api.ApiExpression; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.catalog.DataTypeFactory; @@ -37,18 +38,18 @@ /** Built-in table functions for data transformations. */ @SuppressWarnings("unused") public class Functions { - /** Converts a column of {@link Vector}s into a column of double arrays. */ + /** Converts a column of {@link IntDoubleVector}s into a column of double arrays. */ public static ApiExpression vectorToArray(Object... arguments) { return call(VectorToArrayFunction.class, arguments); } /** - * A {@link ScalarFunction} that converts a column of {@link Vector}s into a column of double - * arrays. + * A {@link ScalarFunction} that converts a column of {@link IntDoubleVector}s into a column of + * double arrays. */ public static class VectorToArrayFunction extends ScalarFunction { public double[] eval(Vector vector) { - return vector.toArray(); + return (double[]) vector.toArray(); } @Override @@ -66,7 +67,8 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) { } /** - * Converts a column of arrays of numeric type into a column of {@link DenseVector} instances. + * Converts a column of arrays of numeric type into a column of {@link DenseIntDoubleVector} + * instances. */ public static ApiExpression arrayToVector(Object... arguments) { return call(ArrayToVectorFunction.class, arguments); @@ -74,18 +76,18 @@ public static ApiExpression arrayToVector(Object... arguments) { /** * A {@link ScalarFunction} that converts a column of arrays of numeric type into a column of - * {@link DenseVector} instances. + * {@link DenseIntDoubleVector} instances. */ public static class ArrayToVectorFunction extends ScalarFunction { - public DenseVector eval(double[] array) { + public DenseIntDoubleVector eval(double[] array) { return Vectors.dense(array); } - public DenseVector eval(Double[] array) { + public DenseIntDoubleVector eval(Double[] array) { return eval(ArrayUtils.toPrimitive(array)); } - public DenseVector eval(Number[] array) { + public DenseIntDoubleVector eval(Number[] array) { double[] doubles = new double[array.length]; for (int i = 0; i < array.length; i++) { doubles[i] = array[i].doubleValue(); @@ -99,7 +101,7 @@ public TypeInference getTypeInference(DataTypeFactory typeFactory) { .outputTypeStrategy( callContext -> Optional.of( - DataTypes.of(DenseVectorTypeInfo.INSTANCE) + DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE) .toDataType(typeFactory))) .build(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java index 8ad15e276..2c9124ba3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java @@ -24,9 +24,9 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.DenseMatrix; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -63,7 +63,7 @@ public KnnModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); /* Tuple3 : */ - DataStream> inputDataWithNorm = + DataStream> inputDataWithNorm = computeNormSquare(tEnv.toDataStream(inputs[0])); DataStream modelData = genModelData(inputDataWithNorm); KnnModel model = new KnnModel().setModelData(tEnv.fromDataStream(modelData)); @@ -96,29 +96,32 @@ public static Knn load(StreamTableEnvironment tEnv, String path) throws IOExcept * @return Knn model. */ private static DataStream genModelData( - DataStream> inputDataWithNormSqare) { + DataStream> inputDataWithNormSqare) { DataStream modelData = DataStreamUtils.mapPartition( inputDataWithNormSqare, new RichMapPartitionFunction< - Tuple3, KnnModelData>() { + Tuple3, KnnModelData>() { @Override public void mapPartition( - Iterable> dataPoints, + Iterable> + dataPoints, Collector out) { - List> bufferedDataPoints = - new ArrayList<>(); - for (Tuple3 dataPoint : dataPoints) { + List> + bufferedDataPoints = new ArrayList<>(); + for (Tuple3 dataPoint : + dataPoints) { bufferedDataPoints.add(dataPoint); } int featureDim = bufferedDataPoints.get(0).f0.size(); DenseMatrix packedFeatures = new DenseMatrix(featureDim, bufferedDataPoints.size()); - DenseVector normSquares = - new DenseVector(bufferedDataPoints.size()); - DenseVector labels = new DenseVector(bufferedDataPoints.size()); + DenseIntDoubleVector normSquares = + new DenseIntDoubleVector(bufferedDataPoints.size()); + DenseIntDoubleVector labels = + new DenseIntDoubleVector(bufferedDataPoints.size()); int offset = 0; - for (Tuple3 dataPoint : + for (Tuple3 dataPoint : bufferedDataPoints) { System.arraycopy( dataPoint.f0.values, @@ -142,14 +145,15 @@ public void mapPartition( * @param inputData Input data. * @return Input data with norm square. */ - private DataStream> computeNormSquare( + private DataStream> computeNormSquare( DataStream inputData) { return inputData.map( - new MapFunction>() { + new MapFunction>() { @Override - public Tuple3 map(Row value) { + public Tuple3 map(Row value) { Double label = ((Number) value.getField(getLabelCol())).doubleValue(); - DenseVector feature = ((Vector) value.getField(getFeaturesCol())).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) value.getField(getFeaturesCol())).toDense(); return Tuple3.of(feature, label, Math.pow(BLAS.norm2(feature), 2)); } }); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java index 3194fbb29..eb32e20fc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -130,7 +130,7 @@ private static class PredictLabelFunction extends RichMapFunction { private KnnModelData knnModelData; private final int k; private final String broadcastKey; - private DenseVector distanceVector; + private DenseIntDoubleVector distanceVector; public PredictLabelFunction(String broadcastKey, int k, String featureCol) { this.k = k; @@ -144,14 +144,14 @@ public Row map(Row row) { knnModelData = (KnnModelData) getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); - distanceVector = new DenseVector(knnModelData.labels.size()); + distanceVector = new DenseIntDoubleVector(knnModelData.labels.size()); } - DenseVector feature = ((Vector) row.getField(featureCol)).toDense(); + DenseIntDoubleVector feature = ((IntDoubleVector) row.getField(featureCol)).toDense(); double prediction = predictLabel(feature); return Row.join(row, Row.of(prediction)); } - private double predictLabel(DenseVector feature) { + private double predictLabel(DenseIntDoubleVector feature) { double normSquare = Math.pow(BLAS.norm2(feature), 2); BLAS.gemv(-2.0, knnModelData.packedFeatures, true, feature, 0.0, distanceVector); for (int i = 0; i < distanceVector.size(); i++) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java index 8e4a5c524..d146e8a3d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java @@ -27,12 +27,12 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.DenseMatrix; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Matrix; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -51,13 +51,15 @@ public class KnnModelData { public DenseMatrix packedFeatures; - public DenseVector featureNormSquares; - public DenseVector labels; + public DenseIntDoubleVector featureNormSquares; + public DenseIntDoubleVector labels; public KnnModelData() {} public KnnModelData( - DenseMatrix packedFeatures, DenseVector featureNormSquares, DenseVector labels) { + DenseMatrix packedFeatures, + DenseIntDoubleVector featureNormSquares, + DenseIntDoubleVector labels) { this.packedFeatures = packedFeatures; this.featureNormSquares = featureNormSquares; this.labels = labels; @@ -77,13 +79,14 @@ public static DataStream getModelDataStream(Table modelDataTable) x -> new KnnModelData( ((Matrix) x.getField(0)).toDense(), - ((Vector) x.getField(1)).toDense(), - ((Vector) x.getField(2)).toDense())); + ((IntDoubleVector) x.getField(1)).toDense(), + ((IntDoubleVector) x.getField(2)).toDense())); } /** Encoder for {@link KnnModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(KnnModelData modelData, OutputStream outputStream) throws IOException { @@ -102,14 +105,15 @@ public Reader createReader(Configuration config, FSDataInputStream private final DataInputView source = new DataInputViewStreamWrapper(stream); - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public KnnModelData read() throws IOException { try { DenseMatrix matrix = DenseMatrixSerializer.INSTANCE.deserialize(source); - DenseVector normSquares = serializer.deserialize(source); - DenseVector labels = serializer.deserialize(source); + DenseIntDoubleVector normSquares = serializer.deserialize(source); + DenseIntDoubleVector labels = serializer.deserialize(source); return new KnnModelData(matrix, normSquares, labels); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java index 30c166b43..7fce6fe20 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java @@ -25,8 +25,8 @@ import org.apache.flink.ml.common.lossfunc.HingeLoss; import org.apache.flink.ml.common.optimizer.Optimizer; import org.apache.flink.ml.common.optimizer.SGD; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -77,13 +77,13 @@ public LinearSVCModel fit(Table... inputs) { || Double.compare(1.0, label) == 0, "LinearSVC only supports binary classification. But detected label: %s.", label); - DenseVector features = - ((Vector) dataPoint.getField(getFeaturesCol())) + DenseIntDoubleVector features = + ((IntDoubleVector) dataPoint.getField(getFeaturesCol())) .toDense(); return new LabeledPointWithWeight(features, label, weight); }); - DataStream initModelData = + DataStream initModelData = DataStreamUtils.reduce( trainData.map(x -> x.getFeatures().size()), (ReduceFunction) @@ -93,7 +93,7 @@ public LinearSVCModel fit(Table... inputs) { "The training data should all have same dimensions."); return t0; }) - .map(DenseVector::new); + .map(DenseIntDoubleVector::new); Optimizer optimizer = new SGD( @@ -103,7 +103,7 @@ public LinearSVCModel fit(Table... inputs) { getTol(), getReg(), getElasticNet()); - DataStream rawModelData = + DataStream rawModelData = optimizer.optimize(initModelData, trainData, HingeLoss.INSTANCE); DataStream modelData = rawModelData.map(LinearSVCModelData::new); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java index a7da2e31c..ce9073578 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java @@ -25,10 +25,10 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -75,7 +75,7 @@ public Table[] transform(Table... inputs) { ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO, - DenseVectorTypeInfo.INSTANCE), + DenseIntDoubleVectorTypeInfo.INSTANCE), ArrayUtils.addAll( inputTypeInfo.getFieldNames(), getPredictionCol(), @@ -136,7 +136,7 @@ private static class PredictLabelFunction extends RichMapFunction { private final double threshold; - private DenseVector coefficient; + private DenseIntDoubleVector coefficient; public PredictLabelFunction( String broadcastModelKey, String featuresCol, double threshold) { @@ -153,7 +153,8 @@ public Row map(Row dataPoint) { getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); coefficient = modelData.coefficient; } - DenseVector features = ((Vector) dataPoint.getField(featuresCol)).toDense(); + DenseIntDoubleVector features = + ((IntDoubleVector) dataPoint.getField(featuresCol)).toDense(); Row predictionResult = predictOneDataPoint(features, coefficient, threshold); return Row.join(dataPoint, predictionResult); } @@ -168,7 +169,7 @@ public Row map(Row dataPoint) { * @return The prediction label and the raw predictions. */ private static Row predictOneDataPoint( - DenseVector feature, DenseVector coefficient, double threshold) { + DenseIntDoubleVector feature, DenseIntDoubleVector coefficient, double threshold) { double dotValue = BLAS.dot(feature, coefficient); return Row.of(dotValue >= threshold ? 1.0 : 0.0, Vectors.dense(dotValue, -dotValue)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java index 771c10473..3f07e0f6e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java @@ -25,9 +25,9 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -45,9 +45,9 @@ */ public class LinearSVCModelData { - public DenseVector coefficient; + public DenseIntDoubleVector coefficient; - public LinearSVCModelData(DenseVector coefficient) { + public LinearSVCModelData(DenseIntDoubleVector coefficient) { this.coefficient = coefficient; } @@ -63,12 +63,13 @@ public static DataStream getModelDataStream(Table modelData) StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); return tEnv.toDataStream(modelData) - .map(x -> new LinearSVCModelData(((Vector) x.getField(0)).toDense())); + .map(x -> new LinearSVCModelData(((IntDoubleVector) x.getField(0)).toDense())); } /** Data encoder for {@link LinearSVCModel}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(LinearSVCModelData modelData, OutputStream outputStream) @@ -85,12 +86,13 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration configuration, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public LinearSVCModelData read() throws IOException { try { - DenseVector coefficient = + DenseIntDoubleVector coefficient = serializer.deserialize(new DataInputViewStreamWrapper(inputStream)); return new LinearSVCModelData(coefficient); } catch (EOFException e) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index e7a896059..41e5ec8f9 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -25,8 +25,8 @@ import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; import org.apache.flink.ml.common.optimizer.Optimizer; import org.apache.flink.ml.common.optimizer.SGD; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -85,12 +85,13 @@ public LogisticRegressionModel fit(Table... inputs) { throw new RuntimeException( "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); } - Vector features = - ((Vector) dataPoint.getField(getFeaturesCol())); + IntDoubleVector features = + ((IntDoubleVector) + dataPoint.getField(getFeaturesCol())); return new LabeledPointWithWeight(features, label, weight); }); - DataStream initModelData = + DataStream initModelData = DataStreamUtils.reduce( trainData.map(x -> x.getFeatures().size()), (ReduceFunction) @@ -100,7 +101,7 @@ public LogisticRegressionModel fit(Table... inputs) { "The training data should all have same dimensions."); return t0; }) - .map(DenseVector::new); + .map(DenseIntDoubleVector::new); Optimizer optimizer = new SGD( @@ -110,7 +111,7 @@ public LogisticRegressionModel fit(Table... inputs) { getTol(), getReg(), getElasticNet()); - DataStream rawModelData = + DataStream rawModelData = optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE); DataStream modelData = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java index 1f7176d48..2ab502299 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -75,7 +75,7 @@ public Table[] transform(Table... inputs) { ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO, - TypeInformation.of(DenseVector.class)), + TypeInformation.of(DenseIntDoubleVector.class)), ArrayUtils.addAll( inputTypeInfo.getFieldNames(), getPredictionCol(), @@ -160,9 +160,10 @@ public Row map(Row dataPoint) { } ParamUtils.updateExistingParams(servable, params); } - Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol()); + IntDoubleVector features = + (IntDoubleVector) dataPoint.getField(servable.getFeaturesCol()); - Tuple2 predictionResult = servable.transform(features); + Tuple2 predictionResult = servable.transform(features); return Row.join(dataPoint, Row.of(predictionResult.f0, predictionResult.f1)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java index 5b4a4f4a5..562842536 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java @@ -25,7 +25,7 @@ import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -70,7 +70,7 @@ public RandomModelDataGenerator(int dim, int seed) { @Override public LogisticRegressionModelData map(Integer integer) throws Exception { - DenseVector vector = new DenseVector(dim); + DenseIntDoubleVector vector = new DenseIntDoubleVector(dim); Random random = new Random(seed); for (int j = 0; j < dim; j++) { vector.values[j] = random.nextDouble(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java index 1bc19938f..75a7ce722 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java @@ -35,9 +35,9 @@ import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -114,9 +114,9 @@ public OnlineLogisticRegressionModel fit(Table... inputs) { getFeaturesCol(), getLabelCol(), getWeightCol()), pointTypeInfo); - DataStream initModelData = + DataStream initModelData = modelDataStream.map( - (MapFunction) + (MapFunction) value -> value.coefficient); initModelData.getTransformation().setParallelism(1); @@ -189,7 +189,7 @@ public FtrlIterationBody( @Override public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { - DataStream modelData = variableStreams.get(0); + DataStream modelData = variableStreams.get(0); DataStream points = dataStreams.get(0); int parallelism = points.getParallelism(); @@ -198,17 +198,17 @@ public IterationBodyResult process( "There are more subtasks in the training process than the number " + "of elements in each batch. Some subtasks might be idling forever."); - DataStream newGradient = + DataStream newGradient = DataStreamUtils.generateBatchData(points, parallelism, batchSize) .connect(modelData.broadcast()) .transform( "LocalGradientCalculator", - TypeInformation.of(DenseVector[].class), + TypeInformation.of(DenseIntDoubleVector[].class), new CalculateLocalGradient()) .setParallelism(parallelism) .countWindowAll(parallelism) .reduce( - (ReduceFunction) + (ReduceFunction) (gradientInfo, newGradientInfo) -> { BLAS.axpy(1.0, gradientInfo[0], newGradientInfo[0]); BLAS.axpy(1.0, gradientInfo[1], newGradientInfo[1]); @@ -217,11 +217,11 @@ public IterationBodyResult process( } return newGradientInfo; }); - DataStream feedbackModelData = + DataStream feedbackModelData = newGradient .transform( "ModelDataUpdater", - TypeInformation.of(DenseVector.class), + TypeInformation.of(DenseIntDoubleVector.class), new UpdateModel(alpha, beta, l1, l2)) .setParallelism(1); @@ -233,12 +233,13 @@ public IterationBodyResult process( } private static class CreateLrModelData - implements MapFunction, CheckpointedFunction { + implements MapFunction, + CheckpointedFunction { private Long modelVersion = 1L; private transient ListState modelVersionState; @Override - public LogisticRegressionModelData map(DenseVector denseVector) throws Exception { + public LogisticRegressionModelData map(DenseIntDoubleVector denseVector) throws Exception { return new LogisticRegressionModelData(denseVector, modelVersion++); } @@ -258,8 +259,8 @@ public void initializeState(FunctionInitializationContext context) throws Except } /** Updates model. */ - private static class UpdateModel extends AbstractStreamOperator - implements OneInputStreamOperator { + private static class UpdateModel extends AbstractStreamOperator + implements OneInputStreamOperator { private ListState nParamState; private ListState zParamState; private final double alpha; @@ -288,8 +289,9 @@ public void initializeState(StateInitializationContext context) throws Exception } @Override - public void processElement(StreamRecord streamRecord) throws Exception { - DenseVector[] gradientInfo = streamRecord.getValue(); + public void processElement(StreamRecord streamRecord) + throws Exception { + DenseIntDoubleVector[] gradientInfo = streamRecord.getValue(); double[] coefficient = gradientInfo[2].values; double[] g = gradientInfo[0].values; for (int i = 0; i < g.length; ++i) { @@ -317,13 +319,14 @@ public void processElement(StreamRecord streamRecord) throws Exce / ((beta + Math.sqrt(nParam[i])) / alpha + l2); } } - output.collect(new StreamRecord<>(new DenseVector(coefficient))); + output.collect(new StreamRecord<>(new DenseIntDoubleVector(coefficient))); } } - private static class CalculateLocalGradient extends AbstractStreamOperator - implements TwoInputStreamOperator { - private ListState modelDataState; + private static class CalculateLocalGradient + extends AbstractStreamOperator + implements TwoInputStreamOperator { + private ListState modelDataState; private ListState localBatchDataState; private double[] gradient; private double[] weightSum; @@ -334,7 +337,8 @@ public void initializeState(StateInitializationContext context) throws Exception modelDataState = context.getOperatorStateStore() .getListState( - new ListStateDescriptor<>("modelData", DenseVector.class)); + new ListStateDescriptor<>( + "modelData", DenseIntDoubleVector.class)); TypeInformation type = ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class)); localBatchDataState = @@ -353,7 +357,7 @@ private void calculateGradient() throws Exception { || !localBatchDataState.get().iterator().hasNext()) { return; } - DenseVector modelData = + DenseIntDoubleVector modelData = OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get(); modelDataState.clear(); @@ -362,7 +366,7 @@ private void calculateGradient() throws Exception { localBatchDataState.update(pointsList); for (Row point : points) { - Vector vec = point.getFieldAs(0); + IntDoubleVector vec = point.getFieldAs(0); double label = point.getFieldAs(1); double weight = point.getArity() == 2 ? 1.0 : point.getFieldAs(2); if (gradient == null) { @@ -371,14 +375,14 @@ private void calculateGradient() throws Exception { } double p = BLAS.dot(modelData, vec); p = 1 / (1 + Math.exp(-p)); - if (vec instanceof DenseVector) { - DenseVector dvec = (DenseVector) vec; + if (vec instanceof DenseIntDoubleVector) { + DenseIntDoubleVector dvec = (DenseIntDoubleVector) vec; for (int i = 0; i < modelData.size(); ++i) { gradient[i] += (p - label) * dvec.values[i]; weightSum[i] += 1.0; } } else { - SparseVector svec = (SparseVector) vec; + SparseIntDoubleVector svec = (SparseIntDoubleVector) vec; for (int i = 0; i < svec.indices.length; ++i) { int idx = svec.indices[i]; gradient[idx] += (p - label) * svec.values[i]; @@ -390,9 +394,9 @@ private void calculateGradient() throws Exception { if (points.length > 0) { output.collect( new StreamRecord<>( - new DenseVector[] { - new DenseVector(gradient), - new DenseVector(weightSum), + new DenseIntDoubleVector[] { + new DenseIntDoubleVector(gradient), + new DenseIntDoubleVector(weightSum), (getRuntimeContext().getIndexOfThisSubtask() == 0) ? modelData : null @@ -403,7 +407,8 @@ private void calculateGradient() throws Exception { } @Override - public void processElement2(StreamRecord modelDataRecord) throws Exception { + public void processElement2(StreamRecord modelDataRecord) + throws Exception { modelDataState.add(modelDataRecord.getValue()); calculateGradient(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java index f06086132..81f0c9979 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java @@ -27,8 +27,8 @@ import org.apache.flink.metrics.Gauge; import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -76,7 +76,7 @@ public Table[] transform(Table... inputs) { ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), Types.DOUBLE, - TypeInformation.of(DenseVector.class), + TypeInformation.of(DenseIntDoubleVector.class), Types.LONG), ArrayUtils.addAll( inputTypeInfo.getFieldNames(), @@ -104,7 +104,7 @@ private static class PredictLabelOperator extends AbstractStreamOperator private final Map, Object> params; private ListState bufferedPointsState; - private DenseVector coefficient; + private DenseIntDoubleVector coefficient; private long modelDataVersion = 0L; private LogisticRegressionModelServable servable; @@ -162,8 +162,9 @@ public void processElement(StreamRecord streamRecord) throws Exception { new LogisticRegressionModelData(coefficient, 0L)); ParamUtils.updateExistingParams(servable, params); } - Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol()); - Tuple2 predictionResult = servable.transform(features); + IntDoubleVector features = + (IntDoubleVector) dataPoint.getField(servable.getFeaturesCol()); + Tuple2 predictionResult = servable.transform(features); output.collect( new StreamRecord<>( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java index 404c5d19a..8c3935e9b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java @@ -27,9 +27,9 @@ import org.apache.flink.api.java.tuple.Tuple4; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -74,10 +74,10 @@ public NaiveBayesModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream> input = + DataStream> input = tEnv.toDataStream(inputs[0]) .map( - (MapFunction>) + (MapFunction>) row -> { Number number = (Number) row.getField(labelCol); Preconditions.checkNotNull( @@ -87,7 +87,7 @@ public NaiveBayesModel fit(Table... inputs) { number.intValue() == number.doubleValue(), "Label value should be indexed number."); return new Tuple2<>( - (Vector) row.getField(featuresCol), + (IntDoubleVector) row.getField(featuresCol), number.doubleValue()); }, Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE)); @@ -129,8 +129,8 @@ public NaiveBayesModel fit(Table... inputs) { DataTypes.ARRAY( DataTypes.MAP( DataTypes.DOUBLE(), DataTypes.DOUBLE())))) - .column("piArray", DataTypes.of(DenseVectorTypeInfo.INSTANCE)) - .column("labels", DataTypes.of(DenseVectorTypeInfo.INSTANCE)) + .column("piArray", DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE)) + .column("labels", DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE)) .build(); NaiveBayesModel model = @@ -165,10 +165,11 @@ public Map, Object> getParamMap() { * */ private static class ExtractFeatureFunction - implements FlatMapFunction, Tuple3> { + implements FlatMapFunction< + Tuple2, Tuple3> { @Override public void flatMap( - Tuple2 value, + Tuple2 value, Collector> collector) { Preconditions.checkNotNull(value.f1); for (int i = 0; i < value.f0.size(); i++) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java index 1ac9c00dd..05c7b5564 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -146,13 +146,13 @@ public Row map(Row row) { (NaiveBayesModelData) getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); } - Vector vector = (Vector) row.getField(featuresCol); + IntDoubleVector vector = (IntDoubleVector) row.getField(featuresCol); double label = findMaxProbLabel(calculateProb(modelData, vector), modelData.labels); return Row.join(row, Row.of(label)); } } - private static double findMaxProbLabel(DenseVector prob, Vector label) { + private static double findMaxProbLabel(DenseIntDoubleVector prob, IntDoubleVector label) { double result = 0.; int probSize = prob.size(); double maxVal = Double.NEGATIVE_INFINITY; @@ -167,9 +167,10 @@ private static double findMaxProbLabel(DenseVector prob, Vector label) { } /** Calculate probability of the input data. */ - private static DenseVector calculateProb(NaiveBayesModelData modelData, Vector data) { + private static DenseIntDoubleVector calculateProb( + NaiveBayesModelData modelData, IntDoubleVector data) { int labelSize = modelData.labels.size(); - DenseVector probs = new DenseVector(new double[labelSize]); + DenseIntDoubleVector probs = new DenseIntDoubleVector(new double[labelSize]); for (int i = 0; i < labelSize; i++) { Map[] labelData = modelData.theta[i]; for (int j = 0; j < data.size(); j++) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java index fac5dab32..03fc9f392 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java @@ -29,10 +29,10 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -60,8 +60,8 @@ public class NaiveBayesModelData { fields.put( "theta", Types.OBJECT_ARRAY(Types.OBJECT_ARRAY(Types.MAP(Types.DOUBLE, Types.DOUBLE)))); - fields.put("piArray", DenseVectorTypeInfo.INSTANCE); - fields.put("labels", DenseVectorTypeInfo.INSTANCE); + fields.put("piArray", DenseIntDoubleVectorTypeInfo.INSTANCE); + fields.put("labels", DenseIntDoubleVectorTypeInfo.INSTANCE); } public static final TypeInformation TYPE_INFO = @@ -74,13 +74,15 @@ public class NaiveBayesModelData { public Map[][] theta; /** Log of class priors, whose dimension is C (number of classes). */ - public DenseVector piArray; + public DenseIntDoubleVector piArray; /** Value of labels. */ - public DenseVector labels; + public DenseIntDoubleVector labels; public NaiveBayesModelData( - Map[][] theta, DenseVector piArray, DenseVector labels) { + Map[][] theta, + DenseIntDoubleVector piArray, + DenseIntDoubleVector labels) { this.theta = theta; this.piArray = piArray; this.labels = labels; @@ -103,14 +105,15 @@ public static DataStream getModelDataStream(Table modelData row -> new NaiveBayesModelData( (Map[][]) row.getField(0), - ((Vector) row.getField(1)).toDense(), - ((Vector) row.getField(2)).toDense()), + ((IntDoubleVector) row.getField(1)).toDense(), + ((IntDoubleVector) row.getField(2)).toDense()), TYPE_INFO); } /** Data encoder for the {@link NaiveBayesModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(NaiveBayesModelData modelData, OutputStream outputStream) @@ -141,7 +144,8 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration config, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public NaiveBayesModelData read() throws IOException { @@ -152,9 +156,11 @@ public NaiveBayesModelData read() throws IOException { new MapSerializer<>( DoubleSerializer.INSTANCE, DoubleSerializer.INSTANCE); - DenseVector labels = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector labels = + serializer.deserialize(inputViewStreamWrapper); - DenseVector piArray = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector piArray = + serializer.deserialize(inputViewStreamWrapper); int featureSize = inputViewStreamWrapper.readInt(); int numLabels = inputViewStreamWrapper.readInt(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java index 6fbf39d57..12901edb1 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java @@ -43,10 +43,10 @@ import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound; import org.apache.flink.ml.common.iteration.TerminateOnMaxIter; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorWithNormSerializer; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -89,11 +89,12 @@ public KMeansModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream points = + DataStream points = tEnv.toDataStream(inputs[0]) - .map(row -> ((Vector) row.getField(getFeaturesCol())).toDense()); + .map(row -> ((IntDoubleVector) row.getField(getFeaturesCol())).toDense()); - DataStream initCentroids = selectRandomCentroids(points, getK(), getSeed()); + DataStream initCentroids = + selectRandomCentroids(points, getK(), getSeed()); IterationConfig config = IterationConfig.newBuilder() @@ -144,20 +145,20 @@ public KMeansIterationBody(int maxIterationNum, DistanceMeasure distanceMeasure) @Override public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { - DataStream centroids = variableStreams.get(0); - DataStream points = dataStreams.get(0); + DataStream centroids = variableStreams.get(0); + DataStream points = dataStreams.get(0); DataStream terminationCriteria = centroids.flatMap(new TerminateOnMaxIter(maxIterationNum)); - DataStream> centroidIdAndPoints = + DataStream> centroidIdAndPoints = points.connect(centroids.broadcast()) .transform( "CentroidsUpdateAccumulator", new TupleTypeInfo<>( BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, ObjectArrayTypeInfo.getInfoFor( - DenseVectorTypeInfo.INSTANCE)), + DenseIntDoubleVectorTypeInfo.INSTANCE)), new CentroidsUpdateAccumulator(distanceMeasure)); DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints, 100); @@ -169,7 +170,7 @@ public IterationBodyResult process( .reduce(new CentroidsUpdateReducer()) .map(new ModelDataGenerator()); - DataStream newCentroids = + DataStream newCentroids = newModelData.map(x -> x.centroids).setParallelism(1); DataStream finalModelData = @@ -183,10 +184,11 @@ public IterationBodyResult process( } private static class CentroidsUpdateReducer - implements ReduceFunction> { + implements ReduceFunction> { @Override - public Tuple2 reduce( - Tuple2 tuple2, Tuple2 t1) + public Tuple2 reduce( + Tuple2 tuple2, + Tuple2 t1) throws Exception { for (int i = 0; i < tuple2.f0.length; i++) { tuple2.f0[i] += t1.f0[i]; @@ -198,28 +200,31 @@ public Tuple2 reduce( } private static class ModelDataGenerator - implements MapFunction, KMeansModelData> { + implements MapFunction, KMeansModelData> { @Override - public KMeansModelData map(Tuple2 tuple2) throws Exception { + public KMeansModelData map(Tuple2 tuple2) + throws Exception { double[] weights = new double[tuple2.f0.length]; for (int i = 0; i < tuple2.f0.length; i++) { BLAS.scal(1.0 / tuple2.f0[i], tuple2.f1[i]); weights[i] = tuple2.f0[i]; } - return new KMeansModelData(tuple2.f1, new DenseVector(weights)); + return new KMeansModelData(tuple2.f1, new DenseIntDoubleVector(weights)); } } private static class CentroidsUpdateAccumulator - extends AbstractStreamOperator> + extends AbstractStreamOperator> implements TwoInputStreamOperator< - DenseVector, DenseVector[], Tuple2>, - IterationListener> { + DenseIntDoubleVector, + DenseIntDoubleVector[], + Tuple2>, + IterationListener> { private final DistanceMeasure distanceMeasure; - private ListState centroids; + private ListState centroids; private ListStateWithCache points; @@ -232,8 +237,8 @@ public CentroidsUpdateAccumulator(DistanceMeasure distanceMeasure) { public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - TypeInformation type = - ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); + TypeInformation type = + ObjectArrayTypeInfo.getInfoFor(DenseIntDoubleVectorTypeInfo.INSTANCE); centroids = context.getOperatorStateStore() @@ -255,12 +260,14 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } @Override - public void processElement1(StreamRecord streamRecord) throws Exception { + public void processElement1(StreamRecord streamRecord) + throws Exception { points.add(new VectorWithNorm(streamRecord.getValue())); } @Override - public void processElement2(StreamRecord streamRecord) throws Exception { + public void processElement2(StreamRecord streamRecord) + throws Exception { Preconditions.checkState(!centroids.get().iterator().hasNext()); centroids.add(streamRecord.getValue()); } @@ -269,9 +276,9 @@ public void processElement2(StreamRecord streamRecord) throws Exc public void onEpochWatermarkIncremented( int epochWatermark, Context context, - Collector> out) + Collector> out) throws Exception { - DenseVector[] centroidValues = + DenseIntDoubleVector[] centroidValues = Objects.requireNonNull( OperatorStateUtils.getUniqueElement(centroids, "centroids") .orElse(null)); @@ -281,11 +288,11 @@ public void onEpochWatermarkIncremented( centroidsWithNorm[i] = new VectorWithNorm(centroidValues[i]); } - DenseVector[] newCentroids = new DenseVector[centroidValues.length]; + DenseIntDoubleVector[] newCentroids = new DenseIntDoubleVector[centroidValues.length]; Integer[] counts = new Integer[centroidValues.length]; Arrays.fill(counts, 0); for (int i = 0; i < centroidValues.length; i++) { - newCentroids[i] = new DenseVector(centroidValues[i].size()); + newCentroids[i] = new DenseIntDoubleVector(centroidValues[i].size()); } for (VectorWithNorm point : points.get()) { @@ -301,25 +308,25 @@ public void onEpochWatermarkIncremented( @Override public void onIterationTerminated( - Context context, Collector> collector) { + Context context, Collector> collector) { centroids.clear(); points.clear(); } } - public static DataStream selectRandomCentroids( - DataStream data, int k, long seed) { - DataStream resultStream = + public static DataStream selectRandomCentroids( + DataStream data, int k, long seed) { + DataStream resultStream = DataStreamUtils.mapPartition( DataStreamUtils.sample(data, k, seed), - new MapPartitionFunction() { + new MapPartitionFunction() { @Override public void mapPartition( - Iterable iterable, - Collector collector) { - List list = new ArrayList<>(); + Iterable iterable, + Collector collector) { + List list = new ArrayList<>(); iterable.iterator().forEachRemaining(list::add); - collector.collect(list.toArray(new DenseVector[0])); + collector.collect(list.toArray(new DenseIntDoubleVector[0])); } }); resultStream.getTransformation().setParallelism(1); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java index 5aa57aec5..cac0846ad 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.distance.DistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -137,7 +137,8 @@ public Row map(Row dataPoint) { centroids[i] = new VectorWithNorm(modelData.centroids[i]); } } - DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense(); + DenseIntDoubleVector point = + ((IntDoubleVector) dataPoint.getField(featuresCol)).toDense(); int closestCentroidId = distanceMeasure.findClosest(centroids, new VectorWithNorm(point)); return Row.join(dataPoint, Row.of(closestCentroidId)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java index eb00b1ea3..b10bd471b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java @@ -28,9 +28,9 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; @@ -52,7 +52,7 @@ */ public class KMeansModelData { - public DenseVector[] centroids; + public DenseIntDoubleVector[] centroids; /** * The weight of the centroids. It is used when updating the model data in online training @@ -61,9 +61,9 @@ public class KMeansModelData { *

KMeansModelData objects generated during {@link KMeans#fit(Table...)} also contains this * field, so that it can be used as the initial model data of the online training process. */ - public DenseVector weights; + public DenseIntDoubleVector weights; - public KMeansModelData(DenseVector[] centroids, DenseVector weights) { + public KMeansModelData(DenseIntDoubleVector[] centroids, DenseIntDoubleVector weights) { Preconditions.checkArgument(centroids.length == weights.size()); this.centroids = centroids; this.weights = weights; @@ -103,15 +103,15 @@ private RandomCentroidsCreator(int k, int dim, double weight, long seed) { @Override public KMeansModelData map(Integer integer) { - DenseVector[] centroids = new DenseVector[k]; + DenseIntDoubleVector[] centroids = new DenseIntDoubleVector[k]; Random random = new Random(seed); for (int i = 0; i < k; i++) { - centroids[i] = new DenseVector(dim); + centroids[i] = new DenseIntDoubleVector(dim); for (int j = 0; j < dim; j++) { centroids[i].values[j] = random.nextDouble(); } } - DenseVector weights = new DenseVector(k); + DenseIntDoubleVector weights = new DenseIntDoubleVector(k); Arrays.fill(weights.values, weight); return new KMeansModelData(centroids, weights); } @@ -130,15 +130,16 @@ public static DataStream getModelDataStream(Table modelData) { .map( x -> new KMeansModelData( - Arrays.stream(((Vector[]) x.getField(0))) - .map(Vector::toDense) - .toArray(DenseVector[]::new), - ((Vector) x.getField(1)).toDense())); + Arrays.stream(((IntDoubleVector[]) x.getField(0))) + .map(IntDoubleVector::toDense) + .toArray(DenseIntDoubleVector[]::new), + ((IntDoubleVector) x.getField(1)).toDense())); } /** Data encoder for {@link KMeansModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(KMeansModelData modelData, OutputStream outputStream) @@ -146,7 +147,7 @@ public void encode(KMeansModelData modelData, OutputStream outputStream) DataOutputViewStreamWrapper outputViewStreamWrapper = new DataOutputViewStreamWrapper(outputStream); IntSerializer.INSTANCE.serialize(modelData.centroids.length, outputViewStreamWrapper); - for (DenseVector denseVector : modelData.centroids) { + for (DenseIntDoubleVector denseVector : modelData.centroids) { serializer.serialize(denseVector, new DataOutputViewStreamWrapper(outputStream)); } serializer.serialize(modelData.weights, new DataOutputViewStreamWrapper(outputStream)); @@ -159,7 +160,8 @@ public static class ModelDataDecoder extends SimpleStreamFormat public Reader createReader( Configuration config, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public KMeansModelData read() throws IOException { @@ -168,11 +170,13 @@ public KMeansModelData read() throws IOException { new DataInputViewStreamWrapper(inputStream); int numDenseVectors = IntSerializer.INSTANCE.deserialize(inputViewStreamWrapper); - DenseVector[] centroids = new DenseVector[numDenseVectors]; + DenseIntDoubleVector[] centroids = + new DenseIntDoubleVector[numDenseVectors]; for (int i = 0; i < numDenseVectors; i++) { centroids[i] = serializer.deserialize(inputViewStreamWrapper); } - DenseVector weights = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector weights = + serializer.deserialize(inputViewStreamWrapper); return new KMeansModelData(centroids, weights); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java index d1783a792..0d407b8df 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java @@ -33,10 +33,10 @@ import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.distance.DistanceMeasure; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -89,7 +89,7 @@ public OnlineKMeansModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream points = + DataStream points = tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); DataStream initModelData = @@ -156,7 +156,7 @@ public OnlineKMeansIterationBody( public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { DataStream modelData = variableStreams.get(0); - DataStream points = dataStreams.get(0); + DataStream points = dataStreams.get(0); int parallelism = points.getParallelism(); @@ -188,10 +188,10 @@ public IterationBodyResult process( private static class ModelDataGlobalReducer implements ReduceFunction { @Override public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) { - DenseVector weights = modelData.weights; - DenseVector[] centroids = modelData.centroids; - DenseVector newWeights = newModelData.weights; - DenseVector[] newCentroids = newModelData.centroids; + DenseIntDoubleVector weights = modelData.weights; + DenseIntDoubleVector[] centroids = modelData.centroids; + DenseIntDoubleVector newWeights = newModelData.weights; + DenseIntDoubleVector[] newCentroids = newModelData.centroids; int k = newCentroids.length; int dim = newCentroids[0].size(); @@ -225,11 +225,12 @@ public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newMode * */ private static class ModelDataLocalUpdater extends AbstractStreamOperator - implements TwoInputStreamOperator { + implements TwoInputStreamOperator< + DenseIntDoubleVector[], KMeansModelData, KMeansModelData> { private final DistanceMeasure distanceMeasure; private final int k; private final double decayFactor; - private ListState localBatchDataState; + private ListState localBatchDataState; private ListState modelDataState; private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double decayFactor) { @@ -242,8 +243,8 @@ private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double dec public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - TypeInformation type = - ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); + TypeInformation type = + ObjectArrayTypeInfo.getInfoFor(DenseIntDoubleVectorTypeInfo.INSTANCE); localBatchDataState = context.getOperatorStateStore() .getListState(new ListStateDescriptor<>("localBatch", type)); @@ -255,7 +256,8 @@ public void initializeState(StateInitializationContext context) throws Exception } @Override - public void processElement1(StreamRecord pointsRecord) throws Exception { + public void processElement1(StreamRecord pointsRecord) + throws Exception { localBatchDataState.add(pointsRecord.getValue()); alignAndComputeModelData(); } @@ -276,30 +278,30 @@ private void alignAndComputeModelData() throws Exception { KMeansModelData modelData = OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get(); - DenseVector[] centroids = modelData.centroids; + DenseIntDoubleVector[] centroids = modelData.centroids; VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[modelData.centroids.length]; for (int i = 0; i < centroidsWithNorm.length; i++) { centroidsWithNorm[i] = new VectorWithNorm(modelData.centroids[i]); } - DenseVector weights = modelData.weights; + DenseIntDoubleVector weights = modelData.weights; modelDataState.clear(); - List pointsList = + List pointsList = IteratorUtils.toList(localBatchDataState.get().iterator()); - DenseVector[] points = pointsList.remove(0); + DenseIntDoubleVector[] points = pointsList.remove(0); localBatchDataState.update(pointsList); int dim = centroids[0].size(); int parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); // Computes new centroids. - DenseVector[] sums = new DenseVector[k]; + DenseIntDoubleVector[] sums = new DenseIntDoubleVector[k]; int[] counts = new int[k]; for (int i = 0; i < k; i++) { - sums[i] = new DenseVector(dim); + sums[i] = new DenseIntDoubleVector(dim); } - for (DenseVector point : points) { + for (DenseIntDoubleVector point : points) { int closestCentroidId = distanceMeasure.findClosest(centroidsWithNorm, new VectorWithNorm(point)); counts[closestCentroidId]++; @@ -313,7 +315,7 @@ private void alignAndComputeModelData() throws Exception { continue; } - DenseVector centroid = centroids[i]; + DenseIntDoubleVector centroid = centroids[i]; weights.values[i] = weights.values[i] + counts[i]; double lambda = counts[i] / weights.values[i]; @@ -325,7 +327,7 @@ private void alignAndComputeModelData() throws Exception { } } - private static class FeaturesExtractor implements MapFunction { + private static class FeaturesExtractor implements MapFunction { private final String featuresCol; private FeaturesExtractor(String featuresCol) { @@ -333,8 +335,8 @@ private FeaturesExtractor(String featuresCol) { } @Override - public DenseVector map(Row row) { - return ((Vector) row.getField(featuresCol)).toDense(); + public DenseIntDoubleVector map(Row row) { + return ((IntDoubleVector) row.getField(featuresCol)).toDense(); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java index 43742f186..91cc7633f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java @@ -26,8 +26,8 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.distance.DistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -173,7 +173,9 @@ public void processElement1(StreamRecord streamRecord) throws Exception { bufferedPointsState.add(dataPoint); return; } - DenseVector point = ((Vector) dataPoint.getField(featuresCol)).toDense(); + DenseIntDoubleVector point = + (DenseIntDoubleVector) + (((IntDoubleVector) dataPoint.getField(featuresCol)).toDense()); int closestCentroidId = distanceMeasure.findClosest(centroids, new VectorWithNorm(point)); output.collect(new StreamRecord<>(Row.join(dataPoint, Row.of(closestCentroidId)))); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java index d4d43cbdd..ce7482626 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java @@ -22,7 +22,7 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; /** The loss function for binary logistic loss. See {@link LogisticRegression} for example. */ @Internal @@ -32,7 +32,7 @@ public class BinaryLogisticLoss implements LossFunc { private BinaryLogisticLoss() {} @Override - public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); double labelScaled = 2 * dataPoint.getLabel() - 1; return dataPoint.getWeight() * Math.log(1 + Math.exp(-dot * labelScaled)); @@ -40,7 +40,9 @@ public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coeffici @Override public void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + LabeledPointWithWeight dataPoint, + DenseIntDoubleVector coefficient, + DenseIntDoubleVector cumGradient) { double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); double labelScaled = 2 * dataPoint.getLabel() - 1; double multiplier = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java index eb0f3bf58..06a104aa7 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java @@ -22,7 +22,7 @@ import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; /** * The loss function for hinge loss. See {@link LinearSVC} for example. @@ -36,7 +36,7 @@ public class HingeLoss implements LossFunc { private HingeLoss() {} @Override - public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); double labelScaled = 2 * dataPoint.getLabel() - 1; return dataPoint.getWeight() * Math.max(0, 1 - labelScaled * dot); @@ -44,7 +44,9 @@ public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coeffici @Override public void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + LabeledPointWithWeight dataPoint, + DenseIntDoubleVector coefficient, + DenseIntDoubleVector cumGradient) { double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); double labelScaled = 2 * dataPoint.getLabel() - 1; if (1 - labelScaled * dot > 0) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java index ea64649b1..7b943491f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java @@ -21,7 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.regression.linearregression.LinearRegression; /** The loss function for least square loss. See {@link LinearRegression} for example. */ @@ -32,14 +32,16 @@ public class LeastSquareLoss implements LossFunc { private LeastSquareLoss() {} @Override - public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); return dataPoint.getWeight() * 0.5 * Math.pow(dot - dataPoint.getLabel(), 2); } @Override public void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + LabeledPointWithWeight dataPoint, + DenseIntDoubleVector coefficient, + DenseIntDoubleVector cumGradient) { double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); BLAS.axpy( (dot - dataPoint.getLabel()) * dataPoint.getWeight(), diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java index 0fead4aab..777130142 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java @@ -20,7 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import java.io.Serializable; @@ -37,7 +37,7 @@ public interface LossFunc extends Serializable { * @param coefficient The model parameters. * @return The loss of the input data. */ - double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient); + double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient); /** * Computes the gradient on the given data point and adds the computed gradient to cumGradient. @@ -47,7 +47,9 @@ public interface LossFunc extends Serializable { * @param cumGradient The accumulated gradient. */ void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient); + LabeledPointWithWeight dataPoint, + DenseIntDoubleVector coefficient, + DenseIntDoubleVector cumGradient); /** Computes loss using the label and the prediction. */ default double computeLoss(double label, double prediction) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java index 647741d13..3cabc8ff6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java @@ -21,7 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.common.lossfunc.LossFunc; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.streaming.api.datastream.DataStream; /** @@ -39,8 +39,8 @@ public interface Optimizer { * @param lossFunc The loss function to optimize. * @return The fitted model data. */ - DataStream optimize( - DataStream initModelData, + DataStream optimize( + DataStream initModelData, DataStream trainData, LossFunc lossFunc); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java index 3d36d9aba..69f1a83f2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java @@ -20,7 +20,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; /** * A utility class for algorithms that need to handle regularization. The regularization term is @@ -45,7 +45,7 @@ class RegularizationUtils { * @return The loss introduced by regularization. */ public static double regularize( - DenseVector coefficient, + DenseIntDoubleVector coefficient, final double reg, final double elasticNet, final double learningRate) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java index 2f7800482..b94b55f9d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java @@ -37,8 +37,8 @@ import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; import org.apache.flink.ml.common.lossfunc.LossFunc; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; @@ -79,8 +79,8 @@ public SGD( } @Override - public DataStream optimize( - DataStream initModelData, + public DataStream optimize( + DataStream initModelData, DataStream trainData, LossFunc lossFunc) { DataStreamList resultList = @@ -111,8 +111,8 @@ public IterationBodyResult process( // totalLoss]. DataStream variableStream = variableStreams.get(0); DataStream trainData = dataStreams.get(0); - final OutputTag modelDataOutputTag = - new OutputTag("MODEL_OUTPUT") {}; + final OutputTag modelDataOutputTag = + new OutputTag("MODEL_OUTPUT") {}; SingleOutputStreamOperator modelUpdateAndWeightAndLoss = trainData @@ -164,7 +164,7 @@ private static class CacheDataAndDoTrain extends AbstractStreamOperator modelDataOutputTag; + private final OutputTag modelDataOutputTag; /** The cached training data. */ private List trainData; @@ -177,9 +177,9 @@ private static class CacheDataAndDoTrain extends AbstractStreamOperator nextBatchOffsetState; /** The model coefficient. */ - private DenseVector coefficient; + private DenseIntDoubleVector coefficient; - private ListState coefficientState; + private ListState coefficientState; /** The dimension of the coefficient. */ private int coefficientDim; @@ -196,7 +196,9 @@ private static class CacheDataAndDoTrain extends AbstractStreamOperator modelDataOutputTag) { + LossFunc lossFunc, + SGDParams params, + OutputTag modelDataOutputTag) { this.lossFunc = lossFunc; this.params = params; this.modelDataOutputTag = modelDataOutputTag; @@ -232,7 +234,7 @@ private void updateModel() { if (getTotalWeight() > 0) { BLAS.axpy( -params.learningRate / getTotalWeight(), - new DenseVector(feedbackArray), + new DenseIntDoubleVector(feedbackArray), coefficient, coefficientDim); double regLoss = @@ -247,7 +249,7 @@ public void onEpochWatermarkIncremented( int epochWatermark, Context context, Collector collector) throws Exception { if (epochWatermark == 0) { - coefficient = new DenseVector(feedbackArray); + coefficient = new DenseIntDoubleVector(feedbackArray); coefficientDim = coefficient.size(); feedbackArray = new double[coefficient.size() + 2]; } else { @@ -271,7 +273,7 @@ public void onEpochWatermarkIncremented( Arrays.fill(feedbackArray, 0); double totalLoss = 0; double totalWeight = 0; - DenseVector cumGradientsWrapper = new DenseVector(feedbackArray); + DenseIntDoubleVector cumGradientsWrapper = new DenseIntDoubleVector(feedbackArray); for (LabeledPointWithWeight dataPoint : miniBatchData) { totalLoss += lossFunc.computeLoss(dataPoint, coefficient); lossFunc.computeGradient(dataPoint, coefficient, cumGradientsWrapper); @@ -311,7 +313,8 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "coefficientState", DenseVectorTypeInfo.INSTANCE)); + "coefficientState", + DenseIntDoubleVectorTypeInfo.INSTANCE)); OperatorStateUtils.getUniqueElement(coefficientState, "coefficientState") .ifPresent(x -> coefficient = x); if (coefficient != null) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java index 3d37c1fed..e0da19729 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java @@ -18,22 +18,22 @@ package org.apache.flink.ml.common.util; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import java.util.ArrayList; import java.util.List; -/** Provides utility functions for {@link Vector}. */ +/** Provides utility functions for {@link IntDoubleVector}. */ public class VectorUtils { /** * Selects a subset of the vector base on the indices. Note that the input indices must be * sorted in ascending order. */ - public static Vector selectByIndices(Vector vector, int[] sortedIndices) { - if (vector instanceof DenseVector) { - DenseVector resultVec = new DenseVector(sortedIndices.length); + public static IntDoubleVector selectByIndices(IntDoubleVector vector, int[] sortedIndices) { + if (vector instanceof DenseIntDoubleVector) { + DenseIntDoubleVector resultVec = new DenseIntDoubleVector(sortedIndices.length); for (int i = 0; i < sortedIndices.length; i++) { resultVec.set(i, vector.get(sortedIndices[i])); } @@ -42,18 +42,18 @@ public static Vector selectByIndices(Vector vector, int[] sortedIndices) { List resultIndices = new ArrayList<>(); List resultValues = new ArrayList<>(); - int[] indices = ((SparseVector) vector).indices; + int[] indices = ((SparseIntDoubleVector) vector).indices; for (int i = 0, j = 0; i < indices.length && j < sortedIndices.length; ) { if (indices[i] == sortedIndices[j]) { resultIndices.add(j++); - resultValues.add(((SparseVector) vector).values[i++]); + resultValues.add(((SparseIntDoubleVector) vector).values[i++]); } else if (indices[i] > sortedIndices[j]) { j++; } else { i++; } } - return new SparseVector( + return new SparseIntDoubleVector( sortedIndices.length, resultIndices.stream().mapToInt(Integer::intValue).toArray(), resultValues.stream().mapToDouble(Double::doubleValue).toArray()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java index d74e40b24..97d648f86 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/evaluation/binaryclassification/BinaryClassificationEvaluator.java @@ -35,7 +35,7 @@ import org.apache.flink.ml.api.AlgoOperator; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -659,8 +659,8 @@ public Tuple3 map(Row value) throws Exception { double label = ((Number) value.getFieldAs(labelCol)).doubleValue(); Object probOrigin = value.getField(rawPredictionCol); double prob = - probOrigin instanceof Vector - ? ((Vector) probOrigin).get(1) + probOrigin instanceof IntDoubleVector + ? ((IntDoubleVector) probOrigin).get(1) : ((Number) probOrigin).doubleValue(); double weight = weightCol == null ? 1.0 : ((Number) value.getField(weightCol)).doubleValue(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java index aafbf6e7b..5384fdbc7 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java @@ -24,11 +24,11 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -72,11 +72,11 @@ public Table[] transform(Table... inputs) { for (int i = 0; i < inputCols.length; ++i) { int idx = inputTypeInfo.getFieldIndex(inputCols[i]); Class typeClass = inputTypeInfo.getTypeAt(idx).getTypeClass(); - if (typeClass.equals(SparseVector.class)) { - outputTypes[i] = SparseVectorTypeInfo.INSTANCE; - } else if (typeClass.equals(DenseVector.class)) { - outputTypes[i] = DenseVectorTypeInfo.INSTANCE; - } else if (typeClass.equals(Vector.class)) { + if (typeClass.equals(SparseIntDoubleVector.class)) { + outputTypes[i] = SparseIntDoubleVectorTypeInfo.INSTANCE; + } else if (typeClass.equals(DenseIntDoubleVector.class)) { + outputTypes[i] = DenseIntDoubleVectorTypeInfo.INSTANCE; + } else if (typeClass.equals(IntDoubleVector.class)) { outputTypes[i] = VectorTypeInfo.INSTANCE; } else { outputTypes[i] = Types.DOUBLE; @@ -119,15 +119,15 @@ public Row map(Row input) { } private Object binarizerFunc(Object obj, double threshold) { - if (obj instanceof DenseVector) { - DenseVector inputVec = (DenseVector) obj; - DenseVector vec = inputVec.clone(); + if (obj instanceof DenseIntDoubleVector) { + DenseIntDoubleVector inputVec = (DenseIntDoubleVector) obj; + DenseIntDoubleVector vec = inputVec.clone(); for (int i = 0; i < vec.size(); ++i) { vec.values[i] = inputVec.get(i) > threshold ? 1.0 : 0.0; } return vec; - } else if (obj instanceof SparseVector) { - SparseVector inputVec = (SparseVector) obj; + } else if (obj instanceof SparseIntDoubleVector) { + SparseIntDoubleVector inputVec = (SparseIntDoubleVector) obj; int[] newIndices = new int[inputVec.indices.length]; int pos = 0; @@ -139,7 +139,8 @@ private Object binarizerFunc(Object obj, double threshold) { double[] newValues = new double[pos]; Arrays.fill(newValues, 1.0); - return new SparseVector(inputVec.size(), Arrays.copyOf(newIndices, pos), newValues); + return new SparseIntDoubleVector( + inputVec.size(), Arrays.copyOf(newIndices, pos), newValues); } else { return Double.parseDouble(obj.toString()) > threshold ? 1.0 : 0.0; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java index 390d99762..d3fa93f35 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java @@ -23,9 +23,9 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -107,7 +107,8 @@ public Table[] transform(Table... inputs) { RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll( - inputTypeInfo.getFieldTypes(), SparseVectorTypeInfo.INSTANCE), + inputTypeInfo.getFieldTypes(), + SparseIntDoubleVectorTypeInfo.INSTANCE), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream output = @@ -177,7 +178,7 @@ public Row map(Row row) throws Exception { } } - SparseVector outputVec = + SparseIntDoubleVector outputVec = Vectors.sparse( termCounts.length, indices.stream().mapToInt(i -> i).toArray(), diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/dct/DCT.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/dct/DCT.java index c343cd763..d2c0d2716 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/dct/DCT.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/dct/DCT.java @@ -22,10 +22,10 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -70,7 +70,8 @@ public Table[] transform(Table... inputs) { RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll( - inputTypeInfo.getFieldTypes(), DenseVectorTypeInfo.INSTANCE), + inputTypeInfo.getFieldTypes(), + DenseIntDoubleVectorTypeInfo.INSTANCE), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream stream = @@ -101,7 +102,7 @@ private DCTFunction(String inputCol, boolean isInverse) { @Override public Row map(Row row) throws Exception { - Vector vector = row.getFieldAs(inputCol); + IntDoubleVector vector = row.getFieldAs(inputCol); if (previousVectorSize != vector.size()) { if (isInverse) { @@ -113,7 +114,7 @@ public Row map(Row row) throws Exception { } double[] array = vector.toArray(); - if (vector instanceof DenseVector) { + if (vector instanceof DenseIntDoubleVector) { array = array.clone(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.java index 5a49fb64c..4e9c00fe0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProduct.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -77,16 +77,16 @@ public Table[] transform(Table... inputs) { private static class ElementwiseProductFunction implements MapFunction { private final String inputCol; - private final Vector scalingVec; + private final IntDoubleVector scalingVec; - public ElementwiseProductFunction(String inputCol, Vector scalingVec) { + public ElementwiseProductFunction(String inputCol, IntDoubleVector scalingVec) { this.inputCol = inputCol; this.scalingVec = scalingVec; } @Override public Row map(Row value) { - Vector inputVec = value.getFieldAs(inputCol); + IntDoubleVector inputVec = value.getFieldAs(inputCol); if (inputVec != null) { if (scalingVec.size() != inputVec.size()) { throw new IllegalArgumentException( @@ -96,7 +96,7 @@ public Row map(Row value) { + inputVec.size() + ")."); } - Vector retVec = inputVec.clone(); + IntDoubleVector retVec = inputVec.clone(); BLAS.hDot(scalingVec, retVec); return Row.join(value, Row.of(retVec)); } else { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProductParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProductParams.java index 4bf612cd1..36aa046d6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProductParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/elementwiseproduct/ElementwiseProductParams.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.common.param.HasInputCol; import org.apache.flink.ml.common.param.HasOutputCol; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.ParamValidators; import org.apache.flink.ml.param.VectorParam; @@ -32,18 +32,18 @@ */ public interface ElementwiseProductParams extends HasInputCol, HasOutputCol { - Param SCALING_VEC = + Param SCALING_VEC = new VectorParam( "scalingVec", "The scaling vector to multiply with input vectors using hadamard product.", null, ParamValidators.notNull()); - default Vector getScalingVec() { + default IntDoubleVector getScalingVec() { return get(SCALING_VEC); } - default T setScalingVec(Vector value) { + default T setScalingVec(IntDoubleVector value) { set(SCALING_VEC, value); return (T) this; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java index 3ad50b758..abff9434d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java @@ -22,7 +22,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -137,7 +137,7 @@ public Row map(Row row) { values[pos] = entry.getValue(); pos++; } - return Row.join(row, Row.of(new SparseVector(numFeatures, indices, values))); + return Row.join(row, Row.of(new SparseIntDoubleVector(numFeatures, indices, values))); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java index 392019804..95660e694 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -76,7 +76,8 @@ public Table[] transform(Table... inputs) { RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll( - inputTypeInfo.getFieldTypes(), SparseVectorTypeInfo.INSTANCE), + inputTypeInfo.getFieldTypes(), + SparseIntDoubleVectorTypeInfo.INSTANCE), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream output = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java index 7c2d051c7..0c3e61c5b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java @@ -24,8 +24,9 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -98,7 +99,7 @@ public static IDF load(StreamTableEnvironment tEnv, String path) throws IOExcept /** The main logic to compute the model data of IDF. */ private static class IDFAggregator - implements AggregateFunction, IDFModelData> { + implements AggregateFunction, IDFModelData> { private final int minDocFreq; public IDFAggregator(int minDocFreq) { @@ -106,37 +107,38 @@ public IDFAggregator(int minDocFreq) { } @Override - public Tuple2 createAccumulator() { - return Tuple2.of(0L, new DenseVector(new double[0])); + public Tuple2 createAccumulator() { + return Tuple2.of(0L, new DenseIntDoubleVector(new double[0])); } @Override - public Tuple2 add( - Vector vector, Tuple2 numDocsAndDocFreq) { + public Tuple2 add( + Vector vector, Tuple2 numDocsAndDocFreq) { + IntDoubleVector intDoubleVector = (IntDoubleVector) vector; if (numDocsAndDocFreq.f0 == 0) { - numDocsAndDocFreq.f1 = new DenseVector(vector.size()); + numDocsAndDocFreq.f1 = new DenseIntDoubleVector(intDoubleVector.size()); } numDocsAndDocFreq.f0 += 1L; double[] values; - if (vector instanceof SparseVector) { - values = ((SparseVector) vector).values; + if (vector instanceof SparseIntDoubleVector) { + values = ((SparseIntDoubleVector) vector).values; } else { - values = ((DenseVector) vector).values; + values = ((DenseIntDoubleVector) vector).values; } for (int i = 0; i < values.length; i++) { values[i] = values[i] > 0 ? 1 : 0; } - BLAS.axpy(1, vector, numDocsAndDocFreq.f1); + BLAS.axpy(1, intDoubleVector, numDocsAndDocFreq.f1); return numDocsAndDocFreq; } @Override - public IDFModelData getResult(Tuple2 numDocsAndDocFreq) { + public IDFModelData getResult(Tuple2 numDocsAndDocFreq) { long numDocs = numDocsAndDocFreq.f0; - DenseVector docFreq = numDocsAndDocFreq.f1; + DenseIntDoubleVector docFreq = numDocsAndDocFreq.f1; Preconditions.checkState(numDocs > 0, "The training set is empty."); long[] filteredDocFreq = new long[docFreq.size()]; @@ -152,9 +154,9 @@ public IDFModelData getResult(Tuple2 numDocsAndDocFreq) { } @Override - public Tuple2 merge( - Tuple2 numDocsAndDocFreq1, - Tuple2 numDocsAndDocFreq2) { + public Tuple2 merge( + Tuple2 numDocsAndDocFreq1, + Tuple2 numDocsAndDocFreq2) { if (numDocsAndDocFreq1.f0 == 0) { return numDocsAndDocFreq2; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java index 87a2f254a..42fee1e84 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -125,7 +125,7 @@ public static IDFModel load(StreamTableEnvironment tEnv, String path) throws IOE private static class ComputeTfIdfFunction extends RichMapFunction { private final String inputCol; private final String broadcastKey; - private DenseVector idf; + private DenseIntDoubleVector idf; public ComputeTfIdfFunction(String broadcastKey, String inputCol) { this.broadcastKey = broadcastKey; @@ -141,7 +141,8 @@ public Row map(Row row) { idf = idfModelDataData.idf; } - Vector outputVec = ((Vector) Objects.requireNonNull(row.getField(inputCol))).clone(); + IntDoubleVector outputVec = + ((IntDoubleVector) Objects.requireNonNull(row.getField(inputCol))).clone(); BLAS.hDot(idf, outputVec); return Row.join(row, Row.of(outputVec)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java index a808454e8..3c495968d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java @@ -29,8 +29,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -48,7 +48,7 @@ */ public class IDFModelData { /** Inverse document frequency for all terms. */ - public DenseVector idf; + public DenseIntDoubleVector idf; /** Document frequency for all terms after filtering out infrequent terms. */ public long[] docFreq; /** Number of docs in the training set. */ @@ -56,7 +56,7 @@ public class IDFModelData { public IDFModelData() {} - public IDFModelData(DenseVector idf, long[] docFreq, long numDocs) { + public IDFModelData(DenseIntDoubleVector idf, long[] docFreq, long numDocs) { this.idf = idf; this.docFreq = docFreq; this.numDocs = numDocs; @@ -77,7 +77,8 @@ public static DataStream getModelDataStream(Table modelDataTable) /** Encoder for {@link IDFModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer denseVectorSerializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer denseVectorSerializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(IDFModelData modelData, OutputStream outputStream) throws IOException { @@ -93,14 +94,14 @@ public static class ModelDataDecoder extends SimpleStreamFormat { @Override public Reader createReader(Configuration config, FSDataInputStream stream) { return new Reader() { - private final DenseVectorSerializer denseVectorSerializer = - new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer denseVectorSerializer = + new DenseIntDoubleVectorSerializer(); @Override public IDFModelData read() throws IOException { DataInputView source = new DataInputViewStreamWrapper(stream); try { - DenseVector idf = denseVectorSerializer.deserialize(source); + DenseIntDoubleVector idf = denseVectorSerializer.deserialize(source); long[] docFreq = LongPrimitiveArraySerializer.INSTANCE.deserialize(source); long numDocs = LongSerializer.INSTANCE.deserialize(source); return new IDFModelData(idf, docFreq, numDocs); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/interaction/Interaction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/interaction/Interaction.java index 76e878981..b01910c50 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/interaction/Interaction.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/interaction/Interaction.java @@ -22,9 +22,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; @@ -104,8 +104,8 @@ public Row map(Row value) { return Row.join(value, Row.of((Object) null)); } - if (obj instanceof DenseVector) { - featureSize[i] = ((Vector) obj).size(); + if (obj instanceof DenseIntDoubleVector) { + featureSize[i] = ((IntDoubleVector) obj).size(); if (featureIndices[i] == null || featureIndices[i].length != featureSize[i]) { featureIndices[i] = new int[featureSize[i]]; for (int j = 0; j < featureSize[i]; ++j) { @@ -113,13 +113,13 @@ public Row map(Row value) { } } - featureValues[i] = ((DenseVector) obj).values; + featureValues[i] = ((DenseIntDoubleVector) obj).values; nnz *= featureSize[i]; - } else if (obj instanceof SparseVector) { - featureSize[i] = ((Vector) obj).size(); - featureIndices[i] = ((SparseVector) obj).indices; - featureValues[i] = ((SparseVector) obj).values; - nnz *= ((SparseVector) obj).values.length; + } else if (obj instanceof SparseIntDoubleVector) { + featureSize[i] = ((IntDoubleVector) obj).size(); + featureIndices[i] = ((SparseIntDoubleVector) obj).indices; + featureValues[i] = ((SparseIntDoubleVector) obj).values; + nnz *= ((SparseIntDoubleVector) obj).values.length; hasSparse = true; } else { featureSize[i] = 1; @@ -128,7 +128,7 @@ public Row map(Row value) { } } - Vector ret; + IntDoubleVector ret; int featureIter = inputCols.length - 1; if (hasSparse) { int[] indices = new int[nnz]; @@ -170,7 +170,7 @@ public Row map(Row value) { } idxOffset *= prevValues.length; } - ret = new DenseVector(values); + ret = new DenseIntDoubleVector(values); } return Row.join(value, Row.of(ret)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java index 763a0df22..64b868072 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java @@ -23,8 +23,8 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler.MinMaxReduceFunctionOperator; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -93,13 +93,15 @@ public KBinsDiscretizerModel fit(Table... inputs) { String strategy = getStrategy(); int numBins = getNumBins(); - DataStream inputData = + DataStream inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction) - value -> ((Vector) value.getField(inputCol)).toDense()); + (MapFunction) + value -> + ((IntDoubleVector) value.getField(inputCol)) + .toDense()); - DataStream preprocessedData; + DataStream preprocessedData; if (strategy.equals(UNIFORM)) { preprocessedData = inputData @@ -121,12 +123,13 @@ public KBinsDiscretizerModel fit(Table... inputs) { DataStream modelData = DataStreamUtils.mapPartition( preprocessedData, - new MapPartitionFunction() { + new MapPartitionFunction< + DenseIntDoubleVector, KBinsDiscretizerModelData>() { @Override public void mapPartition( - Iterable iterable, + Iterable iterable, Collector collector) { - List list = new ArrayList<>(); + List list = new ArrayList<>(); iterable.iterator().forEachRemaining(list::add); if (list.size() == 0) { @@ -180,9 +183,9 @@ public static KBinsDiscretizer load(StreamTableEnvironment tEnv, String path) } private static double[][] findBinEdgesWithUniformStrategy( - List input, int numBins) { - DenseVector minVector = input.get(0); - DenseVector maxVector = input.get(1); + List input, int numBins) { + DenseIntDoubleVector minVector = input.get(0); + DenseIntDoubleVector maxVector = input.get(1); int numColumns = minVector.size(); double[][] binEdges = new double[numColumns][]; @@ -206,7 +209,7 @@ private static double[][] findBinEdgesWithUniformStrategy( } private static double[][] findBinEdgesWithQuantileStrategy( - List input, int numBins) { + List input, int numBins) { int numColumns = input.get(0).size(); int numData = input.size(); double[][] binEdges = new double[numColumns][]; @@ -271,7 +274,8 @@ private static double[][] findBinEdgesWithQuantileStrategy( return binEdges; } - private static double[][] findBinEdgesWithKMeansStrategy(List input, int numBins) { + private static double[][] findBinEdgesWithKMeansStrategy( + List input, int numBins) { int numColumns = input.get(0).size(); int numData = input.size(); double[][] binEdges = new double[numColumns][numBins + 1]; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java index 7053e313c..e42b57540 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -80,7 +80,7 @@ public Table[] transform(Table... inputs) { new RowTypeInfo( ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), - TypeInformation.of(DenseVector.class)), + TypeInformation.of(DenseIntDoubleVector.class)), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream output = @@ -149,8 +149,8 @@ public Row map(Row row) { getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); binEdges = modelData.binEdges; } - DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense(); - DenseVector outputVec = inputVec.clone(); + DenseIntDoubleVector inputVec = ((IntDoubleVector) row.getField(inputCol)).toDense(); + DenseIntDoubleVector outputVec = inputVec.clone(); for (int i = 0; i < inputVec.size(); i++) { double targetFeature = inputVec.get(i); int index = Arrays.binarySearch(binEdges[i], targetFeature); @@ -164,7 +164,7 @@ public Row map(Row row) { index = Math.min(index, (binEdges[i].length - 2)); index = Math.max(index, 0); - outputVec.set(i, index); + outputVec.set(i, (double) index); } return Row.join(row, Row.of(outputVec)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java index 80df8c4d6..60493bc91 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSH.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -80,7 +80,7 @@ private static DataStream getVectorSize(DataStream input, String v DataStream vectorSizes = input.map( d -> { - Vector vec = d.getFieldAs(vectorCol); + IntDoubleVector vec = d.getFieldAs(vectorCol); return vec.size(); }); return DataStreamUtils.reduce( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java index 66c88ea40..e72141d21 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java @@ -33,9 +33,9 @@ import org.apache.flink.ml.common.datastream.EndOfStreamWindows; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.typeinfo.PriorityQueueTypeInfo; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -108,7 +108,7 @@ public Table[] transform(Table... inputs) { tEnv.toDataStream(modelDataTable, modelDataClass); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); - TypeInformation outputType = TypeInformation.of(DenseVector[].class); + TypeInformation outputType = TypeInformation.of(DenseIntDoubleVector[].class); RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputType), @@ -138,7 +138,7 @@ public Table[] transform(Table... inputs) { * @return A dataset containing at most k items closest to the key with a column named `distCol` * appended. */ - public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) { + public Table approxNearestNeighbors(Table dataset, IntDoubleVector key, int k, String distCol) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) dataset).getTableEnvironment(); Table transformedTable = @@ -189,7 +189,7 @@ public Table approxNearestNeighbors(Table dataset, Vector key, int k, String dis * An overloaded version of `approxNearestNeighbors` with "distCol" as default value of * `distCol`. */ - public Table approxNearestNeighbors(Table dataset, Vector key, int k) { + public Table approxNearestNeighbors(Table dataset, IntDoubleVector key, int k) { return approxNearestNeighbors(dataset, key, k, "distCol"); } @@ -308,7 +308,7 @@ private RowTypeInfo getOutputType(Table dataTable, String idCol) { idColType, VectorTypeInfo.INSTANCE, Types.INT, - DenseVectorTypeInfo.INSTANCE + DenseIntDoubleVectorTypeInfo.INSTANCE }, new String[] {idCol, getInputCol(), indexCol, hashValueCol}); return outputTypeInfo; @@ -330,7 +330,7 @@ public Row map(Row value) throws Exception { (LSHModelData) getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0); } - Vector[] hashValues = modelData.hashFunction(value.getFieldAs(inputCol)); + IntDoubleVector[] hashValues = modelData.hashFunction(value.getFieldAs(inputCol)); return Row.join(value, Row.of((Object) hashValues)); } } @@ -338,11 +338,11 @@ public Row map(Row value) throws Exception { private static class FilterByBucketFunction extends RichFlatMapFunction { private final String inputCol; private final String outputCol; - private final Vector key; + private final IntDoubleVector key; private LSHModelData modelData; - private DenseVector[] keyHashes; + private DenseIntDoubleVector[] keyHashes; - public FilterByBucketFunction(String inputCol, String outputCol, Vector key) { + public FilterByBucketFunction(String inputCol, String outputCol, IntDoubleVector key) { this.inputCol = inputCol; this.outputCol = outputCol; this.key = key; @@ -356,7 +356,7 @@ public void flatMap(Row value, Collector out) throws Exception { getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0); keyHashes = modelData.hashFunction(key); } - DenseVector[] hashes = value.getFieldAs(outputCol); + DenseIntDoubleVector[] hashes = value.getFieldAs(outputCol); boolean sameBucket = false; for (int i = 0; i < keyHashes.length; i += 1) { if (keyHashes[i].equals(hashes[i])) { @@ -367,7 +367,7 @@ public void flatMap(Row value, Collector out) throws Exception { if (!sameBucket) { return; } - Vector vec = value.getFieldAs(inputCol); + IntDoubleVector vec = value.getFieldAs(inputCol); double dist = modelData.keyDistance(key, vec); out.collect(Row.join(value, Row.of(dist))); } @@ -447,7 +447,7 @@ public ExplodeHashValuesFunction(String idCol, String inputCol, String outputCol @Override public void flatMap(Row value, Collector out) throws Exception { Row kept = Row.of(value.getField(idCol), value.getField(inputCol)); - DenseVector[] hashValues = value.getFieldAs(outputCol); + DenseIntDoubleVector[] hashValues = value.getFieldAs(outputCol); for (int i = 0; i < hashValues.length; i += 1) { out.collect(Row.join(kept, Row.of(i, hashValues[i]))); } @@ -455,10 +455,10 @@ public void flatMap(Row value, Collector out) throws Exception { } private static class IndexHashValueKeySelector - implements KeySelector> { + implements KeySelector> { @Override - public Tuple2 getKey(Row value) throws Exception { + public Tuple2 getKey(Row value) throws Exception { return Tuple2.of(value.getFieldAs(2), value.getFieldAs(3)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModelData.java index 1b95e1906..a69e5087e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModelData.java @@ -18,8 +18,8 @@ package org.apache.flink.ml.feature.lsh; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; /** * Base class for LSH model data. A concrete class extending this base class should implement how to @@ -33,7 +33,7 @@ abstract class LSHModelData { * @param vec input vector. * @return the mapping of LSH functions. */ - public abstract DenseVector[] hashFunction(Vector vec); + public abstract DenseIntDoubleVector[] hashFunction(IntDoubleVector vec); /** * Calculates the distance between two different feature vectors using the corresponding @@ -43,5 +43,5 @@ abstract class LSHModelData { * @param y One input vector in the metric space. * @return The distance between x and y. */ - public abstract double keyDistance(Vector x, Vector y); + public abstract double keyDistance(IntDoubleVector x, IntDoubleVector y); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java index 7e4e7e012..6217de24c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSHModelData.java @@ -28,8 +28,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.util.Preconditions; import java.io.EOFException; @@ -122,7 +122,7 @@ public TypeInformation getProducedType() { } @Override - public DenseVector[] hashFunction(Vector vec) { + public DenseIntDoubleVector[] hashFunction(IntDoubleVector vec) { int[] indices = vec.toSparse().indices; Preconditions.checkArgument(indices.length > 0, "Must have at least 1 non zero entry."); double[][] hashValues = new double[numHashTables][numHashFunctionsPerTable]; @@ -139,11 +139,13 @@ public DenseVector[] hashFunction(Vector vec) { hashValues[i][j] = minv; } } - return Arrays.stream(hashValues).map(DenseVector::new).toArray(DenseVector[]::new); + return Arrays.stream(hashValues) + .map(DenseIntDoubleVector::new) + .toArray(DenseIntDoubleVector[]::new); } @Override - public double keyDistance(Vector x, Vector y) { + public double keyDistance(IntDoubleVector x, IntDoubleVector y) { int[] xIndices = x.toSparse().indices; int[] yIndices = y.toSparse().indices; Preconditions.checkArgument( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java index 26aa0caa9..d754ea89c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java @@ -23,10 +23,11 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.api.Estimator; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -90,7 +91,7 @@ public MaxAbsScalerModel fit(Table... inputs) { DataStream modelData = maxAbsValues.map( (MapFunction) - vector -> new MaxAbsScalerModelData((DenseVector) vector)); + vector -> new MaxAbsScalerModelData((DenseIntDoubleVector) vector)); MaxAbsScalerModel model = new MaxAbsScalerModel().setModelData(tEnv.fromDataStream(modelData)); @@ -104,8 +105,8 @@ public MaxAbsScalerModel fit(Table... inputs) { */ private static class MaxAbsReduceFunctionOperator extends AbstractStreamOperator implements OneInputStreamOperator, BoundedOneInput { - private ListState maxAbsState; - private DenseVector maxAbsVector; + private ListState maxAbsState; + private DenseIntDoubleVector maxAbsVector; @Override public void endInput() { @@ -116,22 +117,24 @@ public void endInput() { @Override public void processElement(StreamRecord streamRecord) { - Vector currentValue = streamRecord.getValue(); + IntDoubleVector currentValue = (IntDoubleVector) streamRecord.getValue(); maxAbsVector = - maxAbsVector == null ? new DenseVector(currentValue.size()) : maxAbsVector; + maxAbsVector == null + ? new DenseIntDoubleVector(currentValue.size()) + : maxAbsVector; Preconditions.checkArgument( currentValue.size() == maxAbsVector.size(), "The training data should all have same dimensions."); - if (currentValue instanceof DenseVector) { - double[] values = ((DenseVector) currentValue).values; + if (currentValue instanceof DenseIntDoubleVector) { + double[] values = ((DenseIntDoubleVector) currentValue).values; for (int i = 0; i < currentValue.size(); ++i) { maxAbsVector.values[i] = Math.max(maxAbsVector.values[i], Math.abs(values[i])); } - } else if (currentValue instanceof SparseVector) { - int[] indices = ((SparseVector) currentValue).indices; - double[] values = ((SparseVector) currentValue).values; + } else if (currentValue instanceof SparseIntDoubleVector) { + int[] indices = ((SparseIntDoubleVector) currentValue).indices; + double[] values = ((SparseIntDoubleVector) currentValue).values; for (int i = 0; i < indices.length; ++i) { maxAbsVector.values[indices[i]] = @@ -147,7 +150,7 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "maxAbsState", DenseVectorTypeInfo.INSTANCE)); + "maxAbsState", DenseIntDoubleVectorTypeInfo.INSTANCE)); OperatorStateUtils.getUniqueElement(maxAbsState, "maxAbsState") .ifPresent(x -> maxAbsVector = x); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java index 5f5d7e4c8..036f2423d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -132,7 +132,7 @@ public static MaxAbsScalerModel load(StreamTableEnvironment tEnv, String path) private static class PredictOutputFunction extends RichMapFunction { private final String inputCol; private final String broadcastKey; - private DenseVector scaleVector; + private DenseIntDoubleVector scaleVector; public PredictOutputFunction(String broadcastKey, String inputCol) { this.broadcastKey = broadcastKey; @@ -156,8 +156,8 @@ public Row map(Row row) { } } - Vector inputVec = row.getFieldAs(inputCol); - Vector outputVec = inputVec.clone(); + IntDoubleVector inputVec = row.getFieldAs(inputCol); + IntDoubleVector outputVec = inputVec.clone(); BLAS.hDot(scaleVector, outputVec); return Row.join(row, Row.of(outputVec)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java index 4c1e76db5..bac6b3b9f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java @@ -27,8 +27,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -45,11 +45,11 @@ * classes to save/load model data. */ public class MaxAbsScalerModelData { - public DenseVector maxVector; + public DenseIntDoubleVector maxVector; public MaxAbsScalerModelData() {} - public MaxAbsScalerModelData(DenseVector maxVector) { + public MaxAbsScalerModelData(DenseIntDoubleVector maxVector) { this.maxVector = maxVector; } @@ -63,12 +63,13 @@ public static DataStream getModelDataStream(Table modelDa StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); return tEnv.toDataStream(modelDataTable) - .map(x -> new MaxAbsScalerModelData((DenseVector) x.getField(0))); + .map(x -> new MaxAbsScalerModelData((DenseIntDoubleVector) x.getField(0))); } /** Encoder for {@link MaxAbsScalerModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(MaxAbsScalerModelData modelData, OutputStream outputStream) @@ -84,13 +85,14 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration config, FSDataInputStream stream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public MaxAbsScalerModelData read() throws IOException { DataInputView source = new DataInputViewStreamWrapper(stream); try { - DenseVector maxVector = serializer.deserialize(source); + DenseIntDoubleVector maxVector = serializer.deserialize(source); return new MaxAbsScalerModelData(maxVector); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java index d21fd9a79..dec5f44ec 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java @@ -26,8 +26,8 @@ import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -78,12 +78,14 @@ public MinMaxScalerModel fit(Table... inputs) { final String inputCol = getInputCol(); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream inputData = + DataStream inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction) - value -> ((Vector) value.getField(inputCol)).toDense()); - DataStream minMaxValues = + (MapFunction) + value -> + ((IntDoubleVector) value.getField(inputCol)) + .toDense()); + DataStream minMaxValues = inputData .transform( "reduceInEachPartition", @@ -97,14 +99,15 @@ public MinMaxScalerModel fit(Table... inputs) { DataStream modelData = DataStreamUtils.mapPartition( minMaxValues, - new RichMapPartitionFunction() { + new RichMapPartitionFunction< + DenseIntDoubleVector, MinMaxScalerModelData>() { @Override public void mapPartition( - Iterable values, + Iterable values, Collector out) { - Iterator iter = values.iterator(); - DenseVector minVector = iter.next(); - DenseVector maxVector = iter.next(); + Iterator iter = values.iterator(); + DenseIntDoubleVector minVector = iter.next(); + DenseIntDoubleVector maxVector = iter.next(); out.collect(new MinMaxScalerModelData(minVector, maxVector)); } }); @@ -119,13 +122,15 @@ public void mapPartition( * A stream operator to compute the min and max values in each partition of the input bounded * data stream. */ - public static class MinMaxReduceFunctionOperator extends AbstractStreamOperator - implements OneInputStreamOperator, BoundedOneInput { - private ListState minState; - private ListState maxState; + public static class MinMaxReduceFunctionOperator + extends AbstractStreamOperator + implements OneInputStreamOperator, + BoundedOneInput { + private ListState minState; + private ListState maxState; - private DenseVector minVector; - private DenseVector maxVector; + private DenseIntDoubleVector minVector; + private DenseIntDoubleVector maxVector; @Override public void endInput() { @@ -136,12 +141,12 @@ public void endInput() { } @Override - public void processElement(StreamRecord streamRecord) { - DenseVector currentValue = streamRecord.getValue(); + public void processElement(StreamRecord streamRecord) { + DenseIntDoubleVector currentValue = streamRecord.getValue(); if (minVector == null) { int vecSize = currentValue.size(); - minVector = new DenseVector(vecSize); - maxVector = new DenseVector(vecSize); + minVector = new DenseIntDoubleVector(vecSize); + maxVector = new DenseIntDoubleVector(vecSize); System.arraycopy(currentValue.values, 0, minVector.values, 0, vecSize); System.arraycopy(currentValue.values, 0, maxVector.values, 0, vecSize); } else { @@ -163,12 +168,14 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "minState", TypeInformation.of(DenseVector.class))); + "minState", + TypeInformation.of(DenseIntDoubleVector.class))); maxState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "maxState", TypeInformation.of(DenseVector.class))); + "maxState", + TypeInformation.of(DenseIntDoubleVector.class))); OperatorStateUtils.getUniqueElement(minState, "minState").ifPresent(x -> minVector = x); OperatorStateUtils.getUniqueElement(maxState, "maxState").ifPresent(x -> maxVector = x); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java index 858c0f426..78d818db0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -79,7 +79,7 @@ public Table[] transform(Table... inputs) { new RowTypeInfo( ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), - TypeInformation.of(DenseVector.class)), + TypeInformation.of(DenseIntDoubleVector.class)), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); DataStream output = BroadcastUtils.withBroadcastStream( @@ -131,8 +131,8 @@ private static class PredictOutputFunction extends RichMapFunction { private final String broadcastKey; private final double upperBound; private final double lowerBound; - private DenseVector scaleVector; - private DenseVector offsetVector; + private DenseIntDoubleVector scaleVector; + private DenseIntDoubleVector offsetVector; public PredictOutputFunction( String broadcastKey, double upperBound, double lowerBound, String inputCol) { @@ -148,10 +148,10 @@ public Row map(Row row) { MinMaxScalerModelData minMaxScalerModelData = (MinMaxScalerModelData) getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); - DenseVector minVector = minMaxScalerModelData.minVector; - DenseVector maxVector = minMaxScalerModelData.maxVector; - scaleVector = new DenseVector(minVector.size()); - offsetVector = new DenseVector(minVector.size()); + DenseIntDoubleVector minVector = minMaxScalerModelData.minVector; + DenseIntDoubleVector maxVector = minMaxScalerModelData.maxVector; + scaleVector = new DenseIntDoubleVector(minVector.size()); + offsetVector = new DenseIntDoubleVector(minVector.size()); for (int i = 0; i < maxVector.size(); ++i) { if (Math.abs(minVector.values[i] - maxVector.values[i]) < 1.0e-5) { scaleVector.values[i] = 0.0; @@ -165,8 +165,8 @@ public Row map(Row row) { } } } - DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense(); - DenseVector outputVec = new DenseVector(scaleVector.size()); + DenseIntDoubleVector inputVec = ((IntDoubleVector) row.getField(inputCol)).toDense(); + DenseIntDoubleVector outputVec = new DenseIntDoubleVector(scaleVector.size()); for (int i = 0; i < scaleVector.size(); ++i) { outputVec.values[i] = inputVec.values[i] * scaleVector.values[i] + offsetVector.values[i]; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java index 451406e46..3e18b76f5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java @@ -27,8 +27,8 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -45,13 +45,13 @@ * classes to save/load model data. */ public class MinMaxScalerModelData { - public DenseVector minVector; + public DenseIntDoubleVector minVector; - public DenseVector maxVector; + public DenseIntDoubleVector maxVector; public MinMaxScalerModelData() {} - public MinMaxScalerModelData(DenseVector minVector, DenseVector maxVector) { + public MinMaxScalerModelData(DenseIntDoubleVector minVector, DenseIntDoubleVector maxVector) { this.minVector = minVector; this.maxVector = maxVector; } @@ -69,12 +69,14 @@ public static DataStream getModelDataStream(Table modelDa .map( x -> new MinMaxScalerModelData( - (DenseVector) x.getField(0), (DenseVector) x.getField(1))); + (DenseIntDoubleVector) x.getField(0), + (DenseIntDoubleVector) x.getField(1))); } /** Encoder for {@link MinMaxScalerModelData}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(MinMaxScalerModelData modelData, OutputStream outputStream) @@ -91,14 +93,15 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration config, FSDataInputStream stream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public MinMaxScalerModelData read() throws IOException { DataInputView source = new DataInputViewStreamWrapper(stream); try { - DenseVector minVector = serializer.deserialize(source); - DenseVector maxVector = serializer.deserialize(source); + DenseIntDoubleVector minVector = serializer.deserialize(source); + DenseIntDoubleVector maxVector = serializer.deserialize(source); return new MinMaxScalerModelData(minVector, maxVector); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/normalizer/Normalizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/normalizer/Normalizer.java index a9034ae3a..c3ed94fe4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/normalizer/Normalizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/normalizer/Normalizer.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -95,8 +95,8 @@ public NormalizationFunction(double p, String inputCol) { @Override public Row map(Row row) throws Exception { - Vector inputVec = row.getFieldAs(inputCol); - Vector outputVec = inputVec.clone(); + IntDoubleVector inputVec = row.getFieldAs(inputCol); + IntDoubleVector outputVec = inputVec.clone(); double norm = BLAS.norm(inputVec, p); BLAS.scal(1.0 / norm, outputVec); return Row.join(row, Row.of(outputVec)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java index aab391bc1..d2aefb2a2 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java @@ -27,7 +27,7 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -80,7 +80,8 @@ public Table[] transform(Table... inputs) { ArrayUtils.addAll( inputTypeInfo.getFieldTypes(), Collections.nCopies( - outputCols.length, SparseVectorTypeInfo.INSTANCE) + outputCols.length, + SparseIntDoubleVectorTypeInfo.INSTANCE) .toArray(new TypeInformation[0])), ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java index 86fd2eaad..68080787f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/polynomialexpansion/PolynomialExpansion.java @@ -23,9 +23,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -126,18 +126,19 @@ public PolynomialExpansionFunction(int degree, String inputCol) { @Override public Row map(Row row) throws Exception { - Vector vec = row.getFieldAs(inputCol); + IntDoubleVector vec = row.getFieldAs(inputCol); if (vec == null) { throw new IllegalArgumentException("The vector must not be null."); } - Vector outputVec; - if (vec instanceof DenseVector) { + IntDoubleVector outputVec; + if (vec instanceof DenseIntDoubleVector) { int size = vec.size(); double[] retVals = new double[getResultVectorSize(size, degree) - 1]; - expandDenseVector(((DenseVector) vec).values, size - 1, degree, 1.0, retVals, -1); - outputVec = new DenseVector(retVals); - } else if (vec instanceof SparseVector) { - SparseVector sparseVec = (SparseVector) vec; + expandDenseVector( + ((DenseIntDoubleVector) vec).values, size - 1, degree, 1.0, retVals, -1); + outputVec = new DenseIntDoubleVector(retVals); + } else if (vec instanceof SparseIntDoubleVector) { + SparseIntDoubleVector sparseVec = (SparseIntDoubleVector) vec; int[] indices = sparseVec.indices; double[] values = sparseVec.values; int size = sparseVec.size(); @@ -158,7 +159,7 @@ public Row map(Row row) throws Exception { -1); outputVec = - new SparseVector( + new SparseIntDoubleVector( getResultVectorSize(size, degree) - 1, polyIndices.f1, polyValues.f1); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java index cb6890811..8fc28e796 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java @@ -23,8 +23,8 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.util.QuantileSummary; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -73,11 +73,13 @@ public RobustScalerModel fit(Table... inputs) { final String inputCol = getInputCol(); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream inputData = + DataStream inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction) - value -> ((Vector) value.getField(inputCol)).toDense()); + (MapFunction) + value -> + ((IntDoubleVector) value.getField(inputCol)) + .toDense()); DataStream modelData = DataStreamUtils.aggregate( inputData, @@ -93,7 +95,8 @@ public RobustScalerModel fit(Table... inputs) { * RobustScalerModelData}. */ private static class QuantileAggregator - implements AggregateFunction { + implements AggregateFunction< + DenseIntDoubleVector, QuantileSummary[], RobustScalerModelData> { private final double relativeError; private final double lower; @@ -111,7 +114,8 @@ public QuantileSummary[] createAccumulator() { } @Override - public QuantileSummary[] add(DenseVector denseVector, QuantileSummary[] quantileSummaries) { + public QuantileSummary[] add( + DenseIntDoubleVector denseVector, QuantileSummary[] quantileSummaries) { if (quantileSummaries.length == 0) { quantileSummaries = new QuantileSummary[denseVector.size()]; for (int i = 0; i < denseVector.size(); i++) { @@ -136,8 +140,8 @@ public QuantileSummary[] add(DenseVector denseVector, QuantileSummary[] quantile @Override public RobustScalerModelData getResult(QuantileSummary[] quantileSummaries) { Preconditions.checkState(quantileSummaries.length != 0, "The training set is empty."); - DenseVector medianVector = new DenseVector(quantileSummaries.length); - DenseVector rangeVector = new DenseVector(quantileSummaries.length); + DenseIntDoubleVector medianVector = new DenseIntDoubleVector(quantileSummaries.length); + DenseIntDoubleVector rangeVector = new DenseIntDoubleVector(quantileSummaries.length); for (int i = 0; i < quantileSummaries.length; i++) { QuantileSummary compressed = quantileSummaries[i].compress(); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java index deda6e339..b6239527d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -98,8 +98,8 @@ private static class PredictOutputFunction extends RichMapFunction { private final boolean withCentering; private final boolean withScaling; - private DenseVector medians; - private DenseVector scales; + private DenseIntDoubleVector medians; + private DenseIntDoubleVector scales; public PredictOutputFunction( String broadcastModelKey, @@ -120,12 +120,13 @@ public Row map(Row row) throws Exception { getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); medians = modelData.medians; scales = - new DenseVector( + new DenseIntDoubleVector( Arrays.stream(modelData.ranges.values) .map(range -> range == 0 ? 0 : 1 / range) .toArray()); } - DenseVector outputVec = ((Vector) row.getField(inputCol)).clone().toDense(); + DenseIntDoubleVector outputVec = + ((IntDoubleVector) row.getField(inputCol)).clone().toDense(); Preconditions.checkState( medians.size() == outputVec.size(), "Number of features must be %s but got %s.", diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java index 807fe2497..253dd8f1b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java @@ -25,8 +25,8 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -43,13 +43,13 @@ * classes to save/load model data. */ public class RobustScalerModelData { - public DenseVector medians; + public DenseIntDoubleVector medians; - public DenseVector ranges; + public DenseIntDoubleVector ranges; public RobustScalerModelData() {} - public RobustScalerModelData(DenseVector medians, DenseVector ranges) { + public RobustScalerModelData(DenseIntDoubleVector medians, DenseIntDoubleVector ranges) { this.medians = medians; this.ranges = ranges; } @@ -67,13 +67,14 @@ public static DataStream getModelDataStream(Table modelDa .map( x -> new RobustScalerModelData( - (DenseVector) x.getField("medians"), - (DenseVector) x.getField("ranges"))); + (DenseIntDoubleVector) x.getField("medians"), + (DenseIntDoubleVector) x.getField("ranges"))); } /** Data encoder for the {@link RobustScalerModel} model data. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(RobustScalerModelData modelData, OutputStream outputStream) @@ -92,15 +93,18 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration configuration, FSDataInputStream inputStream) throws IOException { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public RobustScalerModelData read() throws IOException { DataInputViewStreamWrapper inputViewStreamWrapper = new DataInputViewStreamWrapper(inputStream); try { - DenseVector medians = serializer.deserialize(inputViewStreamWrapper); - DenseVector ranges = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector medians = + serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector ranges = + serializer.deserialize(inputViewStreamWrapper); return new RobustScalerModelData(medians, ranges); } catch (EOFException e) { return null; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java index 5c7d44748..c4ff95f0f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScaler.java @@ -28,10 +28,10 @@ import org.apache.flink.ml.common.window.EventTimeTumblingWindows; import org.apache.flink.ml.common.window.Windows; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -125,16 +125,17 @@ public void process( Iterable iterable, Collector collector) throws Exception { - ListState sumState = + ListState sumState = context.globalState() .getListState( new ListStateDescriptor<>( - "sumState", DenseVectorTypeInfo.INSTANCE)); - ListState squaredSumState = + "sumState", DenseIntDoubleVectorTypeInfo.INSTANCE)); + ListState squaredSumState = context.globalState() .getListState( new ListStateDescriptor<>( - "squaredSumState", DenseVectorTypeInfo.INSTANCE)); + "squaredSumState", + DenseIntDoubleVectorTypeInfo.INSTANCE)); ListState numElementsState = context.globalState() .getListState( @@ -143,9 +144,9 @@ public void process( context.globalState() .getListState( new ListStateDescriptor<>("modelVersionState", Types.LONG)); - DenseVector sum = + DenseIntDoubleVector sum = OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null); - DenseVector squaredSum = + DenseIntDoubleVector squaredSum = OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState") .orElse(null); long numElements = @@ -157,11 +158,12 @@ public void process( long numElementsBefore = numElements; for (Row element : iterable) { - Vector inputVec = - ((Vector) Objects.requireNonNull(element.getField(inputCol))).clone(); + IntDoubleVector inputVec = + ((IntDoubleVector) Objects.requireNonNull(element.getField(inputCol))) + .clone(); if (numElements == 0) { - sum = new DenseVector(inputVec.size()); - squaredSum = new DenseVector(inputVec.size()); + sum = new DenseIntDoubleVector(inputVec.size()); + squaredSum = new DenseIntDoubleVector(inputVec.size()); } BLAS.axpy(1, inputVec, sum); BLAS.hDot(inputVec, inputVec); @@ -190,8 +192,8 @@ public void process( private static StandardScalerModelData buildModelData( long numElements, - DenseVector sum, - DenseVector squaredSum, + DenseIntDoubleVector sum, + DenseIntDoubleVector squaredSum, long modelVersion, long currentTimeStamp) { BLAS.scal(1.0 / numElements, sum); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java index a491e6866..0fae000c6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/OnlineStandardScalerModel.java @@ -30,8 +30,8 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.metrics.MLMetrics; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -135,10 +135,10 @@ private static class PredictionOperator extends AbstractStreamOperator /** Model data for inference. */ private StandardScalerModelData modelData; - private DenseVector mean; + private DenseIntDoubleVector mean; /** Inverse of standard deviation. */ - private DenseVector scale; + private DenseIntDoubleVector scale; private long modelVersion; @@ -249,7 +249,7 @@ private void initializeModelData(StandardScalerModelData modelData) { modelTimeStamp = modelData.timestamp; modelVersion = modelData.version; mean = modelData.mean; - DenseVector std = modelData.std; + DenseIntDoubleVector std = modelData.std; if (withStd) { scale = std; @@ -263,11 +263,12 @@ private void initializeModelData(StandardScalerModelData modelData) { private void doPrediction(StreamRecord streamRecord) { Row dataPoint = streamRecord.getValue(); - Vector outputVec = - ((Vector) (Objects.requireNonNull(dataPoint.getField(inputCol)))).clone(); + IntDoubleVector outputVec = + ((IntDoubleVector) (Objects.requireNonNull(dataPoint.getField(inputCol)))) + .clone(); if (withMean) { outputVec = outputVec.toDense(); - BLAS.axpy(-1, mean, (DenseVector) outputVec); + BLAS.axpy(-1, mean, (DenseIntDoubleVector) outputVec); } if (withStd) { BLAS.hDot(scale, outputVec); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java index 59f519f28..e4ddc338f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java @@ -27,8 +27,8 @@ import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -76,15 +76,16 @@ public StandardScalerModel fit(Table... inputs) { Preconditions.checkArgument(inputs.length == 1); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream> sumAndSquaredSumAndWeight = - tEnv.toDataStream(inputs[0]) - .transform( - "computeMeta", - new TupleTypeInfo<>( - TypeInformation.of(DenseVector.class), - TypeInformation.of(DenseVector.class), - BasicTypeInfo.LONG_TYPE_INFO), - new ComputeMetaOperator(getInputCol())); + DataStream> + sumAndSquaredSumAndWeight = + tEnv.toDataStream(inputs[0]) + .transform( + "computeMeta", + new TupleTypeInfo<>( + TypeInformation.of(DenseIntDoubleVector.class), + TypeInformation.of(DenseIntDoubleVector.class), + BasicTypeInfo.LONG_TYPE_INFO), + new ComputeMetaOperator(getInputCol())); DataStream modelData = sumAndSquaredSumAndWeight @@ -105,13 +106,14 @@ public StandardScalerModel fit(Table... inputs) { */ private static class BuildModelOperator extends AbstractStreamOperator implements OneInputStreamOperator< - Tuple3, StandardScalerModelData>, + Tuple3, + StandardScalerModelData>, BoundedOneInput { - private ListState sumState; - private ListState squaredSumState; + private ListState sumState; + private ListState squaredSumState; private ListState numElementsState; - private DenseVector sum; - private DenseVector squaredSum; + private DenseIntDoubleVector sum; + private DenseIntDoubleVector squaredSum; private long numElements; @Override @@ -141,8 +143,9 @@ public void endInput() { } @Override - public void processElement(StreamRecord> element) { - Tuple3 value = element.getValue(); + public void processElement( + StreamRecord> element) { + Tuple3 value = element.getValue(); if (numElements == 0) { sum = value.f0; squaredSum = value.f1; @@ -161,13 +164,14 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "sumState", TypeInformation.of(DenseVector.class))); + "sumState", + TypeInformation.of(DenseIntDoubleVector.class))); squaredSumState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( "squaredSumState", - TypeInformation.of(DenseVector.class))); + TypeInformation.of(DenseIntDoubleVector.class))); numElementsState = context.getOperatorStateStore() .getListState( @@ -196,14 +200,15 @@ public void snapshotState(StateSnapshotContext context) throws Exception { /** Computes sum, squared sum and number of elements in each partition. */ private static class ComputeMetaOperator - extends AbstractStreamOperator> - implements OneInputStreamOperator>, + extends AbstractStreamOperator> + implements OneInputStreamOperator< + Row, Tuple3>, BoundedOneInput { - private ListState sumState; - private ListState squaredSumState; + private ListState sumState; + private ListState squaredSumState; private ListState numElementsState; - private DenseVector sum; - private DenseVector squaredSum; + private DenseIntDoubleVector sum; + private DenseIntDoubleVector squaredSum; private long numElements; private final String inputCol; @@ -221,10 +226,10 @@ public void endInput() { @Override public void processElement(StreamRecord element) { - Vector inputVec = (Vector) element.getValue().getField(inputCol); + IntDoubleVector inputVec = (IntDoubleVector) element.getValue().getField(inputCol); if (numElements == 0) { - sum = new DenseVector(inputVec.size()); - squaredSum = new DenseVector(inputVec.size()); + sum = new DenseIntDoubleVector(inputVec.size()); + squaredSum = new DenseIntDoubleVector(inputVec.size()); } BLAS.axpy(1, inputVec, sum); BLAS.hDot(inputVec, inputVec); @@ -239,13 +244,14 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "sumState", TypeInformation.of(DenseVector.class))); + "sumState", + TypeInformation.of(DenseIntDoubleVector.class))); squaredSumState = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( "squaredSumState", - TypeInformation.of(DenseVector.class))); + TypeInformation.of(DenseIntDoubleVector.class))); numElementsState = context.getOperatorStateStore() .getListState( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java index c3d31c83f..69dd42790 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java @@ -24,8 +24,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -96,8 +96,8 @@ private static class PredictOutputFunction extends RichMapFunction { private final String inputCol; private final boolean withMean; private final boolean withStd; - private DenseVector mean; - private DenseVector scale; + private DenseIntDoubleVector mean; + private DenseIntDoubleVector scale; public PredictOutputFunction( String broadcastModelKey, String inputCol, boolean withMean, boolean withStd) { @@ -114,7 +114,7 @@ public Row map(Row dataPoint) { (StandardScalerModelData) getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); mean = modelData.mean; - DenseVector std = modelData.std; + DenseIntDoubleVector std = modelData.std; if (withStd) { scale = std; @@ -125,10 +125,10 @@ public Row map(Row dataPoint) { } } - Vector outputVec = ((Vector) (dataPoint.getField(inputCol))).clone(); + IntDoubleVector outputVec = ((IntDoubleVector) (dataPoint.getField(inputCol))).clone(); if (withMean) { outputVec = outputVec.toDense(); - BLAS.axpy(-1, mean, (DenseVector) outputVec); + BLAS.axpy(-1, mean, (DenseIntDoubleVector) outputVec); } if (withStd) { BLAS.hDot(scale, outputVec); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModelData.java index fcaa2986d..01085b2dd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModelData.java @@ -27,8 +27,8 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -47,9 +47,9 @@ */ public class StandardScalerModelData { /** Mean of each dimension. */ - public DenseVector mean; + public DenseIntDoubleVector mean; /** Standard deviation of each dimension. */ - public DenseVector std; + public DenseIntDoubleVector std; /** Model version. */ public long version; /** Model timestamp. */ @@ -57,12 +57,12 @@ public class StandardScalerModelData { public StandardScalerModelData() {} - public StandardScalerModelData(DenseVector mean, DenseVector std) { + public StandardScalerModelData(DenseIntDoubleVector mean, DenseIntDoubleVector std) { this(mean, std, 0, Long.MAX_VALUE); } public StandardScalerModelData( - DenseVector mean, DenseVector std, long version, long timestamp) { + DenseIntDoubleVector mean, DenseIntDoubleVector std, long version, long timestamp) { this.mean = mean; this.std = std; this.version = version; @@ -93,7 +93,8 @@ public static DataStream getModelDataStream(Table model /** Data encoder for the {@link StandardScalerModel} model data. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(StandardScalerModelData modelData, OutputStream outputStream) @@ -114,7 +115,8 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration configuration, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public StandardScalerModelData read() throws IOException { @@ -122,8 +124,8 @@ public StandardScalerModelData read() throws IOException { new DataInputViewStreamWrapper(inputStream); try { - DenseVector mean = serializer.deserialize(inputViewStreamWrapper); - DenseVector std = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector mean = serializer.deserialize(inputViewStreamWrapper); + DenseIntDoubleVector std = serializer.deserialize(inputViewStreamWrapper); long version = LongSerializer.INSTANCE.deserialize(inputViewStreamWrapper); long timestamp = LongSerializer.INSTANCE.deserialize(inputViewStreamWrapper); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java index f5acf7a05..a0637a157 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java @@ -24,7 +24,7 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.util.VectorUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; @@ -127,14 +127,14 @@ public Row map(Row row) { if (indices.length == 0) { return Row.join(row, Row.of(Vectors.dense())); } else { - Vector inputVec = ((Vector) row.getField(inputCol)); + IntDoubleVector inputVec = ((IntDoubleVector) row.getField(inputCol)); Preconditions.checkArgument( inputVec.size() > indices[indices.length - 1], "Input %s features, but UnivariateFeatureSelector is " + "expecting at least %s features as input.", inputVec.size(), indices[indices.length - 1] + 1); - Vector outputVec = VectorUtils.selectByIndices(inputVec, indices); + IntDoubleVector outputVec = VectorUtils.selectByIndices(inputVec, indices); return Row.join(row, Row.of(outputVec)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java index 621415f07..e7b83a550 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java @@ -26,9 +26,10 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -80,8 +81,8 @@ public VarianceThresholdSelectorModel fit(Table... inputs) { new VarianceThresholdSelectorAggregator(getVarianceThreshold()), Types.TUPLE( Types.LONG, - DenseVectorTypeInfo.INSTANCE, - DenseVectorTypeInfo.INSTANCE), + DenseIntDoubleVectorTypeInfo.INSTANCE, + DenseIntDoubleVectorTypeInfo.INSTANCE), TypeInformation.of(VarianceThresholdSelectorModelData.class)); VarianceThresholdSelectorModel model = @@ -97,7 +98,7 @@ public VarianceThresholdSelectorModel fit(Table... inputs) { private static class VarianceThresholdSelectorAggregator implements AggregateFunction< Vector, - Tuple3, + Tuple3, VarianceThresholdSelectorModelData> { private final double varianceThreshold; @@ -107,31 +108,36 @@ public VarianceThresholdSelectorAggregator(double varianceThreshold) { } @Override - public Tuple3 createAccumulator() { - return Tuple3.of(0L, new DenseVector(new double[0]), new DenseVector(new double[0])); + public Tuple3 createAccumulator() { + return Tuple3.of( + 0L, + new DenseIntDoubleVector(new double[0]), + new DenseIntDoubleVector(new double[0])); } @Override - public Tuple3 add( - Vector vector, Tuple3 numAndSums) { + public Tuple3 add( + Vector vector, + Tuple3 numAndSums) { + IntDoubleVector intDoubleVector = (IntDoubleVector) vector; if (numAndSums.f0 == 0) { - numAndSums.f1 = new DenseVector(vector.size()); - numAndSums.f2 = new DenseVector(vector.size()); + numAndSums.f1 = new DenseIntDoubleVector(intDoubleVector.size()); + numAndSums.f2 = new DenseIntDoubleVector(intDoubleVector.size()); } numAndSums.f0 += 1L; - BLAS.axpy(1.0, vector, numAndSums.f1); - for (int i = 0; i < vector.size(); i++) { - numAndSums.f2.values[i] += vector.get(i) * vector.get(i); + BLAS.axpy(1.0, intDoubleVector, numAndSums.f1); + for (int i = 0; i < intDoubleVector.size(); i++) { + numAndSums.f2.values[i] += intDoubleVector.get(i) * intDoubleVector.get(i); } return numAndSums; } @Override public VarianceThresholdSelectorModelData getResult( - Tuple3 numAndSums) { + Tuple3 numAndSums) { long numRows = numAndSums.f0; - DenseVector sumVector = numAndSums.f1; - DenseVector squareSumVector = numAndSums.f2; + DenseIntDoubleVector sumVector = numAndSums.f1; + DenseIntDoubleVector squareSumVector = numAndSums.f2; Preconditions.checkState(numRows > 0, "The training set is empty."); int[] indices = @@ -149,9 +155,9 @@ public VarianceThresholdSelectorModelData getResult( } @Override - public Tuple3 merge( - Tuple3 numAndSums1, - Tuple3 acc) { + public Tuple3 merge( + Tuple3 numAndSums1, + Tuple3 acc) { if (numAndSums1.f0 == 0) { return acc; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java index f042c9f7d..6654547ce 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java @@ -24,7 +24,7 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.util.VectorUtils; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; @@ -152,7 +152,7 @@ public Row map(Row row) { .toArray(); } - Vector inputVec = ((Vector) row.getField(inputCol)); + IntDoubleVector inputVec = ((IntDoubleVector) row.getField(inputCol)); Preconditions.checkArgument( inputVec.size() == expectedNumOfFeatures, "%s has %s features, but VarianceThresholdSelector is expecting %s features as input.", @@ -162,7 +162,7 @@ public Row map(Row row) { if (indices.length == 0) { return Row.join(row, Row.of(Vectors.dense())); } else { - Vector outputVec = VectorUtils.selectByIndices(inputVec, indices); + IntDoubleVector outputVec = VectorUtils.selectByIndices(inputVec, indices); return Row.join(row, Row.of(outputVec)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java index e951f80e9..159cf8433 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java @@ -24,9 +24,9 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; @@ -112,7 +112,7 @@ public void flatMap(Row value, Collector out) { Tuple2 vectorSizeAndNnz = computeVectorSizeAndNnz(value); int vectorSize = vectorSizeAndNnz.f0; int nnz = vectorSizeAndNnz.f1; - Vector assembledVec = + IntDoubleVector assembledVec = nnz * RATIO > vectorSize ? assembleDense(inputCols, value, vectorSize) : assembleSparse(inputCols, value, vectorSize, nnz); @@ -139,16 +139,16 @@ private Tuple2 computeVectorSizeAndNnz(Row value) { } vectorSize += 1; nnz += 1; - } else if (object instanceof SparseVector) { - int localSize = ((SparseVector) object).size(); + } else if (object instanceof SparseIntDoubleVector) { + int localSize = ((SparseIntDoubleVector) object).size(); checkSize(inputSizes[i], localSize); - nnz += ((SparseVector) object).indices.length; + nnz += ((SparseIntDoubleVector) object).indices.length; vectorSize += localSize; - } else if (object instanceof DenseVector) { - int localSize = ((DenseVector) object).size(); + } else if (object instanceof DenseIntDoubleVector) { + int localSize = ((DenseIntDoubleVector) object).size(); checkSize(inputSizes[i], localSize); vectorSize += localSize; - nnz += ((DenseVector) object).size(); + nnz += ((DenseIntDoubleVector) object).size(); } else { throw new IllegalArgumentException( String.format( @@ -160,7 +160,7 @@ private Tuple2 computeVectorSizeAndNnz(Row value) { nnz += inputSizes[i]; if (keepInvalid) { if (inputSizes[i] > 1) { - DenseVector tmpVec = new DenseVector(inputSizes[i]); + DenseIntDoubleVector tmpVec = new DenseIntDoubleVector(inputSizes[i]); for (int j = 0; j < inputSizes[i]; ++j) { tmpVec.values[j] = Double.NaN; } @@ -205,7 +205,7 @@ public Map, Object> getParamMap() { } /** Assembles the input columns into a dense vector. */ - private static Vector assembleDense(String[] inputCols, Row inputRow, int vectorSize) { + private static IntDoubleVector assembleDense(String[] inputCols, Row inputRow, int vectorSize) { double[] values = new double[vectorSize]; int currentOffset = 0; @@ -213,15 +213,15 @@ private static Vector assembleDense(String[] inputCols, Row inputRow, int vector Object object = inputRow.getField(inputCol); if (object instanceof Number) { values[currentOffset++] = ((Number) object).doubleValue(); - } else if (object instanceof SparseVector) { - SparseVector sparseVector = (SparseVector) object; + } else if (object instanceof SparseIntDoubleVector) { + SparseIntDoubleVector sparseVector = (SparseIntDoubleVector) object; for (int i = 0; i < sparseVector.indices.length; i++) { values[currentOffset + sparseVector.indices[i]] = sparseVector.values[i]; } currentOffset += sparseVector.size(); } else { - DenseVector denseVector = (DenseVector) object; + DenseIntDoubleVector denseVector = (DenseIntDoubleVector) object; System.arraycopy(denseVector.values, 0, values, currentOffset, denseVector.size()); currentOffset += denseVector.size(); @@ -231,7 +231,7 @@ private static Vector assembleDense(String[] inputCols, Row inputRow, int vector } /** Assembles the input columns into a sparse vector. */ - private static Vector assembleSparse( + private static IntDoubleVector assembleSparse( String[] inputCols, Row inputRow, int vectorSize, int nnz) { int[] indices = new int[nnz]; double[] values = new double[nnz]; @@ -246,8 +246,8 @@ private static Vector assembleSparse( values[currentOffset] = ((Number) object).doubleValue(); currentOffset++; currentIndex++; - } else if (object instanceof SparseVector) { - SparseVector sparseVector = (SparseVector) object; + } else if (object instanceof SparseIntDoubleVector) { + SparseIntDoubleVector sparseVector = (SparseIntDoubleVector) object; for (int i = 0; i < sparseVector.indices.length; i++) { indices[currentOffset + i] = sparseVector.indices[i] + currentIndex; } @@ -256,7 +256,7 @@ private static Vector assembleSparse( currentIndex += sparseVector.size(); currentOffset += sparseVector.indices.length; } else { - DenseVector denseVector = (DenseVector) object; + DenseIntDoubleVector denseVector = (DenseIntDoubleVector) object; for (int i = 0; i < denseVector.size(); ++i) { indices[currentOffset + i] = i + currentIndex; } @@ -266,6 +266,6 @@ private static Vector assembleSparse( currentOffset += denseVector.size(); } } - return new SparseVector(vectorSize, indices, values); + return new SparseIntDoubleVector(vectorSize, indices, values); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java index 0bf4c7e57..3486adc26 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java @@ -27,7 +27,7 @@ import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -185,14 +185,14 @@ public void endInput() { public void processElement(StreamRecord element) { if (doublesByColumn == null) { // First record. - Vector vector = (Vector) element.getValue().getField(inputCol); + IntDoubleVector vector = (IntDoubleVector) element.getValue().getField(inputCol); doublesByColumn = new HashSet[vector.size()]; for (int i = 0; i < doublesByColumn.length; i++) { doublesByColumn[i] = new HashSet<>(); } } - Vector vector = (Vector) element.getValue().getField(inputCol); + IntDoubleVector vector = (IntDoubleVector) element.getValue().getField(inputCol); Preconditions.checkState( vector.size() == doublesByColumn.length, "The size of the all input vectors should be the same."); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java index 3becd230d..f5ed91819 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java @@ -24,7 +24,7 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.common.param.HasHandleInvalid; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -149,7 +149,7 @@ public void flatMap(Row input, Collector out) { categoryMaps = modelData.categoryMaps; } - Vector outputVector = ((Vector) input.getField(inputCol)).clone(); + IntDoubleVector outputVector = ((IntDoubleVector) input.getField(inputCol)).clone(); for (Map.Entry> entry : categoryMaps.entrySet()) { int columnId = entry.getKey(); Map mapping = entry.getValue(); @@ -158,7 +158,7 @@ public void flatMap(Row input, Collector out) { if (categoricalFeature == null) { return; } else { - outputVector.set(columnId, categoricalFeature); + outputVector.set(columnId, (double) categoricalFeature); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java index 2abca8910..417e0af1f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java @@ -22,9 +22,9 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -108,8 +108,8 @@ public VectorSliceFunction(Integer[] indices, String inputCol) { @Override public Row map(Row row) throws Exception { - Vector inputVec = row.getFieldAs(inputCol); - Vector outputVec; + IntDoubleVector inputVec = row.getFieldAs(inputCol); + IntDoubleVector outputVec; if (maxIndex >= inputVec.size()) { throw new IllegalArgumentException( "Index value " @@ -117,15 +117,15 @@ public Row map(Row row) throws Exception { + " is greater than vector size:" + inputVec.size()); } - if (inputVec instanceof DenseVector) { + if (inputVec instanceof DenseIntDoubleVector) { double[] values = new double[indices.length]; for (int i = 0; i < indices.length; ++i) { - values[i] = ((DenseVector) inputVec).values[indices[i]]; + values[i] = ((DenseIntDoubleVector) inputVec).values[indices[i]]; } - outputVec = new DenseVector(values); + outputVec = new DenseIntDoubleVector(values); } else { int nnz = 0; - SparseVector vec = (SparseVector) inputVec; + SparseIntDoubleVector vec = (SparseIntDoubleVector) inputVec; int[] outputIndices = new int[indices.length]; double[] outputValues = new double[indices.length]; for (int i = 0; i < indices.length; i++) { @@ -140,7 +140,7 @@ public Row map(Row row) throws Exception { outputIndices = Arrays.copyOf(outputIndices, nnz); outputValues = Arrays.copyOf(outputValues, nnz); } - outputVec = new SparseVector(indices.length, outputIndices, outputValues); + outputVec = new SparseIntDoubleVector(indices.length, outputIndices, outputValues); } return Row.join(row, Row.of(outputVec)); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java index d977eb0eb..50422088e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java @@ -25,8 +25,8 @@ import org.apache.flink.ml.common.lossfunc.LeastSquareLoss; import org.apache.flink.ml.common.optimizer.Optimizer; import org.apache.flink.ml.common.optimizer.SGD; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -73,13 +73,13 @@ public LinearRegressionModel fit(Table... inputs) { double label = ((Number) dataPoint.getField(getLabelCol())) .doubleValue(); - DenseVector features = - ((Vector) dataPoint.getField(getFeaturesCol())) + DenseIntDoubleVector features = + ((IntDoubleVector) dataPoint.getField(getFeaturesCol())) .toDense(); return new LabeledPointWithWeight(features, label, weight); }); - DataStream initModelData = + DataStream initModelData = DataStreamUtils.reduce( trainData.map(x -> x.getFeatures().size()), (ReduceFunction) @@ -89,7 +89,7 @@ public LinearRegressionModel fit(Table... inputs) { "The training data should all have same dimensions."); return t0; }) - .map(DenseVector::new); + .map(DenseIntDoubleVector::new); Optimizer optimizer = new SGD( @@ -99,7 +99,7 @@ public LinearRegressionModel fit(Table... inputs) { getTol(), getReg(), getElasticNet()); - DataStream rawModelData = + DataStream rawModelData = optimizer.optimize(initModelData, trainData, LeastSquareLoss.INSTANCE); DataStream modelData = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java index dec2f4365..2e5290b9b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java @@ -25,8 +25,8 @@ import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -127,7 +127,7 @@ private static class PredictLabelFunction extends RichMapFunction { private final String featuresCol; - private DenseVector coefficient; + private DenseIntDoubleVector coefficient; public PredictLabelFunction(String broadcastModelKey, String featuresCol) { this.broadcastModelKey = broadcastModelKey; @@ -142,7 +142,8 @@ public Row map(Row dataPoint) { getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); coefficient = modelData.coefficient; } - DenseVector features = ((Vector) dataPoint.getField(featuresCol)).toDense(); + DenseIntDoubleVector features = + ((IntDoubleVector) dataPoint.getField(featuresCol)).toDense(); Row predictionResult = predictOneDataPoint(features, coefficient); return Row.join(dataPoint, predictionResult); } @@ -155,7 +156,8 @@ public Row map(Row dataPoint) { * @param coefficient The model parameters. * @return The prediction label and the raw probabilities. */ - private static Row predictOneDataPoint(DenseVector feature, DenseVector coefficient) { + private static Row predictOneDataPoint( + DenseIntDoubleVector feature, DenseIntDoubleVector coefficient) { return Row.of(BLAS.dot(feature, coefficient)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java index efab0f6d7..927bdfb35 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java @@ -25,8 +25,8 @@ import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; @@ -44,9 +44,9 @@ */ public class LinearRegressionModelData { - public DenseVector coefficient; + public DenseIntDoubleVector coefficient; - public LinearRegressionModelData(DenseVector coefficient) { + public LinearRegressionModelData(DenseIntDoubleVector coefficient) { this.coefficient = coefficient; } @@ -62,12 +62,13 @@ public static DataStream getModelDataStream(Table mod StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); return tEnv.toDataStream(modelData) - .map(x -> new LinearRegressionModelData((DenseVector) x.getField(0))); + .map(x -> new LinearRegressionModelData((DenseIntDoubleVector) x.getField(0))); } /** Data encoder for {@link LinearRegressionModel}. */ public static class ModelDataEncoder implements Encoder { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public void encode(LinearRegressionModelData modelData, OutputStream outputStream) @@ -84,12 +85,13 @@ public static class ModelDataDecoder extends SimpleStreamFormat createReader( Configuration configuration, FSDataInputStream inputStream) { return new Reader() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer serializer = + new DenseIntDoubleVectorSerializer(); @Override public LinearRegressionModelData read() throws IOException { try { - DenseVector coefficient = + DenseIntDoubleVector coefficient = serializer.deserialize(new DataInputViewStreamWrapper(inputStream)); return new LinearRegressionModelData(coefficient); } catch (EOFException e) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java index d928774ac..64b3b37b0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java @@ -27,8 +27,8 @@ import org.apache.flink.ml.api.AlgoOperator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.param.HasFlatten; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -93,16 +93,16 @@ public Table[] transform(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream> inputData = + DataStream> inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction>) + (MapFunction>) row -> { Number number = (Number) row.getField(labelCol); Preconditions.checkNotNull( number, "Input data must contain label value."); return new Tuple2<>( - ((Vector) row.getField(featuresCol)), + ((IntDoubleVector) row.getField(featuresCol)), number.doubleValue()); }, Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE)); @@ -125,7 +125,7 @@ public Table[] transform(Table... inputs) { @SuppressWarnings("unchecked") private static class ANOVAAggregator implements AggregateFunction< - Tuple2, + Tuple2, Tuple3>>[], List> { @Override @@ -135,9 +135,9 @@ public Tuple3>>[] createAcc @Override public Tuple3>>[] add( - Tuple2 featuresAndLabel, + Tuple2 featuresAndLabel, Tuple3>>[] acc) { - Vector features = featuresAndLabel.f0; + IntDoubleVector features = featuresAndLabel.f0; double label = featuresAndLabel.f1; int numOfFeatures = features.size(); if (acc.length == 0) { @@ -257,15 +257,19 @@ private Table convertToTable( return tEnv.fromDataStream(output) .as("featureIndex", "pValue", "degreeOfFreedom", "fValue"); } else { - DataStream> output = + DataStream> output = datastream.map( - new MapFunction, Tuple3>() { + new MapFunction< + List, + Tuple3>() { @Override - public Tuple3 map( - List rows) { + public Tuple3 + map(List rows) { int numOfFeatures = rows.size(); - DenseVector pValues = new DenseVector(numOfFeatures); - DenseVector fValues = new DenseVector(numOfFeatures); + DenseIntDoubleVector pValues = + new DenseIntDoubleVector(numOfFeatures); + DenseIntDoubleVector fValues = + new DenseIntDoubleVector(numOfFeatures); long[] degrees = new long[numOfFeatures]; for (int i = 0; i < numOfFeatures; i++) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java index 01cac587d..c2f6dad97 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java @@ -32,9 +32,9 @@ import org.apache.flink.ml.api.AlgoOperator; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.param.HasFlatten; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -191,9 +191,9 @@ public Table[] transform(Table... inputs) { outputTypeInfo = new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, + DenseIntDoubleVectorTypeInfo.INSTANCE, Types.PRIMITIVE_ARRAY(Types.INT), - DenseVectorTypeInfo.INSTANCE + DenseIntDoubleVectorTypeInfo.INSTANCE }, new String[] {"pValues", "degreesOfFreedom", "statistics"}); } @@ -236,7 +236,7 @@ public void flatMap(Row row, Collector> collecto Double label = ((Number) row.getFieldAs(labelCol)).doubleValue(); - Vector features = row.getFieldAs(featuresCol); + IntDoubleVector features = row.getFieldAs(featuresCol); for (int i = 0; i < features.size(); i++) { collector.collect(Tuple3.of(i, features.get(i), label)); } @@ -650,8 +650,8 @@ private void endInputWithFlatten() { private void endInputWithoutFlatten() { int size = index2Statistic.size(); - Vector pValueScaledVector = new DenseVector(size); - Vector statisticScaledVector = new DenseVector(size); + IntDoubleVector pValueScaledVector = new DenseIntDoubleVector(size); + IntDoubleVector statisticScaledVector = new DenseIntDoubleVector(size); int[] dofArray = new int[size]; for (Map.Entry> entry : index2Statistic.entrySet()) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java index aaba3d46b..1924be69b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java @@ -33,8 +33,8 @@ import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.param.HasFlatten; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -99,24 +99,24 @@ public Table[] transform(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream> inputData = + DataStream> inputData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction>) + (MapFunction>) row -> { Number number = (Number) row.getField(labelCol); Preconditions.checkNotNull( number, "Input data must contain label value."); return new Tuple2<>( - ((Vector) row.getField(featuresCol)), + ((IntDoubleVector) row.getField(featuresCol)), number.doubleValue()); }) .returns(Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE)); - DataStream> summaries = - DataStreamUtils.aggregate(inputData, new SummaryAggregator()); + DataStream> + summaries = DataStreamUtils.aggregate(inputData, new SummaryAggregator()); - DataStream covarianceInEachPartition = + DataStream covarianceInEachPartition = BroadcastUtils.withBroadcastStream( Collections.singletonList(inputData), Collections.singletonMap(broadcastSummaryKey, summaries), @@ -126,10 +126,10 @@ public Table[] transform(Table... inputs) { input, new CalCovarianceOperator(broadcastSummaryKey)); }); - DataStream reducedCovariance = + DataStream reducedCovariance = DataStreamUtils.reduce( covarianceInEachPartition, - (ReduceFunction) + (ReduceFunction) (sums1, sums2) -> { BLAS.axpy(1.0, sums1, sums2); return sums2; @@ -156,24 +156,30 @@ private Table convertToTable( return tEnv.fromDataStream(dataStream) .as("featureIndex", "pValue", "degreeOfFreedom", "fValue"); } else { - DataStream> output = + DataStream> output = DataStreamUtils.mapPartition( dataStream, new MapPartitionFunction< Tuple4, - Tuple3>() { + Tuple3>() { @Override public void mapPartition( Iterable> iterable, - Collector> + Collector< + Tuple3< + DenseIntDoubleVector, + long[], + DenseIntDoubleVector>> collector) { List> rows = IteratorUtils.toList(iterable.iterator()); int numOfFeatures = rows.size(); - DenseVector pValues = new DenseVector(numOfFeatures); + DenseIntDoubleVector pValues = + new DenseIntDoubleVector(numOfFeatures); long[] degrees = new long[numOfFeatures]; - DenseVector fValues = new DenseVector(numOfFeatures); + DenseIntDoubleVector fValues = + new DenseIntDoubleVector(numOfFeatures); for (int i = 0; i < numOfFeatures; i++) { Tuple4 tuple = rows.get(i); @@ -190,7 +196,8 @@ public void mapPartition( /** Computes the covariance of each feature on each partition. */ private static class CalCovarianceOperator - extends RichMapPartitionFunction, DenseVector> { + extends RichMapPartitionFunction< + Tuple2, DenseIntDoubleVector> { private final String broadcastKey; @@ -200,14 +207,15 @@ private CalCovarianceOperator(String broadcastKey) { @Override public void mapPartition( - Iterable> iterable, Collector collector) { - Tuple5 summaries = - (Tuple5) + Iterable> iterable, + Collector collector) { + Tuple5 summaries = + (Tuple5) getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); int expectedNumOfFeatures = summaries.f3.size(); - DenseVector sumVector = new DenseVector(expectedNumOfFeatures); - for (Tuple2 featuresAndLabel : iterable) { + DenseIntDoubleVector sumVector = new DenseIntDoubleVector(expectedNumOfFeatures); + for (Tuple2 featuresAndLabel : iterable) { Preconditions.checkArgument( featuresAndLabel.f0.size() == expectedNumOfFeatures, "Input %s features, but FValueTest is expecting %s features.", @@ -229,10 +237,11 @@ public void mapPartition( /** Computes the p-value, fValues and the number of degrees of freedom of input features. */ private static class CalFValueOperator - extends RichMapPartitionFunction> { + extends RichMapPartitionFunction< + DenseIntDoubleVector, Tuple4> { private final String broadcastKey; - private DenseVector sumVector; + private DenseIntDoubleVector sumVector; private CalFValueOperator(String broadcastKey) { this.broadcastKey = broadcastKey; @@ -240,10 +249,10 @@ private CalFValueOperator(String broadcastKey) { @Override public void mapPartition( - Iterable iterable, + Iterable iterable, Collector> collector) { - Tuple5 summaries = - (Tuple5) + Tuple5 summaries = + (Tuple5) getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); int expectedNumOfFeatures = summaries.f4.size(); @@ -273,26 +282,31 @@ public void mapPartition( /** Computes the num, mean, and standard deviation of the input label and features. */ private static class SummaryAggregator implements AggregateFunction< - Tuple2, - Tuple5, - Tuple5> { + Tuple2, + Tuple5, + Tuple5> { @Override - public Tuple5 createAccumulator() { + public Tuple5 + createAccumulator() { return Tuple5.of( - 0L, 0.0, 0.0, new DenseVector(new double[0]), new DenseVector(new double[0])); + 0L, + 0.0, + 0.0, + new DenseIntDoubleVector(new double[0]), + new DenseIntDoubleVector(new double[0])); } @Override - public Tuple5 add( - Tuple2 featuresAndLabel, - Tuple5 summary) { - Vector features = featuresAndLabel.f0; + public Tuple5 add( + Tuple2 featuresAndLabel, + Tuple5 summary) { + IntDoubleVector features = featuresAndLabel.f0; double label = featuresAndLabel.f1; if (summary.f0 == 0) { - summary.f3 = new DenseVector(features.size()); - summary.f4 = new DenseVector(features.size()); + summary.f3 = new DenseIntDoubleVector(features.size()); + summary.f4 = new DenseIntDoubleVector(features.size()); } summary.f0 += 1L; summary.f1 += label; @@ -306,14 +320,14 @@ public Tuple5 add( } @Override - public Tuple5 getResult( - Tuple5 summary) { + public Tuple5 getResult( + Tuple5 summary) { final long numRows = summary.f0; Preconditions.checkState(numRows > 0, "The training set is empty."); int numOfFeatures = summary.f3.size(); double labelMean = summary.f1 / numRows; - Tuple5 result = + Tuple5 result = Tuple5.of( numRows, labelMean, @@ -321,8 +335,8 @@ public Tuple5 getResult( (summary.f2 / numRows - labelMean * labelMean) * numRows / (numRows - 1)), - new DenseVector(numOfFeatures), - new DenseVector(numOfFeatures)); + new DenseIntDoubleVector(numOfFeatures), + new DenseIntDoubleVector(numOfFeatures)); for (int i = 0; i < summary.f3.size(); i++) { double mean = summary.f3.get(i) / numRows; result.f3.values[i] = mean; @@ -336,9 +350,9 @@ public Tuple5 getResult( } @Override - public Tuple5 merge( - Tuple5 summary1, - Tuple5 summary2) { + public Tuple5 merge( + Tuple5 summary1, + Tuple5 summary2) { if (summary1.f0 == 0) { return summary2; } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/FunctionsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/FunctionsTest.java index c5ddfc4dc..d0472a00d 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/FunctionsTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/FunctionsTest.java @@ -19,8 +19,8 @@ package org.apache.flink.ml; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -60,10 +60,10 @@ public class FunctionsTest extends AbstractTestBase { private static final List longArrays = Arrays.asList(new long[] {0, 0}, new long[] {0, 1}); - private static final List denseVectors = + private static final List denseVectors = Arrays.asList(Vectors.dense(0.0, 0.0), Vectors.dense(0.0, 1.0)); - private static final List sparseVectors = + private static final List sparseVectors = Arrays.asList( Vectors.sparse(2, new int[0], new double[0]), Vectors.sparse(2, new int[] {1}, new double[] {1.0})); @@ -126,7 +126,7 @@ private void testArrayToVector(List array) { assertEquals(outputValues.size(), denseVectors.size()); for (int i = 0; i < denseVectors.size(); i++) { - DenseVector vector = outputValues.get(i).getFieldAs("vector"); + DenseIntDoubleVector vector = outputValues.get(i).getFieldAs("vector"); assertEquals(denseVectors.get(i), vector); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java index 483dc34d8..ea25e0e2e 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/KnnTest.java @@ -23,9 +23,9 @@ import org.apache.flink.ml.classification.knn.Knn; import org.apache.flink.ml.classification.knn.KnnModel; import org.apache.flink.ml.classification.knn.KnnModelData; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.DenseMatrix; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -96,7 +96,7 @@ public void before() { tEnv = StreamTableEnvironment.create(env); Schema schema = Schema.newBuilder() - .column("f0", DataTypes.of(DenseVector.class)) + .column("f0", DataTypes.of(DenseIntDoubleVector.class)) .column("f1", DataTypes.DOUBLE()) .build(); DataStream dataStream = env.fromCollection(trainRows); @@ -179,10 +179,10 @@ public void testInputTypeConversion() throws Exception { trainData = TestUtils.convertDataTypesToSparseInt(tEnv, trainData); predictData = TestUtils.convertDataTypesToSparseInt(tEnv, predictData); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class}, TestUtils.getColumnDataTypes(trainData)); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class}, TestUtils.getColumnDataTypes(predictData)); Knn knn = new Knn(); @@ -232,12 +232,12 @@ public void testGetModelData() throws Exception { KnnModelData data = new KnnModelData( (DenseMatrix) modelRows.get(0).getField(0), - (DenseVector) modelRows.get(0).getField(1), - (DenseVector) modelRows.get(0).getField(2)); + (DenseIntDoubleVector) modelRows.get(0).getField(1), + (DenseIntDoubleVector) modelRows.get(0).getField(2)); Assert.assertNotNull(data); assertEquals(2, data.packedFeatures.numRows()); - assertEquals(data.packedFeatures.numCols(), data.labels.size()); - assertEquals(data.featureNormSquares.size(), data.labels.size()); + assertEquals(data.packedFeatures.numCols(), data.labels.size().intValue()); + assertEquals(data.featureNormSquares.size().intValue(), data.labels.size().intValue()); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java index b2a20eb22..e4e7d5bf9 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java @@ -24,11 +24,11 @@ import org.apache.flink.ml.classification.linearsvc.LinearSVC; import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; import org.apache.flink.ml.classification.linearsvc.LinearSVCModelData; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -93,7 +93,9 @@ public void before() { trainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); } @@ -104,9 +106,11 @@ private void verifyPredictionResult( throws Exception { List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); for (Row predictionRow : predResult) { - DenseVector feature = ((Vector) predictionRow.getField(featuresCol)).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.getField(featuresCol)).toDense(); double prediction = (Double) predictionRow.getField(predictionCol); - DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.getField(rawPredictionCol); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); assertTrue(rawPrediction.get(0) < 0); @@ -196,7 +200,7 @@ public void testFitAndPredict() throws Exception { public void testInputTypeConversion() throws Exception { trainDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, trainDataTable); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class, Integer.class}, TestUtils.getColumnDataTypes(trainDataTable)); LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); @@ -274,7 +278,9 @@ public void testMoreSubtaskThanData() throws Exception { trainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index 4f6c3fa35..b70c173ea 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -26,12 +26,12 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.types.BasicType; import org.apache.flink.ml.servable.types.DataTypes; @@ -123,7 +123,9 @@ public void before() { binomialTrainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); @@ -131,7 +133,7 @@ public void before() { binomialTrainData.stream() .map( r -> { - DenseVector features = r.getFieldAs(0); + DenseIntDoubleVector features = r.getFieldAs(0); double label = r.getFieldAs(1); double weight = r.getFieldAs(2); return Row.of(features.toSparse(), label, weight); @@ -143,7 +145,7 @@ public void before() { binomialSparseTrainData, new RowTypeInfo( new TypeInformation[] { - SparseVectorTypeInfo.INSTANCE, + SparseIntDoubleVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, @@ -155,7 +157,9 @@ public void before() { multinomialTrainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); binomialDataDataFrame = @@ -175,9 +179,11 @@ private void verifyPredictionResult( throws Exception { List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); for (Row predictionRow : predResult) { - DenseVector feature = ((Vector) predictionRow.getField(featuresCol)).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.getField(featuresCol)).toDense(); double prediction = (double) predictionRow.getField(predictionCol); - DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.getField(rawPredictionCol); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); assertTrue(rawPrediction.get(0) > 0.5); @@ -195,9 +201,11 @@ private void verifyPredictionResult( int rawPredictionColIndex = output.getIndex(rawPredictionCol); for (org.apache.flink.ml.servable.api.Row predictionRow : output.collect()) { - DenseVector feature = ((Vector) predictionRow.get(featuresColIndex)).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.get(featuresColIndex)).toDense(); double prediction = (double) predictionRow.get(predictionColIndex); - DenseVector rawPrediction = (DenseVector) predictionRow.get(rawPredictionColIndex); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.get(rawPredictionColIndex); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); assertTrue(rawPrediction.get(0) > 0.5); @@ -287,7 +295,7 @@ public void testFitAndPredict() throws Exception { public void testInputTypeConversion() throws Exception { binomialDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, binomialDataTable); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class, Integer.class}, TestUtils.getColumnDataTypes(binomialDataTable)); LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); @@ -435,7 +443,9 @@ public void testMoreSubtaskThanData() throws Exception { binomialTrainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java index 0ac5ae74a..2489d594b 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -29,10 +29,10 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionWithFtrl; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.types.BasicType; import org.apache.flink.ml.servable.types.DataTypes; @@ -131,7 +131,7 @@ public void before() { testRows, new RowTypeInfo( new TypeInformation[] { - SparseVectorTypeInfo.INSTANCE, + SparseIntDoubleVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, @@ -359,9 +359,11 @@ private void verifyPredictionResult( throws Exception { List predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); for (Row predictionRow : predResult) { - DenseVector feature = ((Vector) predictionRow.getField(featuresCol)).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.getField(featuresCol)).toDense(); double prediction = (double) predictionRow.getField(predictionCol); - DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.getField(rawPredictionCol); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); assertTrue(rawPrediction.get(0) > 0.5); @@ -379,9 +381,11 @@ private void verifyPredictionResult( int rawPredictionColIndex = output.getIndex(rawPredictionCol); for (org.apache.flink.ml.servable.api.Row predictionRow : output.collect()) { - DenseVector feature = ((Vector) predictionRow.get(featuresColIndex)).toDense(); + DenseIntDoubleVector feature = + ((IntDoubleVector) predictionRow.get(featuresColIndex)).toDense(); double prediction = (double) predictionRow.get(predictionColIndex); - DenseVector rawPrediction = (DenseVector) predictionRow.get(rawPredictionColIndex); + DenseIntDoubleVector rawPrediction = + (DenseIntDoubleVector) predictionRow.get(rawPredictionColIndex); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); assertTrue(rawPrediction.get(0) > 0.5); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java index cac6e4c91..f02a04482 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java @@ -21,8 +21,8 @@ import org.apache.flink.ml.classification.naivebayes.NaiveBayes; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -56,7 +56,7 @@ public class NaiveBayesTest extends AbstractTestBase { private StreamTableEnvironment tEnv; private Table trainTable; private Table predictTable; - private Map expectedOutput; + private Map expectedOutput; private NaiveBayes estimator; @Before @@ -82,7 +82,7 @@ public void before() { predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("features"); expectedOutput = - new HashMap() { + new HashMap() { { put(Vectors.dense(0, 1.), 11.0); put(Vectors.dense(0, 0.), 11.0); @@ -109,13 +109,13 @@ public void before() { * @param predictionCol Name of the column in the table that contains the prediction result * @return A map containing the collected results */ - private static Map executeAndCollect( + private static Map executeAndCollect( Table table, String featuresCol, String predictionCol) { - Map map = new HashMap<>(); + Map map = new HashMap<>(); for (CloseableIterator it = table.execute().collect(); it.hasNext(); ) { Row row = it.next(); map.put( - ((Vector) row.getField(featuresCol)).toDense(), + ((IntDoubleVector) row.getField(featuresCol)).toDense(), (Double) row.getField(predictionCol)); } @@ -159,7 +159,7 @@ public void testParam() { public void testFitAndPredict() { NaiveBayesModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } @@ -170,14 +170,15 @@ public void testInputTypeConversion() { predictTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictTable); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class}, TestUtils.getColumnDataTypes(trainTable)); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(predictTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(predictTable)); NaiveBayesModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } @@ -194,7 +195,7 @@ public void testOutputSchema() { NaiveBayesModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } @@ -282,7 +283,7 @@ public void testSaveLoad() throws Exception { Table outputTable = model.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } @@ -319,7 +320,7 @@ public void testSetModelData() { Table outputTable = modelB.transform(predictTable)[0]; - Map actualOutput = + Map actualOutput = executeAndCollect(outputTable, modelB.getFeaturesCol(), modelB.getPredictionCol()); assertEquals(expectedOutput, actualOutput); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java index 3eb2bb047..035d8c791 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java @@ -33,8 +33,8 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.InMemorySinkFunction; import org.apache.flink.ml.util.InMemorySourceFunction; @@ -212,7 +212,8 @@ public void before() throws Exception { trainDenseSource, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(DenseVector.class), Types.DOUBLE + TypeInformation.of(DenseIntDoubleVector.class), + Types.DOUBLE }, new String[] {"features", "label"}))); @@ -222,7 +223,8 @@ public void before() throws Exception { predictDenseSource, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(DenseVector.class), Types.DOUBLE + TypeInformation.of(DenseIntDoubleVector.class), + Types.DOUBLE }, new String[] {"features", "label"}))); @@ -232,7 +234,7 @@ public void before() throws Exception { trainSparseSource, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(SparseVector.class), + TypeInformation.of(SparseIntDoubleVector.class), Types.DOUBLE, Types.DOUBLE }, @@ -244,7 +246,8 @@ public void before() throws Exception { predictSparseSource, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(SparseVector.class), Types.DOUBLE + TypeInformation.of(SparseIntDoubleVector.class), + Types.DOUBLE }, new String[] {"features", "label"}))); @@ -252,7 +255,7 @@ public void before() throws Exception { tEnv.fromDataStream( env.fromElements( Row.of( - new DenseVector( + new DenseIntDoubleVector( new double[] { 0.41233679404769874, -0.18088118293232122 }), @@ -263,7 +266,7 @@ public void before() throws Exception { tEnv.fromDataStream( env.fromElements( Row.of( - new DenseVector( + new DenseIntDoubleVector( new double[] { 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01 @@ -334,7 +337,7 @@ private void waitModelDataUpdate(JobID jobID) throws InterruptedException { * * @param expectedRawInfo A list containing sets of expected result RawInfo. */ - private void predictAndAssert(List expectedRawInfo, boolean isSparse) + private void predictAndAssert(List expectedRawInfo, boolean isSparse) throws Exception { if (isSparse) { predictSparseSource.addAll(PREDICT_SPARSE_ROWS); @@ -343,7 +346,7 @@ private void predictAndAssert(List expectedRawInfo, boolean isSpars } List rawResult = outputSink.poll(isSparse ? PREDICT_SPARSE_ROWS.length : PREDICT_DENSE_ROWS.length); - List resultDetail = new ArrayList<>(rawResult.size()); + List resultDetail = new ArrayList<>(rawResult.size()); for (Row row : rawResult) { resultDetail.add(row.getFieldAs(3)); } @@ -416,14 +419,18 @@ public void testParam() { @Test public void testDenseFitAndPredict() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.04481034155642882, 0.9551896584435712}), - new DenseVector(new double[] {0.5353966697318491, 0.4646033302681509})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.04481034155642882, 0.9551896584435712}), + new DenseIntDoubleVector( + new double[] {0.5353966697318491, 0.4646033302681509})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.013104324065967066, 0.9868956759340329}), - new DenseVector(new double[] {0.5095144380001769, 0.49048556199982307})); + new DenseIntDoubleVector( + new double[] {0.013104324065967066, 0.9868956759340329}), + new DenseIntDoubleVector( + new double[] {0.5095144380001769, 0.49048556199982307})); OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression() .setFeaturesCol("features") @@ -450,14 +457,18 @@ public void testDenseFitAndPredict() throws Exception { @Test public void testSparseFitAndPredict() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.4452309884735286, 0.5547690115264714}), - new DenseVector(new double[] {0.5105551725414953, 0.4894448274585047})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.4452309884735286, 0.5547690115264714}), + new DenseIntDoubleVector( + new double[] {0.5105551725414953, 0.4894448274585047})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.40310431554310666, 0.5968956844568933}), - new DenseVector(new double[] {0.5249618837373886, 0.4750381162626114})); + new DenseIntDoubleVector( + new double[] {0.40310431554310666, 0.5968956844568933}), + new DenseIntDoubleVector( + new double[] {0.5249618837373886, 0.4750381162626114})); OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression() .setFeaturesCol("features") @@ -483,14 +494,18 @@ public void testSparseFitAndPredict() throws Exception { @Test public void testFitAndPredictWithWeightCol() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.452491993753382, 0.547508006246618}), - new DenseVector(new double[] {0.5069192929506545, 0.4930807070493455})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.452491993753382, 0.547508006246618}), + new DenseIntDoubleVector( + new double[] {0.5069192929506545, 0.4930807070493455})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.41108882806164193, 0.5889111719383581}), - new DenseVector(new double[] {0.5247727600974581, 0.4752272399025419})); + new DenseIntDoubleVector( + new double[] {0.41108882806164193, 0.5889111719383581}), + new DenseIntDoubleVector( + new double[] {0.5247727600974581, 0.4752272399025419})); OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression() .setFeaturesCol("features") @@ -521,20 +536,24 @@ public void testGenerateRandomModelData() throws Exception { LogisticRegressionModelDataUtil.generateRandomModelData(tEnv, 2, 2022); DataStream modelData = tEnv.toDataStream(modelDataTable); Row modelRow = (Row) IteratorUtils.toList(modelData.executeAndCollect()).get(0); - Assert.assertEquals(2, ((DenseVector) modelRow.getField(0)).size()); + Assert.assertEquals(2, ((DenseIntDoubleVector) modelRow.getField(0)).size().intValue()); Assert.assertEquals(0L, modelRow.getField(1)); } @Test public void testInitWithLogisticRegression() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.037327343811250024, 0.96267265618875}), - new DenseVector(new double[] {0.5684728224189707, 0.4315271775810293})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.037327343811250024, 0.96267265618875}), + new DenseIntDoubleVector( + new double[] {0.5684728224189707, 0.4315271775810293})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.007758574555505882, 0.9922414254444941}), - new DenseVector(new double[] {0.5257216567388069, 0.4742783432611931})); + new DenseIntDoubleVector( + new double[] {0.007758574555505882, 0.9922414254444941}), + new DenseIntDoubleVector( + new double[] {0.5257216567388069, 0.4742783432611931})); LogisticRegression logisticRegression = new LogisticRegression() .setLabelCol("label") @@ -590,14 +609,18 @@ public void testBatchSizeLessThanParallelism() { @Test public void testSaveAndReload() throws Exception { - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.04481034155642882, 0.9551896584435712}), - new DenseVector(new double[] {0.5353966697318491, 0.4646033302681509})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.04481034155642882, 0.9551896584435712}), + new DenseIntDoubleVector( + new double[] {0.5353966697318491, 0.4646033302681509})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.013104324065967066, 0.9868956759340329}), - new DenseVector(new double[] {0.5095144380001769, 0.49048556199982307})); + new DenseIntDoubleVector( + new double[] {0.013104324065967066, 0.9868956759340329}), + new DenseIntDoubleVector( + new double[] {0.5095144380001769, 0.49048556199982307})); OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression() .setFeaturesCol("features") @@ -652,7 +675,8 @@ public void testGetModelData() throws Exception { LogisticRegressionModelData expectedModelData = new LogisticRegressionModelData( - new DenseVector(new double[] {0.2994527071464283, -0.1412541067743284}), + new DenseIntDoubleVector( + new double[] {0.2994527071464283, -0.1412541067743284}), 1L); Assert.assertArrayEquals( expectedModelData.coefficient.values, actualModelData.coefficient.values, 1e-5); @@ -662,19 +686,25 @@ public void testGetModelData() throws Exception { @Test public void testSetModelData() throws Exception { LogisticRegressionModelData modelData1 = - new LogisticRegressionModelData(new DenseVector(new double[] {0.085, -0.22}), 1L); + new LogisticRegressionModelData( + new DenseIntDoubleVector(new double[] {0.085, -0.22}), 1L); LogisticRegressionModelData modelData2 = - new LogisticRegressionModelData(new DenseVector(new double[] {0.075, -0.28}), 2L); + new LogisticRegressionModelData( + new DenseIntDoubleVector(new double[] {0.075, -0.28}), 2L); - final List expectedRawInfo1 = + final List expectedRawInfo1 = Arrays.asList( - new DenseVector(new double[] {0.6285496932692606, 0.3714503067307394}), - new DenseVector(new double[] {0.7588710471221473, 0.24112895287785274})); - final List expectedRawInfo2 = + new DenseIntDoubleVector( + new double[] {0.6285496932692606, 0.3714503067307394}), + new DenseIntDoubleVector( + new double[] {0.7588710471221473, 0.24112895287785274})); + final List expectedRawInfo2 = Arrays.asList( - new DenseVector(new double[] {0.6673003248270917, 0.3326996751729083}), - new DenseVector(new double[] {0.8779865510655934, 0.12201344893440658})); + new DenseIntDoubleVector( + new double[] {0.6673003248270917, 0.3326996751729083}), + new DenseIntDoubleVector( + new double[] {0.8779865510655934, 0.12201344893440658})); InMemorySourceFunction modelDataSource = new InMemorySourceFunction<>(); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java index 292e793ed..0f5143250 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/AgglomerativeClusteringTest.java @@ -31,9 +31,9 @@ import org.apache.flink.ml.common.window.EventTimeTumblingWindows; import org.apache.flink.ml.common.window.GlobalWindows; import org.apache.flink.ml.common.window.ProcessingTimeTumblingWindows; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -71,7 +71,7 @@ public class AgglomerativeClusteringTest extends AbstractTestBase { private StreamExecutionEnvironment env; private Table inputDataTable; - private static final List INPUT_DATA = + private static final List INPUT_DATA = Arrays.asList( Vectors.dense(1, 1), Vectors.dense(1, 4), @@ -97,7 +97,7 @@ public class AgglomerativeClusteringTest extends AbstractTestBase { private static final double[] EUCLIDEAN_COMPLETE_MERGE_DISTANCES = new double[] {1, 1.5, 3, 3.3541019, 5}; - private static final List> EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT = + private static final List> EUCLIDEAN_WARD_NUM_CLUSTERS_AS_TWO_RESULT = Arrays.asList( new HashSet<>( Arrays.asList( @@ -107,38 +107,42 @@ public class AgglomerativeClusteringTest extends AbstractTestBase { Vectors.dense(4, 0))), new HashSet<>(Arrays.asList(Vectors.dense(1, 4), Vectors.dense(4, 4)))); - private static final List> EUCLIDEAN_WARD_THRESHOLD_AS_TWO_RESULT = + private static final List> EUCLIDEAN_WARD_THRESHOLD_AS_TWO_RESULT = Arrays.asList( new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), new HashSet<>(Collections.singletonList(Vectors.dense(1, 4))), new HashSet<>(Collections.singletonList(Vectors.dense(4, 4))), new HashSet<>(Arrays.asList(Vectors.dense(4, 1.5), Vectors.dense(4, 0)))); - private static final List> EUCLIDEAN_WARD_COUNT_FIVE_WINDOW_AS_TWO_RESULT = - Arrays.asList( - new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), - new HashSet<>( - Arrays.asList( - Vectors.dense(1, 4), - Vectors.dense(4, 4), - Vectors.dense(4, 1.5)))); - - private static final List> EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT = - Arrays.asList( - new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), - new HashSet<>(Collections.singletonList(Vectors.dense(1, 4))), - new HashSet<>(Arrays.asList(Vectors.dense(4, 0), Vectors.dense(4, 1.5))), - new HashSet<>(Collections.singletonList(Vectors.dense(4, 4)))); - - private static final List> EUCLIDEAN_AVERAGE_NUM_CLUSTERS_AS_TWO_RESULT = - Arrays.asList( - new HashSet<>( - Arrays.asList( - Vectors.dense(1, 1), - Vectors.dense(1, 0), - Vectors.dense(4, 1.5), - Vectors.dense(4, 0))), - new HashSet<>(Arrays.asList(Vectors.dense(1, 4), Vectors.dense(4, 4)))); + private static final List> + EUCLIDEAN_WARD_COUNT_FIVE_WINDOW_AS_TWO_RESULT = + Arrays.asList( + new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), + new HashSet<>( + Arrays.asList( + Vectors.dense(1, 4), + Vectors.dense(4, 4), + Vectors.dense(4, 1.5)))); + + private static final List> + EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT = + Arrays.asList( + new HashSet<>(Arrays.asList(Vectors.dense(1, 1), Vectors.dense(1, 0))), + new HashSet<>(Collections.singletonList(Vectors.dense(1, 4))), + new HashSet<>( + Arrays.asList(Vectors.dense(4, 0), Vectors.dense(4, 1.5))), + new HashSet<>(Collections.singletonList(Vectors.dense(4, 4)))); + + private static final List> + EUCLIDEAN_AVERAGE_NUM_CLUSTERS_AS_TWO_RESULT = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(1, 1), + Vectors.dense(1, 0), + Vectors.dense(4, 1.5), + Vectors.dense(4, 0))), + new HashSet<>(Arrays.asList(Vectors.dense(1, 4), Vectors.dense(4, 4)))); private static final double TOLERANCE = 1e-7; @@ -303,18 +307,22 @@ public void testTransformWithCountTumblingWindows() throws Exception { public void testTransformWithEventTimeTumblingWindows() throws Exception { RowTypeInfo outputTypeInfo = new RowTypeInfo( - new TypeInformation[] {DenseVectorTypeInfo.INSTANCE, Types.INSTANT}, + new TypeInformation[] { + DenseIntDoubleVectorTypeInfo.INSTANCE, Types.INSTANT + }, new String[] {"features", "ts"}); Instant baseTime = Instant.now(); DataStream inputDataStream = env.fromCollection(INPUT_DATA) .setParallelism(1) - .map(x -> Row.of(x, baseTime.plusSeconds((long) x.get(0))), outputTypeInfo); + .map( + x -> Row.of(x, baseTime.plusSeconds((long) x.get(0).doubleValue())), + outputTypeInfo); Schema schema = Schema.newBuilder() - .column("features", DataTypes.of(DenseVectorTypeInfo.INSTANCE)) + .column("features", DataTypes.of(DenseIntDoubleVectorTypeInfo.INSTANCE)) .column("ts", DataTypes.TIMESTAMP_LTZ(3)) .watermark("ts", "ts - INTERVAL '5' SECOND") .build(); @@ -331,16 +339,17 @@ public void testTransformWithEventTimeTumblingWindows() throws Exception { Table[] outputs = agglomerativeClustering.transform(inputDataTable); List output = IteratorUtils.toList(tEnv.toDataStream(outputs[0]).executeAndCollect()); - List> actualGroups = + List> actualGroups = KMeansTest.groupFeaturesByPrediction( output, agglomerativeClustering.getFeaturesCol(), agglomerativeClustering.getPredictionCol()); boolean isAllSubSet = true; - for (Set expectedSet : EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT) { + for (Set expectedSet : + EUCLIDEAN_WARD_EVENT_TIME_WINDOW_AS_TWO_RESULT) { boolean isSubset = false; - for (Set actualSet : actualGroups) { + for (Set actualSet : actualGroups) { if (actualSet.containsAll(expectedSet)) { isSubset = true; break; @@ -440,13 +449,13 @@ private void verifyMergeInfo(double[] expectedDistances, Table mergeInfoTable) @SuppressWarnings("unchecked") public void verifyClusteringResult( - List> expected, + List> expected, Table outputTable, String featureCol, String predictionCol) throws Exception { List output = IteratorUtils.toList(tEnv.toDataStream(outputTable).executeAndCollect()); - List> actualGroups = + List> actualGroups = KMeansTest.groupFeaturesByPrediction(output, featureCol, predictionCol); assertTrue(CollectionUtils.isEqualCollection(expected, actualGroups)); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java index 8f1961726..e6fbd704b 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java @@ -22,9 +22,9 @@ import org.apache.flink.ml.clustering.kmeans.KMeansModel; import org.apache.flink.ml.clustering.kmeans.KMeansModelData; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -60,7 +60,7 @@ public class KMeansTest extends AbstractTestBase { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); - private static final List DATA = + private static final List DATA = Arrays.asList( Vectors.dense(0.0, 0.0), Vectors.dense(0.0, 0.3), @@ -70,7 +70,7 @@ public class KMeansTest extends AbstractTestBase { Vectors.dense(9.6, 0.0)); private StreamExecutionEnvironment env; private StreamTableEnvironment tEnv; - private static final List> expectedGroups = + private static final List> expectedGroups = Arrays.asList( new HashSet<>( Arrays.asList( @@ -100,11 +100,11 @@ public void before() { * @param predictionCol Name of the column in the table that contains the prediction result * @return A map containing the collected results */ - protected static List> groupFeaturesByPrediction( + protected static List> groupFeaturesByPrediction( List rows, String featuresCol, String predictionCol) { - Map> map = new HashMap<>(); + Map> map = new HashMap<>(); for (Row row : rows) { - DenseVector vector = ((Vector) row.getField(featuresCol)).toDense(); + DenseIntDoubleVector vector = ((IntDoubleVector) row.getField(featuresCol)).toDense(); int predict = (Integer) row.getField(predictionCol); map.putIfAbsent(predict, new HashSet<>()); map.get(predict).add(vector); @@ -152,7 +152,7 @@ public void testOutputSchema() { @Test public void testFewerDistinctPointsThanCluster() { - List data = + List data = Arrays.asList( Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1)); @@ -162,10 +162,10 @@ public void testFewerDistinctPointsThanCluster() { KMeansModel model = kmeans.fit(input); Table output = model.transform(input)[0]; - List> expectedGroups = + List> expectedGroups = Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1))); List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); @@ -181,7 +181,7 @@ public void testFitAndPredict() { Arrays.asList("features", "prediction"), output.getResolvedSchema().getColumnNames()); List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); @@ -191,7 +191,8 @@ public void testFitAndPredict() { public void testInputTypeConversion() { dataTable = TestUtils.convertDataTypesToSparseInt(tEnv, dataTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(dataTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(dataTable)); KMeans kmeans = new KMeans().setMaxIter(2).setK(2); KMeansModel model = kmeans.fit(dataTable); @@ -201,7 +202,7 @@ public void testInputTypeConversion() { Arrays.asList("features", "prediction"), output.getResolvedSchema().getColumnNames()); List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); @@ -226,7 +227,7 @@ public void testSaveLoadAndPredict() throws Exception { output.getResolvedSchema().getColumnNames()); List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); @@ -245,7 +246,7 @@ public void testGetModelData() throws Exception { List collectedModelData = IteratorUtils.toList(modelData.executeAndCollect()); assertEquals(1, collectedModelData.size()); - DenseVector[] centroids = collectedModelData.get(0).centroids; + DenseIntDoubleVector[] centroids = collectedModelData.get(0).centroids; assertEquals(2, centroids.length); Arrays.sort(centroids, Comparator.comparingDouble(vector -> vector.get(0))); assertArrayEquals(centroids[0].values, new double[] {0.1, 0.1}, 1e-5); @@ -261,7 +262,7 @@ public void testSetModelData() { Table output = modelB.transform(dataTable)[0]; List results = IteratorUtils.toList(output.execute().collect()); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction( results, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java index d85ab3510..0591403b2 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java @@ -31,10 +31,10 @@ import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.InMemorySinkFunction; import org.apache.flink.ml.util.InMemorySourceFunction; import org.apache.flink.ml.util.TestUtils; @@ -77,8 +77,8 @@ public class OnlineKMeansTest extends TestLogger { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); - private static final DenseVector[] trainData1 = - new DenseVector[] { + private static final DenseIntDoubleVector[] trainData1 = + new DenseIntDoubleVector[] { Vectors.dense(10.0, 0.0), Vectors.dense(10.0, 0.3), Vectors.dense(10.3, 0.0), @@ -86,8 +86,8 @@ public class OnlineKMeansTest extends TestLogger { Vectors.dense(-10.0, 0.6), Vectors.dense(-10.6, 0.0) }; - private static final DenseVector[] trainData2 = - new DenseVector[] { + private static final DenseIntDoubleVector[] trainData2 = + new DenseIntDoubleVector[] { Vectors.dense(10.0, 100.0), Vectors.dense(10.0, 100.3), Vectors.dense(10.3, 100.0), @@ -95,8 +95,8 @@ public class OnlineKMeansTest extends TestLogger { Vectors.dense(-10.0, -100.6), Vectors.dense(-10.6, -100.0) }; - private static final DenseVector[] predictData = - new DenseVector[] { + private static final DenseIntDoubleVector[] predictData = + new DenseIntDoubleVector[] { Vectors.dense(10.0, 10.0), Vectors.dense(10.3, 10.0), Vectors.dense(10.0, 10.3), @@ -104,7 +104,7 @@ public class OnlineKMeansTest extends TestLogger { Vectors.dense(-10.3, 10.0), Vectors.dense(-10.0, 10.3) }; - private static final List> expectedGroups1 = + private static final List> expectedGroups1 = Arrays.asList( new HashSet<>( Arrays.asList( @@ -116,7 +116,7 @@ public class OnlineKMeansTest extends TestLogger { Vectors.dense(-10.0, 10.0), Vectors.dense(-10.3, 10.0), Vectors.dense(-10.0, 10.3)))); - private static final List> expectedGroups2 = + private static final List> expectedGroups2 = Collections.singletonList( new HashSet<>( Arrays.asList( @@ -133,8 +133,8 @@ public class OnlineKMeansTest extends TestLogger { private int currentModelDataVersion; - private InMemorySourceFunction trainSource; - private InMemorySourceFunction predictSource; + private InMemorySourceFunction trainSource; + private InMemorySourceFunction predictSource; private InMemorySinkFunction outputSink; private InMemorySinkFunction modelDataSink; @@ -179,10 +179,12 @@ public void before() throws Exception { offlineTrainTable = tEnv.fromDataStream(env.fromElements(trainData1)).as("features"); onlineTrainTable = - tEnv.fromDataStream(env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE)) + tEnv.fromDataStream( + env.addSource(trainSource, DenseIntDoubleVectorTypeInfo.INSTANCE)) .as("features"); onlinePredictTable = - tEnv.fromDataStream(env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE)) + tEnv.fromDataStream( + env.addSource(predictSource, DenseIntDoubleVectorTypeInfo.INSTANCE)) .as("features"); } @@ -247,11 +249,13 @@ private void waitModelDataUpdate(JobID jobID) throws InterruptedException { * @param predictionCol Name of the column in the table that contains the prediction result */ private void predictAndAssert( - List> expectedGroups, String featuresCol, String predictionCol) + List> expectedGroups, + String featuresCol, + String predictionCol) throws Exception { predictSource.addAll(OnlineKMeansTest.predictData); List rawResult = outputSink.poll(OnlineKMeansTest.predictData.length); - List> actualGroups = + List> actualGroups = groupFeaturesByPrediction(rawResult, featuresCol, predictionCol); Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); } @@ -324,13 +328,13 @@ public void testInputTypeConversion() throws Exception { onlinePredictTable = TestUtils.convertDataTypesToSparseInt(tEnv, onlinePredictTable); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(offlineTrainTable)); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(onlineTrainTable)); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(onlinePredictTable)); OnlineKMeans onlineKMeans = @@ -406,7 +410,7 @@ public void testDecayFactor() throws Exception { KMeansModelData expectedModelData = new KMeansModelData( - new DenseVector[] { + new DenseIntDoubleVector[] { Vectors.dense(-10.2, -200.2 / 3), Vectors.dense(10.1, 200.3 / 3) }, Vectors.dense(4.5, 4.5)); @@ -502,7 +506,9 @@ public void testGetModelData() throws Exception { KMeansModelData expectedModelData = new KMeansModelData( - new DenseVector[] {Vectors.dense(-10.2, 0.2), Vectors.dense(10.1, 0.1)}, + new DenseIntDoubleVector[] { + Vectors.dense(-10.2, 0.2), Vectors.dense(10.1, 0.1) + }, Vectors.dense(3, 3)); assertArrayEquals(expectedModelData.weights.values, actualModelData.weights.values, 1e-5); @@ -520,12 +526,14 @@ public void testGetModelData() throws Exception { public void testSetModelData() throws Exception { KMeansModelData modelData1 = new KMeansModelData( - new DenseVector[] {Vectors.dense(10.1, 0.1), Vectors.dense(-10.2, 0.2)}, + new DenseIntDoubleVector[] { + Vectors.dense(10.1, 0.1), Vectors.dense(-10.2, 0.2) + }, Vectors.dense(0.0, 0.0)); KMeansModelData modelData2 = new KMeansModelData( - new DenseVector[] { + new DenseIntDoubleVector[] { Vectors.dense(10.1, 100.1), Vectors.dense(-10.2, -100.2) }, Vectors.dense(0.0, 0.0)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java index 22ce2deba..19d403548 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.common.lossfunc; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.junit.Test; @@ -31,8 +31,8 @@ public class BinaryLogisticLossTest { private static final LabeledPointWithWeight dataPoint = new LabeledPointWithWeight(Vectors.dense(1.0, 2.0, 3.0), 1.0, 2.0); - private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); - private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final DenseIntDoubleVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseIntDoubleVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); private static final double TOLERANCE = 1e-7; @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java index 1cd165ecf..a759962c7 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.common.lossfunc; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.junit.Test; @@ -33,8 +33,8 @@ public class HingeLossTest { new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, -1.0), 1.0, 2.0); private static final LabeledPointWithWeight dataPoint2 = new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, 1.0), 1.0, 2.0); - private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); - private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final DenseIntDoubleVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseIntDoubleVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); private static final double TOLERANCE = 1e-7; @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java index ee2d03021..e420662db 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.common.lossfunc; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.junit.Test; @@ -31,8 +31,8 @@ public class LeastSquareLossTest { private static final LabeledPointWithWeight dataPoint = new LabeledPointWithWeight(Vectors.dense(1.0, 2.0, 3.0), 1.0, 2.0); - private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); - private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final DenseIntDoubleVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseIntDoubleVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); private static final double TOLERANCE = 1e-7; @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java index 83d688337..2ce653110 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.common.optimizer; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.commons.lang3.RandomUtils; import org.junit.Test; @@ -29,7 +29,8 @@ public class RegularizationUtilsTest { private static final double learningRate = 0.1; private static final double TOLERANCE = 1e-7; - private static final DenseVector coefficient = new DenseVector(new double[] {1.0, -2.0, 0}); + private static final DenseIntDoubleVector coefficient = + new DenseIntDoubleVector(new double[] {1.0, -2.0, 0}); @Test public void testRegularization() { @@ -40,7 +41,7 @@ public void testRegularization() { } private void checkRegularization(double reg, double elasticNet, double[] expectedCoefficient) { - DenseVector clonedCoefficient = coefficient.clone(); + DenseIntDoubleVector clonedCoefficient = coefficient.clone(); RegularizationUtils.regularize(clonedCoefficient, reg, elasticNet, learningRate); assertArrayEquals(expectedCoefficient, clonedCoefficient.values, TOLERANCE); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java index 027be970c..521f67b21 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java @@ -18,8 +18,8 @@ package org.apache.flink.ml.common.util; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.junit.Test; @@ -33,13 +33,13 @@ public class VectorUtilsTest { @Test public void testSelectByIndices() { - DenseVector denseVector = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0); + DenseIntDoubleVector denseVector = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0); assertArrayEquals( Vectors.dense(2.0, 4.0).toArray(), VectorUtils.selectByIndices(denseVector, new int[] {1, 3}).toArray(), EPS); - SparseVector sparseVector = + SparseIntDoubleVector sparseVector = Vectors.sparse(5, new int[] {1, 2, 3}, new double[] {2.0, 3.0, 4.0}); assertArrayEquals( Vectors.sparse(3, new int[] {1, 2}, new double[] {2.0, 4.0}).toArray(), diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java index 0c146a3d2..7a2f90a10 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/evaluation/BinaryClassificationEvaluatorTest.java @@ -20,7 +20,7 @@ import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluator; import org.apache.flink.ml.evaluation.binaryclassification.BinaryClassificationEvaluatorParams; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -191,7 +191,7 @@ public void testEvaluate() { public void testInputTypeConversion() { inputDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, inputDataTable); assertArrayEquals( - new Class[] {Integer.class, SparseVector.class}, + new Class[] {Integer.class, SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(inputDataTable)); BinaryClassificationEvaluator eval = diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java index eb06b0c6e..17a502972 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/BinarizerTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.binarizer.Binarizer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -65,11 +65,11 @@ public class BinarizerTest extends AbstractTestBase { private static final Double[] EXPECTED_VALUE_OUTPUT = new Double[] {0.0, 1.0, 1.0}; - private static final List EXPECTED_DENSE_OUTPUT = + private static final List EXPECTED_DENSE_OUTPUT = Arrays.asList( Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0), Vectors.dense(1.0, 1.0)); - private static final List EXPECTED_SPARSE_OUTPUT = + private static final List EXPECTED_SPARSE_OUTPUT = Arrays.asList( Vectors.sparse(17, new int[] {9}, new double[] {1.0}), Vectors.sparse(17, new int[] {0, 2}, new double[] {1.0, 1.0}), @@ -90,8 +90,8 @@ private void verifyOutputResult(Table output, String[] outputCols) throws Except List results = IteratorUtils.toList(stream.executeAndCollect()); List doubleValues = new ArrayList<>(results.size()); - List sparseVectorValues = new ArrayList<>(results.size()); - List denseVectorValues = new ArrayList<>(results.size()); + List sparseVectorValues = new ArrayList<>(results.size()); + List denseVectorValues = new ArrayList<>(results.size()); for (Row row : results) { doubleValues.add(row.getFieldAs(outputCols[0])); denseVectorValues.add(row.getFieldAs(outputCols[1])); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java index 32d58ccc4..c751e604b 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.feature.countvectorizer.CountVectorizer; import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -68,7 +68,7 @@ public class CountVectorizerTest extends AbstractTestBase { Row.of((Object) new String[] {"e", "f"}), Row.of((Object) new String[] {"a", "c", "a"}))); - private static final List EXPECTED_OUTPUT = + private static final List EXPECTED_OUTPUT = new ArrayList<>( Arrays.asList( Vectors.sparse( @@ -101,15 +101,15 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expected) throws Exception { + Table output, String outputCol, List expected) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); - DataStream stream = + DataStream stream = tEnv.toDataStream(output) .map( - (MapFunction) - row -> (SparseVector) row.getField(outputCol)); - List result = IteratorUtils.toList(stream.executeAndCollect()); + (MapFunction) + row -> (SparseIntDoubleVector) row.getField(outputCol)); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expected, result, TestUtils::compare); } @@ -244,7 +244,7 @@ public void testFitOnEmptyData() { @Test public void testMinMaxDF() throws Exception { - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.sparse( @@ -280,7 +280,7 @@ public void testMinMaxDF() throws Exception { @Test public void testMinTF() throws Exception { - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.sparse( @@ -305,7 +305,7 @@ public void testMinTF() throws Exception { @Test public void testBinary() throws Exception { - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.sparse( @@ -336,7 +336,7 @@ public void testBinary() throws Exception { @Test public void testVocabularySize() throws Exception { - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.sparse( diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java index 36baea68d..7c1f71d0b 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/DCTTest.java @@ -19,8 +19,8 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.dct.DCT; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -48,7 +48,7 @@ public class DCTTest extends AbstractTestBase { private StreamExecutionEnvironment env; private StreamTableEnvironment tEnv; - private static final List inputData = + private static final List inputData = Arrays.asList(Vectors.dense(1.0, 1.0, 1.0, 1.0), Vectors.dense(1.0, 0.0, -1.0, 0.0)); private static final List expectedForwardOutputData = @@ -126,7 +126,8 @@ public void testTransformInverse() { public void testInputTypeConversion() throws Exception { inputTable = TestUtils.convertDataTypesToSparseInt(tEnv, inputTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(inputTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(inputTable)); DCT dct = new DCT(); Table outputTable = dct.transform(inputTable)[0]; @@ -159,21 +160,21 @@ private static void verifyTransformResult( actualOutputData.sort( Comparator.comparingLong( x -> - ((Vector) Objects.requireNonNull(x.getField(inputCol))) + ((IntDoubleVector) Objects.requireNonNull(x.getField(inputCol))) .toDense() .hashCode())); expectedOutputData.sort( Comparator.comparingLong( x -> - ((Vector) Objects.requireNonNull(x.getField(0))) + ((IntDoubleVector) Objects.requireNonNull(x.getField(0))) .toDense() .hashCode())); assertEquals(actualOutputData.size(), expectedOutputData.size()); for (int i = 0; i < actualOutputData.size(); i++) { - Vector actualVector = actualOutputData.get(i).getFieldAs(outputCol); - Vector expectedVector = expectedOutputData.get(i).getFieldAs(1); + IntDoubleVector actualVector = actualOutputData.get(i).getFieldAs(outputCol); + IntDoubleVector expectedVector = expectedOutputData.get(i).getFieldAs(1); assertArrayEquals(expectedVector.toArray(), actualVector.toArray(), 1e-3); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java index 4b55f0dde..894b510e4 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ElementwiseProductTest.java @@ -19,8 +19,8 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.elementwiseproduct.ElementwiseProduct; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -90,28 +90,30 @@ private void verifyOutputResult(Table output, String outputCol, boolean isSparse for (Row result : results) { if (result.getField(0) == (Object) 0) { if (isSparse) { - SparseVector sparseVector = (SparseVector) result.getField(outputCol); - assertEquals(EXPECTED_OUTPUT_SPARSE_VEC_SIZE_1, sparseVector.size()); + SparseIntDoubleVector sparseVector = + (SparseIntDoubleVector) result.getField(outputCol); + assertEquals(EXPECTED_OUTPUT_SPARSE_VEC_SIZE_1, sparseVector.size().intValue()); assertArrayEquals(EXPECTED_OUTPUT_SPARSE_VEC_INDICES_1, sparseVector.indices); assertArrayEquals( EXPECTED_OUTPUT_SPARSE_VEC_VALUES_1, sparseVector.values, 1.0e-5); } else { assertArrayEquals( EXPECTED_OUTPUT_DENSE_VEC_ARRAY_1, - ((DenseVector) result.getField(outputCol)).values, + ((DenseIntDoubleVector) result.getField(outputCol)).values, 1.0e-5); } } else if (result.getField(0) == (Object) 1) { if (isSparse) { - SparseVector sparseVector = (SparseVector) result.getField(outputCol); - assertEquals(EXPECTED_OUTPUT_SPARSE_VEC_SIZE_2, sparseVector.size()); + SparseIntDoubleVector sparseVector = + (SparseIntDoubleVector) result.getField(outputCol); + assertEquals(EXPECTED_OUTPUT_SPARSE_VEC_SIZE_2, sparseVector.size().intValue()); assertArrayEquals(EXPECTED_OUTPUT_SPARSE_VEC_INDICES_2, sparseVector.indices); assertArrayEquals( EXPECTED_OUTPUT_SPARSE_VEC_VALUES_2, sparseVector.values, 1.0e-5); } else { assertArrayEquals( EXPECTED_OUTPUT_DENSE_VEC_ARRAY_2, - ((DenseVector) result.getField(outputCol)).values, + ((DenseIntDoubleVector) result.getField(outputCol)).values, 1.0e-5); } } else if (result.getField(0) == (Object) 2) { diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java index fae16692d..c636486c6 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.featurehasher.FeatureHasher; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -48,9 +48,9 @@ public class FeatureHasherTest extends AbstractTestBase { private static final List INPUT_DATA = Arrays.asList(Row.of(0, "a", 1.0, true), Row.of(1, "c", 1.0, false)); - private static final SparseVector EXPECTED_OUTPUT_DATA_1 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_1 = Vectors.sparse(1000, new int[] {607, 635, 913}, new double[] {1.0, 1.0, 1.0}); - private static final SparseVector EXPECTED_OUTPUT_DATA_2 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_2 = Vectors.sparse(1000, new int[] {242, 869, 913}, new double[] {1.0, 1.0, 1.0}); @Before diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java index 9bed73ec8..287b5ec7a 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java @@ -21,7 +21,7 @@ import org.apache.flink.ml.feature.idf.IDF; import org.apache.flink.ml.feature.idf.IDFModel; import org.apache.flink.ml.feature.idf.IDFModelData; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -54,12 +54,12 @@ public class IDFTest extends AbstractTestBase { private StreamTableEnvironment tEnv; private Table inputTable; - private static final List expectedOutput = + private static final List expectedOutput = Arrays.asList( Vectors.dense(0, 0, 0, 0.5753641), Vectors.dense(0, 0, 1.3862943, 0.8630462), Vectors.dense(0, 0, 0, 0)); - private static final List expectedOutputMinDocFreqAsTwo = + private static final List expectedOutputMinDocFreqAsTwo = Arrays.asList( Vectors.dense(0, 0, 0, 0.5753641), Vectors.dense(0, 0, 0, 0.8630462), @@ -71,7 +71,7 @@ public void before() { env = TestUtils.getExecutionEnvironment(); tEnv = StreamTableEnvironment.create(env); - List input = + List input = Arrays.asList( Vectors.dense(0, 1, 0, 2), Vectors.dense(0, 1, 2, 3), @@ -81,11 +81,12 @@ public void before() { @SuppressWarnings("unchecked") private void verifyPredictionResult( - List expectedOutput, Table output, String predictionCol) throws Exception { + List expectedOutput, Table output, String predictionCol) + throws Exception { List collectedResult = IteratorUtils.toList( tEnv.toDataStream(output.select($(predictionCol))).executeAndCollect()); - List actualOutputs = new ArrayList<>(expectedOutput.size()); + List actualOutputs = new ArrayList<>(expectedOutput.size()); collectedResult.forEach(x -> actualOutputs.add((x.getFieldAs(0)))); actualOutputs.sort(TestUtils::compare); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java index 0b0cc9957..59a672ac2 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/InteractionTest.java @@ -19,9 +19,9 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.interaction.Interaction; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -63,20 +63,20 @@ public class InteractionTest extends AbstractTestBase { Vectors.sparse(17, new int[] {0, 2, 14}, new double[] {5.0, 4.0, 1.0})), Row.of(3, null, null, null)); - private static final List EXPECTED_DENSE_OUTPUT = + private static final List EXPECTED_DENSE_OUTPUT = Arrays.asList( - new DenseVector(new double[] {3.0, 4.0, 6.0, 8.0}), - new DenseVector(new double[] {12.0, 16.0, 20.0, 48.0, 64.0, 80.0})); + new DenseIntDoubleVector(new double[] {3.0, 4.0, 6.0, 8.0}), + new DenseIntDoubleVector(new double[] {12.0, 16.0, 20.0, 48.0, 64.0, 80.0})); - private static final List EXPECTED_SPARSE_OUTPUT = + private static final List EXPECTED_SPARSE_OUTPUT = Arrays.asList( - new SparseVector( + new SparseIntDoubleVector( 68, new int[] {0, 3, 9, 17, 20, 26, 34, 37, 43, 51, 54, 60}, new double[] { 3.0, 6.0, 21.0, 4.0, 8.0, 28.0, 6.0, 12.0, 42.0, 8.0, 16.0, 56.0 }), - new SparseVector( + new SparseIntDoubleVector( 102, new int[] { 0, 2, 14, 17, 19, 31, 34, 36, 48, 51, 53, 65, 68, 70, 82, 85, 87, 99 @@ -94,14 +94,14 @@ public void before() { inputDataTable = tEnv.fromDataStream(dataStream).as("f0", "f1", "f2", "f3"); } - private void verifyOutputResult(Table output, String outputCol, List expectedData) - throws Exception { + private void verifyOutputResult( + Table output, String outputCol, List expectedData) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); DataStream stream = tEnv.toDataStream(output); List results = IteratorUtils.toList(stream.executeAndCollect()); - List resultVec = new ArrayList<>(results.size()); + List resultVec = new ArrayList<>(results.size()); for (Row row : results) { if (row.getField(outputCol) != null) { resultVec.add(row.getFieldAs(outputCol)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java index e6959baae..99786659c 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java @@ -22,7 +22,7 @@ import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData; import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -138,7 +138,8 @@ private void verifyPredictionResult( collectedResult, (o1, o2) -> TestUtils.compare( - (DenseVector) o1.getField(0), (DenseVector) o2.getField(0))); + (DenseIntDoubleVector) o1.getField(0), + (DenseIntDoubleVector) o2.getField(0))); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java index 6b8edf05e..037115660 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java @@ -21,7 +21,8 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler; import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -94,14 +95,14 @@ public class MaxAbsScalerTest { Row.of(Vectors.sparse(4, new int[] {}, new double[] {})), Row.of(Vectors.sparse(4, new int[] {1, 3}, new double[] {1.0, 2.0})))); - private static final List EXPECTED_DATA = + private static final List EXPECTED_DATA = new ArrayList<>( Arrays.asList( Vectors.dense(0.25, 0.1, 1.0), Vectors.dense(0.5, 0.125, 0.5), Vectors.dense(0.75, 0.225, 1.0))); - private static final List EXPECTED_SPARSE_DATA = + private static final List EXPECTED_SPARSE_DATA = new ArrayList<>( Arrays.asList( Vectors.sparse(4, new int[] {0, 1}, new double[] {1.0, 0.5}), @@ -124,7 +125,7 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expectedData) throws Exception { + Table output, String outputCol, List expectedData) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); @@ -134,7 +135,7 @@ private static void verifyPredictionResult( (MapFunction) row -> row.getFieldAs(outputCol), VectorTypeInfo.INSTANCE); - List result = IteratorUtils.toList(stream.executeAndCollect()); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expectedData, result, TestUtils::compare); } @@ -237,7 +238,8 @@ public void testGetModelData() throws Exception { DataStream output = tEnv.toDataStream(modelData); List modelRows = IteratorUtils.toList(output.executeAndCollect()); assertEquals( - new DenseVector(new double[] {200.0, 400.0, 0.0}), modelRows.get(0).getField(0)); + new DenseIntDoubleVector(new double[] {200.0, 400.0, 0.0}), + modelRows.get(0).getField(0)); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java index d92f062d0..442a4c3a7 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinHashLSHTest.java @@ -21,9 +21,9 @@ import org.apache.flink.ml.feature.lsh.MinHashLSH; import org.apache.flink.ml.feature.lsh.MinHashLSHModel; import org.apache.flink.ml.feature.lsh.MinHashLSHModelData; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -90,10 +90,10 @@ private static List convertToOutputFormat(List arrays) { return arrays.stream() .map( array -> { - DenseVector[] denseVectors = + DenseIntDoubleVector[] denseVectors = Arrays.stream(array) .map(Vectors::dense) - .toArray(DenseVector[]::new); + .toArray(DenseIntDoubleVector[]::new); return Row.of((Object) denseVectors); }) .collect(Collectors.toList()); @@ -133,7 +133,7 @@ public void before() { Schema schema = Schema.newBuilder() .column("f0", DataTypes.INT()) - .column("f1", DataTypes.of(SparseVector.class)) + .column("f1", DataTypes.of(SparseIntDoubleVector.class)) .build(); DataStream dataStream = env.fromCollection(inputRows); @@ -144,8 +144,9 @@ public void before() { public void testHashFunction() { MinHashLSHModelData lshModelData = new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); - Vector vec = Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); - DenseVector[] result = lshModelData.hashFunction(vec); + IntDoubleVector vec = + Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); + DenseIntDoubleVector[] result = lshModelData.hashFunction(vec); Assert.assertEquals(3, result.length); Assert.assertEquals(Vectors.dense(1.), result[0]); Assert.assertEquals(Vectors.dense(5.), result[1]); @@ -158,9 +159,10 @@ public void testHashFunctionEqualWithSparseDenseVector() { // least non-zero index. MinHashLSHModelData lshModelData = MinHashLSHModelData.generateModelData(3, 1, 10, 2022L); new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); - Vector vec = Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); - DenseVector[] denseResults = lshModelData.hashFunction(vec.toDense()); - DenseVector[] sparseResults = lshModelData.hashFunction(vec.toSparse()); + IntDoubleVector vec = + Vectors.sparse(10, new int[] {2, 3, 5, 7}, new double[] {1., 1., 1., 1.}); + DenseIntDoubleVector[] denseResults = lshModelData.hashFunction(vec.toDense()); + DenseIntDoubleVector[] sparseResults = lshModelData.hashFunction(vec.toSparse()); Assert.assertArrayEquals(denseResults, sparseResults); } @@ -168,7 +170,7 @@ public void testHashFunctionEqualWithSparseDenseVector() { public void testHashFunctionWithEmptyVector() { MinHashLSHModelData lshModelData = new MinHashLSHModelData(3, 1, new int[] {0, 1, 3}, new int[] {1, 2, 0}); - Vector vec = Vectors.sparse(10, new int[] {}, new double[] {}); + IntDoubleVector vec = Vectors.sparse(10, new int[] {}, new double[] {}); lshModelData.hashFunction(vec); } @@ -375,7 +377,7 @@ public void testApproxNearestNeighbors() { MinHashLSHModel lshModel = lsh.fit(inputTable); List expected = Arrays.asList(Row.of(0, .75), Row.of(1, .75)); - Vector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1.0, 1.0}); + IntDoubleVector key = Vectors.sparse(6, new int[] {1, 3}, new double[] {1.0, 1.0}); Table output = lshModel.approxNearestNeighbors(inputTable, key, 2).select($("id"), $("distCol")); List results = IteratorUtils.toList(output.execute().collect()); @@ -408,7 +410,7 @@ public void testApproxSimilarityJoin() { Schema schema = Schema.newBuilder() .column("f0", DataTypes.INT()) - .column("f1", DataTypes.of(SparseVector.class)) + .column("f1", DataTypes.of(SparseIntDoubleVector.class)) .build(); Table dataB = tEnv.fromDataStream(env.fromCollection(inputRowsB), schema).as("id", "vec"); @@ -426,9 +428,9 @@ public void testApproxSimilarityJoin() { .thenComparingDouble(r -> r.getFieldAs(2))); } - private static class DenseVectorArrayComparator implements Comparator { + private static class DenseVectorArrayComparator implements Comparator { @Override - public int compare(DenseVector[] o1, DenseVector[] o2) { + public int compare(DenseIntDoubleVector[] o1, DenseIntDoubleVector[] o2) { if (o1.length != o2.length) { return o1.length - o2.length; } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java index 324516359..f75b5f8a7 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java @@ -21,8 +21,8 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -70,7 +70,7 @@ public class MinMaxScalerTest extends AbstractTestBase { Row.of(Vectors.dense(50.0, 40.0)), Row.of(Vectors.dense(100.0, 50.0)))); private static final double EPS = 1.0e-5; - private static final List EXPECTED_DATA = + private static final List EXPECTED_DATA = new ArrayList<>( Arrays.asList( Vectors.dense(0.25, 0.1), @@ -86,15 +86,15 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expected) throws Exception { + Table output, String outputCol, List expected) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); - DataStream stream = + DataStream stream = tEnv.toDataStream(output) .map( - (MapFunction) - row -> (DenseVector) row.getField(outputCol)); - List result = IteratorUtils.toList(stream.executeAndCollect()); + (MapFunction) + row -> (DenseIntDoubleVector) row.getField(outputCol)); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expected, result, TestUtils::compare); } @@ -158,9 +158,10 @@ public void testInputTypeConversion() throws Exception { trainDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, trainDataTable); predictDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictDataTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(trainDataTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(trainDataTable)); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(predictDataTable)); MinMaxScaler minMaxScaler = new MinMaxScaler(); @@ -202,8 +203,11 @@ public void testGetModelData() throws Exception { modelData.getResolvedSchema().getColumnNames()); DataStream output = tEnv.toDataStream(modelData); List modelRows = IteratorUtils.toList(output.executeAndCollect()); - assertEquals(new DenseVector(new double[] {0.0, 0.0}), modelRows.get(0).getField(0)); - assertEquals(new DenseVector(new double[] {200.0, 400.0}), modelRows.get(0).getField(1)); + assertEquals( + new DenseIntDoubleVector(new double[] {0.0, 0.0}), modelRows.get(0).getField(0)); + assertEquals( + new DenseIntDoubleVector(new double[] {200.0, 400.0}), + modelRows.get(0).getField(1)); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java index 13e28c68f..7dfd1fa44 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/NormalizerTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.normalizer.Normalizer; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -56,7 +56,7 @@ public class NormalizerTest extends AbstractTestBase { Vectors.dense(2.3, 4.1, 1.3, 2.4, 5.1, 4.1), Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {0.1, 0.2, 0.3}))); - private static final List EXPECTED_DENSE_OUTPUT = + private static final List EXPECTED_DENSE_OUTPUT = Arrays.asList( Vectors.dense( 0.17386300895299714, @@ -73,7 +73,7 @@ public class NormalizerTest extends AbstractTestBase { 0.4608889965995767, 0.3705186051094636)); - private static final List EXPECTED_SPARSE_OUTPUT = + private static final List EXPECTED_SPARSE_OUTPUT = Arrays.asList( Vectors.sparse( 5, @@ -96,14 +96,14 @@ public void before() { inputDataTable = tEnv.fromDataStream(dataStream).as("denseVec", "sparseVec"); } - private void verifyOutputResult(Table output, String outputCol, List expectedData) - throws Exception { + private void verifyOutputResult( + Table output, String outputCol, List expectedData) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); DataStream stream = tEnv.toDataStream(output); List results = IteratorUtils.toList(stream.executeAndCollect()); - List resultVec = new ArrayList<>(results.size()); + List resultVec = new ArrayList<>(results.size()); for (Row row : results) { if (row.getField(outputCol) != null) { resultVec.add(row.getFieldAs(outputCol)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java index e5a6726a8..c588d0e13 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -59,7 +59,7 @@ public class OneHotEncoderTest extends AbstractTestBase { private StreamTableEnvironment tEnv; private Table trainTable; private Table predictTable; - private Map[] expectedOutput; + private Map[] expectedOutput; private OneHotEncoder estimator; @Before @@ -77,7 +77,7 @@ public void before() { expectedOutput = new HashMap[] { - new HashMap() { + new HashMap() { { put(0.0, Vectors.sparse(2, new int[] {0}, new double[] {1.0})); put(1.0, Vectors.sparse(2, new int[] {1}, new double[] {1.0})); @@ -99,9 +99,9 @@ public void before() { * @param outputCols Name of the output columns containing one-hot encoding result * @return An array of map containing the collected results for each input column */ - private static Map[] executeAndCollect( + private static Map[] executeAndCollect( Table table, String[] inputCols, String[] outputCols) { - Map[] maps = new HashMap[inputCols.length]; + Map[] maps = new HashMap[inputCols.length]; for (int i = 0; i < inputCols.length; i++) { maps[i] = new HashMap<>(); } @@ -110,7 +110,7 @@ private static Map[] executeAndCollect( for (int i = 0; i < inputCols.length; i++) { maps[i].put( ((Number) row.getField(inputCols[i])).doubleValue(), - (Vector) row.getField(outputCols[i])); + (IntDoubleVector) row.getField(outputCols[i])); } } return maps; @@ -143,7 +143,7 @@ public void testParam() { public void testFitAndPredict() { OneHotEncoderModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -158,7 +158,7 @@ public void testInputTypeConversion() throws Exception { OneHotEncoderModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -169,7 +169,7 @@ public void testDropLast() { expectedOutput = new HashMap[] { - new HashMap() { + new HashMap() { { put(0.0, Vectors.sparse(3, new int[] {0}, new double[] {1.0})); put(1.0, Vectors.sparse(3, new int[] {1}, new double[] {1.0})); @@ -180,7 +180,7 @@ public void testDropLast() { OneHotEncoderModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -196,7 +196,7 @@ public void testInputDataType() { expectedOutput = new HashMap[] { - new HashMap() { + new HashMap() { { put(0.0, Vectors.sparse(2, new int[] {0}, new double[] {1.0})); put(1.0, Vectors.sparse(2, new int[] {1}, new double[] {1.0})); @@ -207,7 +207,7 @@ public void testInputDataType() { OneHotEncoderModel model = estimator.fit(trainTable); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -276,7 +276,7 @@ public void testSaveLoad() throws Exception { tempFolder.newFolder().getAbsolutePath(), OneHotEncoderModel::load); Table outputTable = model.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } @@ -301,7 +301,7 @@ public void testSetModelData() { ParamUtils.updateExistingParams(modelB, modelA.getParamMap()); Table outputTable = modelB.transform(predictTable)[0]; - Map[] actualOutput = + Map[] actualOutput = executeAndCollect(outputTable, modelB.getInputCols(), modelB.getOutputCols()); assertArrayEquals(expectedOutput, actualOutput); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java index 303713cd1..fdb1515f8 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OnlineStandardScalerTest.java @@ -32,10 +32,10 @@ import org.apache.flink.ml.feature.standardscaler.OnlineStandardScaler; import org.apache.flink.ml.feature.standardscaler.OnlineStandardScalerModel; import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -115,7 +115,10 @@ public void before() { inputStream, Schema.newBuilder() .column("f0", DataTypes.BIGINT()) - .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .column( + "f1", + DataTypes.RAW( + DenseIntDoubleVectorTypeInfo.INSTANCE)) .build()) .as("id", "input"); @@ -136,7 +139,7 @@ public Row map(Row value) throws Exception { }, new RowTypeInfo( new TypeInformation[] { - Types.LONG, DenseVectorTypeInfo.INSTANCE + Types.LONG, DenseIntDoubleVectorTypeInfo.INSTANCE }, new String[] {"id", "input"})) .setParallelism(1); @@ -155,7 +158,10 @@ public Row map(Row value) throws Exception { inputStreamWithEventTime, Schema.newBuilder() .column("f0", DataTypes.BIGINT()) - .column("f1", DataTypes.RAW(DenseVectorTypeInfo.INSTANCE)) + .column( + "f1", + DataTypes.RAW( + DenseIntDoubleVectorTypeInfo.INSTANCE)) .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") .watermark("rowtime", "SOURCE_WATERMARK()") .build()) @@ -295,7 +301,7 @@ public void testFitAndPredictWithGlobalWindow() throws Exception { .as("input"); // Tests withMean option. - List expectedResWithMean = + List expectedResWithMean = Arrays.asList( Vectors.dense(-2.8, 8, 1), Vectors.dense(1.1, -6, 1), @@ -304,7 +310,7 @@ public void testFitAndPredictWithGlobalWindow() throws Exception { verifyPredictionResult(expectedResWithMean, output, standardScaler.getOutputCol()); // Tests withStd option. - List expectedResWithStd = + List expectedResWithStd = Arrays.asList( Vectors.dense(-1.0231819, 1.2480754, 0.5773502), Vectors.dense(0.5729819, -0.6933752, 0.5773503), @@ -313,7 +319,7 @@ public void testFitAndPredictWithGlobalWindow() throws Exception { verifyPredictionResult(expectedResWithStd, output, standardScaler.getOutputCol()); // Tests withMean, withStd Option. - List expectedResWithMeanAndStd = + List expectedResWithMeanAndStd = Arrays.asList( Vectors.dense(-1.1459637, 1.1094004, 0.5773503), Vectors.dense(0.45020003, -0.8320503, 0.5773503), @@ -423,13 +429,14 @@ private void verifyUsedModelVersion( @SuppressWarnings("unchecked") private void verifyPredictionResult( - List expectedOutput, Table output, String predictionCol) throws Exception { + List expectedOutput, Table output, String predictionCol) + throws Exception { List collectedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); - List predictions = new ArrayList<>(collectedResult.size()); + List predictions = new ArrayList<>(collectedResult.size()); for (Row r : collectedResult) { - Vector vec = (Vector) r.getField(predictionCol); + IntDoubleVector vec = (IntDoubleVector) r.getField(predictionCol); predictions.add(vec.toDense()); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java index 5454eabff..9b06c0860 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/PolynomialExpansionTest.java @@ -19,7 +19,7 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.polynomialexpansion.PolynomialExpansion; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -56,19 +56,19 @@ public class PolynomialExpansionTest extends AbstractTestBase { Vectors.dense(2.0, 3.0), Vectors.sparse(5, new int[] {1, 4}, new double[] {2.0, 1.0}))); - private static final List EXPECTED_DENSE_OUTPUT = + private static final List EXPECTED_DENSE_OUTPUT = Arrays.asList( Vectors.dense(1.0, 1.0, 2.0, 2.0, 4.0, 3.0, 3.0, 6.0, 9.0), Vectors.dense(2.0, 4.0, 3.0, 6.0, 9.0)); - private static final List EXPECTED_DENSE_OUTPUT_WITH_DEGREE_3 = + private static final List EXPECTED_DENSE_OUTPUT_WITH_DEGREE_3 = Arrays.asList( Vectors.dense( 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 4.0, 4.0, 8.0, 3.0, 3.0, 3.0, 6.0, 6.0, 12.0, 9.0, 9.0, 18.0, 27.0), Vectors.dense(2.0, 4.0, 8.0, 3.0, 6.0, 12.0, 9.0, 18.0, 27.0)); - private static final List EXPECTED_SPARSE_OUTPUT = + private static final List EXPECTED_SPARSE_OUTPUT = Arrays.asList( Vectors.sparse( 55, @@ -87,14 +87,14 @@ public void before() { inputDataTable = tEnv.fromDataStream(dataStream).as("denseVec", "sparseVec"); } - private void verifyOutputResult(Table output, String outputCol, List expectedData) - throws Exception { + private void verifyOutputResult( + Table output, String outputCol, List expectedData) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); DataStream stream = tEnv.toDataStream(output); List results = IteratorUtils.toList(stream.executeAndCollect()); - List resultVec = new ArrayList<>(results.size()); + List resultVec = new ArrayList<>(results.size()); for (Row row : results) { if (row.getField(outputCol) != null) { resultVec.add(row.getFieldAs(outputCol)); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java index e8179024a..a97f57a84 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java @@ -21,8 +21,8 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.ml.feature.robustscaler.RobustScaler; import org.apache.flink.ml.feature.robustscaler.RobustScalerModel; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -79,7 +79,7 @@ public class RobustScalerTest extends AbstractTestBase { Row.of(Vectors.dense(99.0, -99.0)))); private static final double EPS = 1.0e-5; - private static final List EXPECTED_OUTPUT = + private static final List EXPECTED_OUTPUT = new ArrayList<>( Arrays.asList( Vectors.dense(0.75, -0.75), @@ -95,15 +95,15 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expected) throws Exception { + Table output, String outputCol, List expected) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); - DataStream stream = + DataStream stream = tEnv.toDataStream(output) .map( - (MapFunction) - row -> (DenseVector) row.getField(outputCol)); - List result = IteratorUtils.toList(stream.executeAndCollect()); + (MapFunction) + row -> (DenseIntDoubleVector) row.getField(outputCol)); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expected, result, TestUtils::compare); } @@ -162,9 +162,10 @@ public void testInputTypeConversion() throws Exception { tEnv, trainDataTable.select(Expressions.$("input"))); predictDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictDataTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(trainDataTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(trainDataTable)); assertArrayEquals( - new Class[] {SparseVector.class}, + new Class[] {SparseIntDoubleVector.class}, TestUtils.getColumnDataTypes(predictDataTable)); RobustScaler robustScaler = new RobustScaler(); @@ -217,7 +218,7 @@ public void testWithCentering() throws Exception { RobustScaler robustScaler = new RobustScaler().setWithCentering(true); RobustScalerModel model = robustScaler.fit(trainDataTable); Table output = model.transform(predictDataTable)[0]; - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.dense(-0.25, 0.25), @@ -231,7 +232,7 @@ public void testWithoutScaling() throws Exception { RobustScaler robustScaler = new RobustScaler().setWithCentering(true).setWithScaling(false); RobustScalerModel model = robustScaler.fit(trainDataTable); Table output = model.transform(predictDataTable)[0]; - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.dense(-1, 1), @@ -273,7 +274,7 @@ public void testZeroRange() throws Exception { Row.of(2, Vectors.dense(1.0, 1.0)), Row.of(3, Vectors.dense(1.0, 1.0)), Row.of(4, Vectors.dense(4.0, 4.0)))); - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.dense(0.0, -0.0), @@ -297,7 +298,7 @@ public void testNaNData() throws Exception { Row.of(3, Vectors.dense(2.0, -2.0)), Row.of(4, Vectors.dense(3.0, -3.0)), Row.of(5, Vectors.dense(4.0, -4.0)))); - List expectedOutput = + List expectedOutput = new ArrayList<>( Arrays.asList( Vectors.dense(0.0, Double.NaN), @@ -322,11 +323,11 @@ public void testGetModelData() throws Exception { Arrays.asList("medians", "ranges"), modelData.getResolvedSchema().getColumnNames()); DataStream output = tEnv.toDataStream(modelData); List modelRows = IteratorUtils.toList(output.executeAndCollect()); - DenseVector medians = (DenseVector) modelRows.get(0).getField(0); - DenseVector ranges = (DenseVector) modelRows.get(0).getField(1); + DenseIntDoubleVector medians = (DenseIntDoubleVector) modelRows.get(0).getField(0); + DenseIntDoubleVector ranges = (DenseIntDoubleVector) modelRows.get(0).getField(1); - DenseVector expectedMedians = Vectors.dense(4.0, -4.0); - DenseVector expectedRanges = Vectors.dense(4.0, 4.0); + DenseIntDoubleVector expectedMedians = Vectors.dense(4.0, -4.0); + DenseIntDoubleVector expectedRanges = Vectors.dense(4.0, 4.0); assertEquals(expectedMedians, medians); assertEquals(expectedRanges, ranges); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java index 7d08c70ce..5c18ee727 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/StandardScalerTest.java @@ -21,9 +21,9 @@ import org.apache.flink.ml.feature.standardscaler.StandardScaler; import org.apache.flink.ml.feature.standardscaler.StandardScalerModel; import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; @@ -61,19 +61,19 @@ public class StandardScalerTest extends AbstractTestBase { Row.of(Vectors.dense(1.4, -5, 1)), Row.of(Vectors.dense(2, -1, -2))); - private final List expectedResWithMean = + private final List expectedResWithMean = Arrays.asList( Vectors.dense(-2.8, 8, 1), Vectors.dense(1.1, -6, 1), Vectors.dense(1.7, -2, -2)); - private final List expectedResWithStd = + private final List expectedResWithStd = Arrays.asList( Vectors.dense(-1.0231819, 1.2480754, 0.5773502), Vectors.dense(0.5729819, -0.6933752, 0.5773503), Vectors.dense(0.8185455, -0.1386750, -1.1547005)); - private final List expectedResWithMeanAndStd = + private final List expectedResWithMeanAndStd = Arrays.asList( Vectors.dense(-1.1459637, 1.1094004, 0.5773503), Vectors.dense(0.45020003, -0.8320503, 0.5773503), @@ -92,13 +92,14 @@ public void before() { @SuppressWarnings("unchecked") private void verifyPredictionResult( - List expectedOutput, Table output, String predictionCol) throws Exception { + List expectedOutput, Table output, String predictionCol) + throws Exception { List collectedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); - List predictions = new ArrayList<>(collectedResult.size()); + List predictions = new ArrayList<>(collectedResult.size()); for (Row r : collectedResult) { - Vector vec = (Vector) r.getField(predictionCol); + IntDoubleVector vec = (IntDoubleVector) r.getField(predictionCol); predictions.add(vec.toDense()); } @@ -179,7 +180,8 @@ public void testFitAndPredictWithMeanAndStd() throws Exception { public void testInputTypeConversion() throws Exception { denseTable = TestUtils.convertDataTypesToSparseInt(tEnv, denseTable); assertArrayEquals( - new Class[] {SparseVector.class}, TestUtils.getColumnDataTypes(denseTable)); + new Class[] {SparseIntDoubleVector.class}, + TestUtils.getColumnDataTypes(denseTable)); StandardScaler standardScaler = new StandardScaler().setWithMean(true); Table output = standardScaler.fit(denseTable).transform(denseTable)[0]; @@ -257,7 +259,7 @@ public void testSparseInput() throws Exception { Row.of(Vectors.sparse(3, new int[] {0, 2}, new double[] {1.4, 1}))); Table sparseTable = tEnv.fromDataStream(env.fromCollection(sparseInput)).as("input"); - final List expectedResWithStd = + final List expectedResWithStd = Arrays.asList( Vectors.dense(-1.2653836, 1, 0), Vectors.dense(0, 2, -1.30930734), diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java index f722a78a3..af4551d55 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector; import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; import org.apache.flink.ml.util.TestUtils; @@ -512,11 +512,13 @@ private void verifyOutputResult(Table table, int... expectedIndices) throws Exce CloseableIterator rowIterator = tEnv.toDataStream(table).executeAndCollect(); while (rowIterator.hasNext()) { Row row = rowIterator.next(); - assertEquals(expectedIndices.length, ((Vector) row.getField("output")).size()); + assertEquals( + expectedIndices.length, + ((IntDoubleVector) (row.getField("output"))).size().intValue()); for (int i = 0; i < expectedIndices.length; i++) { assertEquals( - ((Vector) row.getField("features")).get(expectedIndices[i]), - ((Vector) row.getField("output")).get(i), + ((IntDoubleVector) row.getField("features")).get(expectedIndices[i]), + ((IntDoubleVector) row.getField("output")).get(i), EPS); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java index 43a893ce5..35aca0606 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector; import org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -74,7 +75,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { Row.of(Vectors.dense(0.1, 0.2, 0.3, 0.4, 0.5, 0.6)), Row.of(Vectors.sparse(6, new int[] {0, 3, 4}, new double[] {0.1, 0.3, 0.5}))); - private static final List EXPECTED_OUTPUT = + private static final List EXPECTED_OUTPUT = Arrays.asList( Vectors.dense(1.0, 4.0, 6.0), Vectors.dense(0.1, 0.4, 0.6), @@ -98,7 +99,7 @@ public void before() { } private static void verifyPredictionResult( - Table output, String outputCol, List expected) throws Exception { + Table output, String outputCol, List expected) throws Exception { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); DataStream stream = @@ -106,7 +107,7 @@ private static void verifyPredictionResult( .map( (MapFunction) row -> (Vector) row.getField(outputCol), VectorTypeInfo.INSTANCE); - List result = IteratorUtils.toList(stream.executeAndCollect()); + List result = IteratorUtils.toList(stream.executeAndCollect()); compareResultCollections(expected, result, TestUtils::compare); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java index f70d95fc4..2280b957d 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java @@ -20,8 +20,8 @@ import org.apache.flink.ml.common.param.HasHandleInvalid; import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -111,15 +111,15 @@ public class VectorAssemblerTest extends AbstractTestBase { Vectors.sparse( 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0}))); - private static final SparseVector EXPECTED_OUTPUT_DATA_1 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_1 = Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 1.0, 1.0}); - private static final DenseVector EXPECTED_OUTPUT_DATA_2 = + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); - private static final DenseVector EXPECTED_OUTPUT_DATA_3 = + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_3 = Vectors.dense(2.0, 2.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); - private static final DenseVector EXPECTED_OUTPUT_DATA_4 = + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_4 = Vectors.dense(Double.NaN, Double.NaN, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); - private static final DenseVector EXPECTED_OUTPUT_DATA_5 = + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_5 = Vectors.dense(2.0, 2.1, Double.NaN, 0.0, 1.0, 2.0, 3.0, 4.0); @Before @@ -373,7 +373,10 @@ public void testInputTypeConversion() throws Exception { inputDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, inputDataTable); assertArrayEquals( new Class[] { - Integer.class, SparseVector.class, Integer.class, SparseVector.class + Integer.class, + SparseIntDoubleVector.class, + Integer.class, + SparseIntDoubleVector.class }, TestUtils.getColumnDataTypes(inputDataTable)); @@ -409,7 +412,8 @@ public void testNumber2Vector() throws Exception { List results = IteratorUtils.toList(dataStream.executeAndCollect()); for (Row result : results) { if (result.getField(2) != null) { - assertEquals(result.getField(2), ((DenseVector) result.getField(4)).values[0]); + assertEquals( + result.getField(2), ((DenseIntDoubleVector) result.getField(4)).values[0]); } } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java index fb9838001..9bb31757f 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java @@ -19,8 +19,8 @@ package org.apache.flink.ml.feature; import org.apache.flink.ml.feature.vectorslicer.VectorSlicer; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; @@ -59,12 +59,12 @@ public class VectorSlicerTest extends AbstractTestBase { Vectors.dense(2.3, 4.1, 1.3, 2.4, 5.1, 4.1), Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {0.1, 0.2, 0.3}))); - private static final DenseVector EXPECTED_OUTPUT_DATA_1 = Vectors.dense(2.1, 3.1, 2.3); - private static final DenseVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(2.3, 4.1, 1.3); + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_1 = Vectors.dense(2.1, 3.1, 2.3); + private static final DenseIntDoubleVector EXPECTED_OUTPUT_DATA_2 = Vectors.dense(2.3, 4.1, 1.3); - private static final SparseVector EXPECTED_OUTPUT_DATA_3 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_3 = Vectors.sparse(3, new int[] {1}, new double[] {0.1}); - private static final SparseVector EXPECTED_OUTPUT_DATA_4 = + private static final SparseIntDoubleVector EXPECTED_OUTPUT_DATA_4 = Vectors.sparse(3, new int[] {1, 2}, new double[] {0.1, 0.2}); @Before diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java index 1c051e17e..bd7c078ff 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java @@ -21,9 +21,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import org.apache.flink.ml.linalg.Vectors; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; import org.apache.flink.ml.regression.linearregression.LinearRegression; import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData; @@ -93,7 +93,9 @@ public void before() { trainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); } @@ -173,7 +175,7 @@ public void testFitAndPredict() throws Exception { public void testInputTypeConversion() throws Exception { trainDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, trainDataTable); assertArrayEquals( - new Class[] {SparseVector.class, Integer.class, Integer.class}, + new Class[] {SparseIntDoubleVector.class, Integer.class, Integer.class}, TestUtils.getColumnDataTypes(trainDataTable)); LinearRegression linearRegression = new LinearRegression().setWeightCol("weight"); @@ -246,7 +248,9 @@ public void testMoreSubtaskThanData() throws Exception { trainData, new RowTypeInfo( new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + DenseIntDoubleVectorTypeInfo.INSTANCE, + Types.DOUBLE, + Types.DOUBLE }, new String[] {"features", "label", "weight"}))); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java index 1b2270852..891a8f50e 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.stats; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.stats.anovatest.ANOVATest; import org.apache.flink.ml.util.TestUtils; @@ -328,13 +328,13 @@ private static void verifyTransformationResult(Table output, Row expected) throw Row result = results.get(0); assertEquals(3, result.getArity()); assertArrayEquals( - ((Vector) expected.getField(0)).toArray(), - ((Vector) result.getField(0)).toArray(), + ((IntDoubleVector) expected.getField(0)).toArray(), + ((IntDoubleVector) result.getField(0)).toArray(), EPS); assertArrayEquals((long[]) expected.getField(1), (long[]) result.getField(1)); assertArrayEquals( - ((Vector) expected.getField(2)).toArray(), - ((Vector) result.getField(2)).toArray(), + ((IntDoubleVector) expected.getField(2)).toArray(), + ((IntDoubleVector) result.getField(2)).toArray(), EPS); } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java index faf8fe576..a79e82463 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.stats; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.stats.fvaluetest.FValueTest; import org.apache.flink.ml.util.TestUtils; @@ -346,13 +346,13 @@ private static void verifyTransformationResult(Table output, Row expected) throw Row result = results.get(0); assertEquals(3, result.getArity()); assertArrayEquals( - ((Vector) expected.getField(0)).toArray(), - ((Vector) result.getField(0)).toArray(), + ((IntDoubleVector) expected.getField(0)).toArray(), + ((IntDoubleVector) result.getField(0)).toArray(), EPS); assertArrayEquals((long[]) expected.getField(1), (long[]) result.getField(1)); assertArrayEquals( - ((Vector) expected.getField(2)).toArray(), - ((Vector) result.getField(2)).toArray(), + ((IntDoubleVector) expected.getField(2)).toArray(), + ((IntDoubleVector) result.getField(2)).toArray(), EPS); } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java index 136231ac8..b7a205a60 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java @@ -18,18 +18,18 @@ package org.apache.flink.ml.common.feature; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.IntDoubleVector; /** Utility class to represent a data point that contains features, label and weight. */ public class LabeledPointWithWeight { - private Vector features; + private IntDoubleVector features; private double label; private double weight; - public LabeledPointWithWeight(Vector features, double label, double weight) { + public LabeledPointWithWeight(IntDoubleVector features, double label, double weight) { this.features = features; this.label = label; this.weight = weight; @@ -37,11 +37,11 @@ public LabeledPointWithWeight(Vector features, double label, double weight) { public LabeledPointWithWeight() {} - public Vector getFeatures() { + public IntDoubleVector getFeatures() { return features; } - public void setFeatures(Vector features) { + public void setFeatures(IntDoubleVector features) { this.features = features; } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java index 643f1e0d8..6bc96aebc 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java @@ -27,67 +27,70 @@ public class BLAS { dev.ludovic.netlib.JavaBLAS.getInstance(); /** \sum_i |x_i| . */ - public static double asum(DenseVector x) { + public static double asum(DenseIntDoubleVector x) { return JAVA_BLAS.dasum(x.size(), x.values, 0, 1); } /** y += a * x . */ - public static void axpy(double a, Vector x, DenseVector y) { - Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); + public static void axpy(double a, IntDoubleVector x, DenseIntDoubleVector y) { + Preconditions.checkArgument( + x.size().intValue() == y.size().intValue(), "Vector size mismatched."); axpy(a, x, y, x.size()); } /** y += a * x for the first k dimensions, with the other dimensions unchanged. */ - public static void axpy(double a, Vector x, DenseVector y, int k) { + public static void axpy(double a, IntDoubleVector x, DenseIntDoubleVector y, int k) { Preconditions.checkArgument(x.size() >= k && y.size() >= k); - if (x instanceof SparseVector) { - axpy(a, (SparseVector) x, y, k); + if (x instanceof SparseIntDoubleVector) { + axpy(a, (SparseIntDoubleVector) x, y, k); } else { - axpy(a, (DenseVector) x, y, k); + axpy(a, (DenseIntDoubleVector) x, y, k); } } /** Computes the hadamard product of the two vectors (y = y \hdot x). */ - public static void hDot(Vector x, Vector y) { - Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); - if (x instanceof SparseVector) { - if (y instanceof SparseVector) { - hDot((SparseVector) x, (SparseVector) y); + public static void hDot(IntDoubleVector x, IntDoubleVector y) { + Preconditions.checkArgument( + x.size().intValue() == y.size().intValue(), "Vector size mismatched."); + if (x instanceof SparseIntDoubleVector) { + if (y instanceof SparseIntDoubleVector) { + hDot((SparseIntDoubleVector) x, (SparseIntDoubleVector) y); } else { - hDot((SparseVector) x, (DenseVector) y); + hDot((SparseIntDoubleVector) x, (DenseIntDoubleVector) y); } } else { - if (y instanceof SparseVector) { - hDot((DenseVector) x, (SparseVector) y); + if (y instanceof SparseIntDoubleVector) { + hDot((DenseIntDoubleVector) x, (SparseIntDoubleVector) y); } else { - hDot((DenseVector) x, (DenseVector) y); + hDot((DenseIntDoubleVector) x, (DenseIntDoubleVector) y); } } } /** Computes the dot of the two vectors (y \dot x). */ - public static double dot(Vector x, Vector y) { - Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); - if (x instanceof SparseVector) { - if (y instanceof SparseVector) { - return dot((SparseVector) x, (SparseVector) y); + public static double dot(IntDoubleVector x, IntDoubleVector y) { + Preconditions.checkArgument( + x.size().intValue() == y.size().intValue(), "Vector size mismatched."); + if (x instanceof SparseIntDoubleVector) { + if (y instanceof SparseIntDoubleVector) { + return dot((SparseIntDoubleVector) x, (SparseIntDoubleVector) y); } else { - return dot((DenseVector) y, (SparseVector) x); + return dot((DenseIntDoubleVector) y, (SparseIntDoubleVector) x); } } else { - if (y instanceof SparseVector) { - return dot((DenseVector) x, (SparseVector) y); + if (y instanceof SparseIntDoubleVector) { + return dot((DenseIntDoubleVector) x, (SparseIntDoubleVector) y); } else { - return dot((DenseVector) x, (DenseVector) y); + return dot((DenseIntDoubleVector) x, (DenseIntDoubleVector) y); } } } - private static double dot(DenseVector x, DenseVector y) { + private static double dot(DenseIntDoubleVector x, DenseIntDoubleVector y) { return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1); } - private static double dot(DenseVector x, SparseVector y) { + private static double dot(DenseIntDoubleVector x, SparseIntDoubleVector y) { double dotValue = 0.0; for (int i = 0; i < y.indices.length; ++i) { dotValue += y.values[i] * x.values[y.indices[i]]; @@ -95,7 +98,7 @@ private static double dot(DenseVector x, SparseVector y) { return dotValue; } - private static double dot(SparseVector x, SparseVector y) { + private static double dot(SparseIntDoubleVector x, SparseIntDoubleVector y) { double dotValue = 0; int p0 = 0; int p1 = 0; @@ -114,27 +117,30 @@ private static double dot(SparseVector x, SparseVector y) { } /** \sqrt(\sum_i x_i * x_i) . */ - public static double norm2(Vector x) { - if (x instanceof DenseVector) { - return norm2((DenseVector) x); + public static double norm2(IntDoubleVector x) { + if (x instanceof DenseIntDoubleVector) { + return norm2((DenseIntDoubleVector) x); + } else { + return norm2((SparseIntDoubleVector) x); } - return norm2((SparseVector) x); } - private static double norm2(DenseVector x) { + private static double norm2(DenseIntDoubleVector x) { return JAVA_BLAS.dnrm2(x.size(), x.values, 1); } - private static double norm2(SparseVector x) { + private static double norm2(SparseIntDoubleVector x) { return JAVA_BLAS.dnrm2(x.values.length, x.values, 1); } /** Calculates the p-norm of the vector x. */ - public static double norm(Vector x, double p) { + public static double norm(IntDoubleVector x, double p) { Preconditions.checkArgument(p >= 1.0, "p value must >= 1.0, but the current p is : " + p); double norm = 0.0; double[] data = - (x instanceof DenseVector) ? ((DenseVector) x).values : ((SparseVector) x).values; + (x instanceof DenseIntDoubleVector) + ? ((DenseIntDoubleVector) x).values + : ((SparseIntDoubleVector) x).values; if (p == 1.0) { for (double datum : data) { @@ -157,11 +163,11 @@ public static double norm(Vector x, double p) { } /** x = x * a . */ - public static void scal(double a, Vector x) { - if (x instanceof DenseVector) { - JAVA_BLAS.dscal(x.size(), a, ((DenseVector) x).values, 1); + public static void scal(double a, IntDoubleVector x) { + if (x instanceof DenseIntDoubleVector) { + JAVA_BLAS.dscal(x.size(), a, ((DenseIntDoubleVector) x).values, 1); } else { - double[] values = ((SparseVector) x).values; + double[] values = ((SparseIntDoubleVector) x).values; JAVA_BLAS.dscal(values.length, a, values, 1); } } @@ -180,9 +186,9 @@ public static void gemv( double alpha, DenseMatrix matrix, boolean transMatrix, - DenseVector x, + DenseIntDoubleVector x, double beta, - DenseVector y) { + DenseIntDoubleVector y) { Preconditions.checkArgument( transMatrix ? (matrix.numRows() == x.size() && matrix.numCols() == y.size()) @@ -203,11 +209,11 @@ public static void gemv( 1); } - private static void axpy(double a, DenseVector x, DenseVector y, int k) { + private static void axpy(double a, DenseIntDoubleVector x, DenseIntDoubleVector y, int k) { JAVA_BLAS.daxpy(k, a, x.values, 1, y.values, 1); } - private static void axpy(double a, SparseVector x, DenseVector y, int k) { + private static void axpy(double a, SparseIntDoubleVector x, DenseIntDoubleVector y, int k) { for (int i = 0; i < x.indices.length; i++) { int index = x.indices[i]; if (index >= k) { @@ -217,7 +223,7 @@ private static void axpy(double a, SparseVector x, DenseVector y, int k) { } } - private static void hDot(SparseVector x, SparseVector y) { + private static void hDot(SparseIntDoubleVector x, SparseIntDoubleVector y) { int idx = 0; int idy = 0; while (idx < x.indices.length && idy < y.indices.length) { @@ -238,7 +244,7 @@ private static void hDot(SparseVector x, SparseVector y) { } } - private static void hDot(SparseVector x, DenseVector y) { + private static void hDot(SparseIntDoubleVector x, DenseIntDoubleVector y) { int idx = 0; for (int i = 0; i < y.size(); i++) { if (idx < x.indices.length && x.indices[idx] == i) { @@ -250,13 +256,13 @@ private static void hDot(SparseVector x, DenseVector y) { } } - private static void hDot(DenseVector x, SparseVector y) { + private static void hDot(DenseIntDoubleVector x, SparseIntDoubleVector y) { for (int i = 0; i < y.values.length; i++) { y.values[i] *= x.values[y.indices[i]]; } } - private static void hDot(DenseVector x, DenseVector y) { + private static void hDot(DenseIntDoubleVector x, DenseIntDoubleVector y) { for (int i = 0; i < x.values.length; i++) { y.values[i] *= x.values[i]; } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseIntDoubleVector.java similarity index 71% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseIntDoubleVector.java index e26a93a4f..699fc269a 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/DenseIntDoubleVector.java @@ -20,36 +20,36 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.typeinfo.TypeInfo; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfoFactory; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfoFactory; import java.util.Arrays; -/** A dense vector of double values. */ -@TypeInfo(DenseVectorTypeInfoFactory.class) +/** A dense vector with int as keys and double as values. */ +@TypeInfo(DenseIntDoubleVectorTypeInfoFactory.class) @PublicEvolving -public class DenseVector implements Vector { +public class DenseIntDoubleVector implements IntDoubleVector { public final double[] values; - public DenseVector(double[] values) { + public DenseIntDoubleVector(double[] values) { this.values = values; } - public DenseVector(int size) { + public DenseIntDoubleVector(int size) { this.values = new double[size]; } @Override - public int size() { + public Integer size() { return values.length; } @Override - public double get(int i) { + public Double get(Integer i) { return values[i]; } @Override - public void set(int i, double value) { + public void set(Integer i, Double value) { values[i] = value; } @@ -59,12 +59,12 @@ public double[] toArray() { } @Override - public DenseVector toDense() { + public DenseIntDoubleVector toDense() { return this; } @Override - public SparseVector toSparse() { + public SparseIntDoubleVector toSparse() { int numNonZeros = 0; for (double value : values) { if (value != 0.0) { @@ -84,7 +84,7 @@ public SparseVector toSparse() { k++; } - return new SparseVector(size(), nonZeroIndices, numZeroValues); + return new SparseIntDoubleVector(size(), nonZeroIndices, numZeroValues); } @Override @@ -94,10 +94,10 @@ public String toString() { @Override public boolean equals(Object obj) { - if (!(obj instanceof DenseVector)) { + if (!(obj instanceof DenseIntDoubleVector)) { return false; } - return Arrays.equals(values, ((DenseVector) obj).values); + return Arrays.equals(values, ((DenseIntDoubleVector) obj).values); } @Override @@ -106,7 +106,7 @@ public int hashCode() { } @Override - public DenseVector clone() { - return new DenseVector(values.clone()); + public DenseIntDoubleVector clone() { + return new DenseIntDoubleVector(values.clone()); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/IntDoubleVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/IntDoubleVector.java new file mode 100644 index 000000000..2bed89d96 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/IntDoubleVector.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfoFactory; + +/** A vector with int as keys and double as values. */ +@TypeInfo(VectorTypeInfoFactory.class) +@PublicEvolving +public interface IntDoubleVector extends Vector { + + @Override + double[] toArray(); + + @Override + DenseIntDoubleVector toDense(); + + @Override + SparseIntDoubleVector toSparse(); + + @Override + IntDoubleVector clone(); +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/LongDoubleVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/LongDoubleVector.java new file mode 100644 index 000000000..e0265f86f --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/LongDoubleVector.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +/** A vector with long as keys and double as values. */ +public interface LongDoubleVector extends Vector { + @Override + default double[] toArray() { + throw new UnsupportedOperationException( + "LongDoubleVector cannot be converted to dense array."); + } + + @Override + default Vector toDense() { + throw new UnsupportedOperationException( + "LongDoubleVector cannot be converted to dense array."); + } + + @Override + SparseLongDoubleVector toSparse(); + + @Override + LongDoubleVector clone(); +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseIntDoubleVector.java similarity index 86% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseIntDoubleVector.java index 54e707f63..f1b656673 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseIntDoubleVector.java @@ -20,21 +20,21 @@ import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.api.common.typeinfo.TypeInfo; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfoFactory; import org.apache.flink.util.Preconditions; import java.util.Arrays; import java.util.Objects; -/** A sparse vector of double values. */ -@TypeInfo(SparseVectorTypeInfoFactory.class) +/** A sparse vector with int as keys and double as values. */ +@TypeInfo(SparseIntDoubleVectorTypeInfoFactory.class) @PublicEvolving -public class SparseVector implements Vector { +public class SparseIntDoubleVector implements IntDoubleVector { public final int n; public int[] indices; public double[] values; - public SparseVector(int n, int[] indices, double[] values) { + public SparseIntDoubleVector(int n, int[] indices, double[] values) { this.n = n; this.indices = indices; this.values = values; @@ -45,12 +45,12 @@ public SparseVector(int n, int[] indices, double[] values) { } @Override - public int size() { + public Integer size() { return n; } @Override - public double get(int i) { + public Double get(Integer i) { int pos = Arrays.binarySearch(indices, i); if (pos >= 0) { return values[pos]; @@ -59,7 +59,7 @@ public double get(int i) { } @Override - public void set(int i, double value) { + public void set(Integer i, Double value) { int pos = Arrays.binarySearch(indices, i); if (pos >= 0) { values[pos] = value; @@ -88,12 +88,16 @@ public double[] toArray() { } @Override - public DenseVector toDense() { - return new DenseVector(toArray()); + public DenseIntDoubleVector toDense() { + double[] result = new double[n]; + for (int i = 0; i < indices.length; i++) { + result[indices[i]] = values[i]; + } + return new DenseIntDoubleVector(result); } @Override - public SparseVector toSparse() { + public SparseIntDoubleVector toSparse() { return this; } @@ -105,7 +109,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) { return false; } - SparseVector that = (SparseVector) o; + SparseIntDoubleVector that = (SparseIntDoubleVector) o; return n == that.n && Arrays.equals(indices, that.indices) && Arrays.equals(values, that.values); @@ -199,7 +203,7 @@ public String toString() { } @Override - public SparseVector clone() { - return new SparseVector(n, indices.clone(), values.clone()); + public SparseIntDoubleVector clone() { + return new SparseIntDoubleVector(n, indices.clone(), values.clone()); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseLongDoubleVector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseLongDoubleVector.java new file mode 100644 index 000000000..9f7ca4bbe --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/SparseLongDoubleVector.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseLongDoubleVectorTypeInfoFactory; +import org.apache.flink.util.Preconditions; + +import java.util.Arrays; + +/** + * A sparse vector with long as keys and double as values. + * + *

TODO: Add processing logic for {@link SparseLongDoubleVector} for existing algorithms. + */ +@TypeInfo(SparseLongDoubleVectorTypeInfoFactory.class) +@PublicEvolving +public class SparseLongDoubleVector implements LongDoubleVector { + + public final long n; + public long[] indices; + public double[] values; + + public SparseLongDoubleVector(long n, long[] indices, double[] values) { + this.n = n; + this.indices = indices; + this.values = values; + if (!isIndicesSorted()) { + sortIndices(); + } + validateSortedData(); + } + + @Override + public Long size() { + return n; + } + + @Override + public Double get(Long i) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + return values[pos]; + } + return 0.; + } + + @Override + public void set(Long i, Double value) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + values[pos] = value; + } else if (value != 0.0) { + Preconditions.checkArgument(i < n, "Index out of bounds: " + i); + long[] indices = new long[this.indices.length + 1]; + double[] values = new double[this.indices.length + 1]; + System.arraycopy(this.indices, 0, indices, 0, -pos - 1); + System.arraycopy(this.values, 0, values, 0, -pos - 1); + indices[-pos - 1] = i; + values[-pos - 1] = value; + System.arraycopy(this.indices, -pos - 1, indices, -pos, this.indices.length + pos + 1); + System.arraycopy(this.values, -pos - 1, values, -pos, this.indices.length + pos + 1); + this.indices = indices; + this.values = values; + } + } + + @Override + public SparseLongDoubleVector toSparse() { + return this; + } + + /** + * Checks whether input data is validate. + * + *

This function does the following checks: + * + *

    + *
  • The indices array and values array are of the same size. + *
  • vector indices are in valid range. + *
  • vector indices are unique. + *
+ * + *

This function works as expected only when indices are sorted. + */ + private void validateSortedData() { + Preconditions.checkArgument( + indices.length == values.length, + "Indices size and values size should be the same."); + if (this.indices.length > 0) { + Preconditions.checkArgument( + this.indices[0] >= 0 && this.indices[this.indices.length - 1] < this.n, + "Index out of bound."); + } + for (int i = 1; i < this.indices.length; i++) { + Preconditions.checkArgument( + this.indices[i] > this.indices[i - 1], "Indices duplicated."); + } + } + + private boolean isIndicesSorted() { + for (int i = 1; i < this.indices.length; i++) { + if (this.indices[i] < this.indices[i - 1]) { + return false; + } + } + return true; + } + + /** Sorts the indices and values. */ + private void sortIndices() { + sortImpl(this.indices, this.values, 0, this.indices.length - 1); + } + + /** Sorts the indices and values using quick sort. */ + private static void sortImpl(long[] indices, double[] values, int low, int high) { + int pivotPos = (low + high) / 2; + long pivot = indices[pivotPos]; + swapIndexAndValue(indices, values, pivotPos, high); + + int pos = low - 1; + for (int i = low; i <= high; i++) { + if (indices[i] <= pivot) { + pos++; + swapIndexAndValue(indices, values, pos, i); + } + } + if (high > pos + 1) { + sortImpl(indices, values, pos + 1, high); + } + if (pos - 1 > low) { + sortImpl(indices, values, low, pos - 1); + } + } + + private static void swapIndexAndValue(long[] indices, double[] values, int index1, int index2) { + long tempIndex = indices[index1]; + indices[index1] = indices[index2]; + indices[index2] = tempIndex; + double tempValue = values[index1]; + values[index1] = values[index2]; + values[index2] = tempValue; + } + + @Override + public SparseLongDoubleVector clone() { + return new SparseLongDoubleVector(n, indices.clone(), values.clone()); + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vector.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vector.java index c63f1d629..df9e09fb3 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vector.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vector.java @@ -24,29 +24,29 @@ import java.io.Serializable; -/** A vector of double values. */ +/** A vector representation of numbers. */ @TypeInfo(VectorTypeInfoFactory.class) @PublicEvolving -public interface Vector extends Serializable { +public interface Vector, V extends Number> extends Serializable { /** Gets the size of the vector. */ - int size(); + K size(); /** Gets the value of the ith element. */ - double get(int i); + V get(K i); /** Sets the value of the ith element. */ - void set(int i, double value); + void set(K i, V value); - /** Converts the instance to a double array. */ - double[] toArray(); + /** Converts the instance to a primitive array. */ + Object toArray(); /** Converts the instance to a dense vector. */ - DenseVector toDense(); + Vector toDense(); /** Converts the instance to a sparse vector. */ - SparseVector toSparse(); + Vector toSparse(); /** Makes a deep copy of the vector. */ - Vector clone(); + Vector clone(); } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java index bb78ef2ce..d8867daf3 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/VectorWithNorm.java @@ -25,15 +25,15 @@ /** A vector with its norm. */ @TypeInfo(VectorWithNormTypeInfoFactory.class) public class VectorWithNorm { - public final Vector vector; + public final IntDoubleVector vector; public final double l2Norm; - public VectorWithNorm(Vector vector) { + public VectorWithNorm(IntDoubleVector vector) { this(vector, BLAS.norm2(vector)); } - public VectorWithNorm(Vector vector, double l2Norm) { + public VectorWithNorm(IntDoubleVector vector, double l2Norm) { this.vector = vector; this.l2Norm = l2Norm; } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java index 99c508080..8d10a53b5 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java @@ -25,12 +25,12 @@ public class Vectors { /** Creates a dense vector from its values. */ - public static DenseVector dense(double... values) { - return new DenseVector(values); + public static DenseIntDoubleVector dense(double... values) { + return new DenseIntDoubleVector(values); } /** Creates a sparse vector from its values. */ - public static SparseVector sparse(int size, int[] indices, double[] values) { - return new SparseVector(size, indices, values); + public static SparseIntDoubleVector sparse(int size, int[] indices, double[] values) { + return new SparseIntDoubleVector(size, indices, values); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorSerializer.java similarity index 74% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorSerializer.java index 5b6f984aa..88cf6247c 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorSerializer.java @@ -24,15 +24,15 @@ import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.util.Bits; import java.io.IOException; import java.util.Arrays; import java.util.Objects; -/** Specialized serializer for {@link DenseVector}. */ -public final class DenseVectorSerializer extends TypeSerializer { +/** Specialized serializer for {@link DenseIntDoubleVector}. */ +public final class DenseIntDoubleVectorSerializer extends TypeSerializer { private static final long serialVersionUID = 1L; @@ -46,22 +46,22 @@ public boolean isImmutableType() { } @Override - public TypeSerializer duplicate() { - return new DenseVectorSerializer(); + public TypeSerializer duplicate() { + return new DenseIntDoubleVectorSerializer(); } @Override - public DenseVector createInstance() { - return new DenseVector(EMPTY); + public DenseIntDoubleVector createInstance() { + return new DenseIntDoubleVector(EMPTY); } @Override - public DenseVector copy(DenseVector from) { - return new DenseVector(Arrays.copyOf(from.values, from.values.length)); + public DenseIntDoubleVector copy(DenseIntDoubleVector from) { + return new DenseIntDoubleVector(Arrays.copyOf(from.values, from.values.length)); } @Override - public DenseVector copy(DenseVector from, DenseVector reuse) { + public DenseIntDoubleVector copy(DenseIntDoubleVector from, DenseIntDoubleVector reuse) { if (from.values.length == reuse.values.length) { System.arraycopy(from.values, 0, reuse.values, 0, from.values.length); return reuse; @@ -75,7 +75,7 @@ public int getLength() { } @Override - public void serialize(DenseVector vector, DataOutputView target) throws IOException { + public void serialize(DenseIntDoubleVector vector, DataOutputView target) throws IOException { if (vector == null) { throw new IllegalArgumentException("The vector must not be null."); } @@ -93,11 +93,11 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept } @Override - public DenseVector deserialize(DataInputView source) throws IOException { + public DenseIntDoubleVector deserialize(DataInputView source) throws IOException { int len = source.readInt(); double[] values = new double[len]; readDoubleArray(values, source, len); - return new DenseVector(values); + return new DenseIntDoubleVector(values); } // Reads `len` double values from `source` into `dst`. @@ -116,7 +116,8 @@ private void readDoubleArray(double[] dst, DataInputView source, int len) throws } @Override - public DenseVector deserialize(DenseVector reuse, DataInputView source) throws IOException { + public DenseIntDoubleVector deserialize(DenseIntDoubleVector reuse, DataInputView source) + throws IOException { int len = source.readInt(); if (len == reuse.values.length) { readDoubleArray(reuse.values, source, len); @@ -125,7 +126,7 @@ public DenseVector deserialize(DenseVector reuse, DataInputView source) throws I double[] values = new double[len]; readDoubleArray(values, source, len); - return new DenseVector(values); + return new DenseIntDoubleVector(values); } @Override @@ -137,28 +138,28 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public boolean equals(Object o) { - return o instanceof DenseVectorSerializer; + return o instanceof DenseIntDoubleVectorSerializer; } @Override public int hashCode() { - return Objects.hashCode(DenseVectorSerializer.class); + return Objects.hashCode(DenseIntDoubleVectorSerializer.class); } // ------------------------------------------------------------------------ @Override - public TypeSerializerSnapshot snapshotConfiguration() { + public TypeSerializerSnapshot snapshotConfiguration() { return new DenseVectorSerializerSnapshot(); } /** Serializer configuration snapshot for compatibility and format evolution. */ @SuppressWarnings("WeakerAccess") public static final class DenseVectorSerializerSnapshot - extends SimpleTypeSerializerSnapshot { + extends SimpleTypeSerializerSnapshot { public DenseVectorSerializerSnapshot() { - super(DenseVectorSerializer::new); + super(DenseIntDoubleVectorSerializer::new); } } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfo.java similarity index 71% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfo.java index 72f9a47d9..1ceca7bb0 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfo.java @@ -21,15 +21,15 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; -/** A {@link TypeInformation} for the {@link DenseVector} type. */ -public class DenseVectorTypeInfo extends TypeInformation { +/** A {@link TypeInformation} for the {@link DenseIntDoubleVector} type. */ +public class DenseIntDoubleVectorTypeInfo extends TypeInformation { private static final long serialVersionUID = 1L; - public static final DenseVectorTypeInfo INSTANCE = new DenseVectorTypeInfo(); + public static final DenseIntDoubleVectorTypeInfo INSTANCE = new DenseIntDoubleVectorTypeInfo(); - public DenseVectorTypeInfo() {} + public DenseIntDoubleVectorTypeInfo() {} @Override public int getArity() { @@ -42,8 +42,8 @@ public int getTotalFields() { } @Override - public Class getTypeClass() { - return DenseVector.class; + public Class getTypeClass() { + return DenseIntDoubleVector.class; } @Override @@ -62,8 +62,8 @@ public boolean isKeyType() { } @Override - public TypeSerializer createSerializer(ExecutionConfig executionConfig) { - return new DenseVectorSerializer(); + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return new DenseIntDoubleVectorSerializer(); } // -------------------------------------------------------------------------------------------- @@ -75,12 +75,12 @@ public int hashCode() { @Override public boolean equals(Object obj) { - return obj instanceof DenseVectorTypeInfo; + return obj instanceof DenseIntDoubleVectorTypeInfo; } @Override public boolean canEqual(Object obj) { - return obj instanceof DenseVectorTypeInfo; + return obj instanceof DenseIntDoubleVectorTypeInfo; } @Override diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfoFactory.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfoFactory.java similarity index 81% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfoFactory.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfoFactory.java index 367f70638..e5920fa23 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfoFactory.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseIntDoubleVectorTypeInfoFactory.java @@ -21,21 +21,21 @@ import org.apache.flink.api.common.typeinfo.TypeInfoFactory; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.TypeExtractor; -import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; import java.lang.reflect.Type; import java.util.Map; /** * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link - * DenseVector}. + * DenseIntDoubleVector}. */ -public class DenseVectorTypeInfoFactory extends TypeInfoFactory { +public class DenseIntDoubleVectorTypeInfoFactory extends TypeInfoFactory { @Override - public TypeInformation createTypeInfo( + public TypeInformation createTypeInfo( Type t, Map> genericParameters) { - return new DenseVectorTypeInfo(); + return new DenseIntDoubleVectorTypeInfo(); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorSerializer.java similarity index 76% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorSerializer.java index a20a4ffff..93c6679b3 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorSerializer.java @@ -24,13 +24,14 @@ import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import java.io.IOException; import java.util.Arrays; -/** Specialized serializer for {@link SparseVector}. */ -public final class SparseVectorSerializer extends TypeSerializerSingleton { +/** Specialized serializer for {@link SparseIntDoubleVector}. */ +public final class SparseIntDoubleVectorSerializer + extends TypeSerializerSingleton { private static final long serialVersionUID = 1L; @@ -38,7 +39,8 @@ public final class SparseVectorSerializer extends TypeSerializerSingleton snapshotConfiguration() { + public TypeSerializerSnapshot snapshotConfiguration() { return new SparseVectorSerializerSnapshot(); } /** Serializer configuration snapshot for compatibility and format evolution. */ @SuppressWarnings("WeakerAccess") public static final class SparseVectorSerializerSnapshot - extends SimpleTypeSerializerSnapshot { + extends SimpleTypeSerializerSnapshot { public SparseVectorSerializerSnapshot() { super(() -> INSTANCE); diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfo.java similarity index 70% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfo.java index 06686f088..bdf6430ee 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfo.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfo.java @@ -22,11 +22,12 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; -/** A {@link TypeInformation} for the {@link SparseVector} type. */ -public class SparseVectorTypeInfo extends TypeInformation { - public static final SparseVectorTypeInfo INSTANCE = new SparseVectorTypeInfo(); +/** A {@link TypeInformation} for the {@link SparseIntDoubleVector} type. */ +public class SparseIntDoubleVectorTypeInfo extends TypeInformation { + public static final SparseIntDoubleVectorTypeInfo INSTANCE = + new SparseIntDoubleVectorTypeInfo(); @Override public boolean isBasicType() { @@ -49,8 +50,8 @@ public int getTotalFields() { } @Override - public Class getTypeClass() { - return SparseVector.class; + public Class getTypeClass() { + return SparseIntDoubleVector.class; } @Override @@ -59,8 +60,8 @@ public boolean isKeyType() { } @Override - public TypeSerializer createSerializer(ExecutionConfig executionConfig) { - return SparseVectorSerializer.INSTANCE; + public TypeSerializer createSerializer(ExecutionConfig executionConfig) { + return SparseIntDoubleVectorSerializer.INSTANCE; } @Override @@ -70,7 +71,7 @@ public String toString() { @Override public boolean equals(Object obj) { - return obj instanceof SparseVectorTypeInfo; + return obj instanceof SparseIntDoubleVectorTypeInfo; } @Override @@ -80,6 +81,6 @@ public int hashCode() { @Override public boolean canEqual(Object obj) { - return obj instanceof SparseVectorTypeInfo; + return obj instanceof SparseIntDoubleVectorTypeInfo; } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfoFactory.java similarity index 80% rename from flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java rename to flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfoFactory.java index 01c10367e..07446d34f 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorTypeInfoFactory.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseIntDoubleVectorTypeInfoFactory.java @@ -22,19 +22,19 @@ import org.apache.flink.api.common.typeinfo.TypeInfoFactory; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.TypeExtractor; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import java.lang.reflect.Type; import java.util.Map; /** * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link - * SparseVector}. + * SparseIntDoubleVector}. */ -public class SparseVectorTypeInfoFactory extends TypeInfoFactory { +public class SparseIntDoubleVectorTypeInfoFactory extends TypeInfoFactory { @Override - public TypeInformation createTypeInfo( + public TypeInformation createTypeInfo( Type type, Map> map) { - return new SparseVectorTypeInfo(); + return new SparseIntDoubleVectorTypeInfo(); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorSerializer.java new file mode 100644 index 000000000..f1cb239d0 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorSerializer.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; + +import java.io.IOException; +import java.util.Arrays; + +/** Specialized serializer for {@link SparseLongDoubleVector}. */ +public final class SparseLongDoubleVectorSerializer + extends TypeSerializerSingleton { + + private static final long serialVersionUID = 1L; + + private static final double[] EMPTY_DOUBLE_ARRAY = new double[0]; + + private static final long[] EMPTY_LONG_ARRAY = new long[0]; + + public static final SparseLongDoubleVectorSerializer INSTANCE = + new SparseLongDoubleVectorSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public SparseLongDoubleVector createInstance() { + return new SparseLongDoubleVector(0, EMPTY_LONG_ARRAY, EMPTY_DOUBLE_ARRAY); + } + + @Override + public SparseLongDoubleVector copy(SparseLongDoubleVector from) { + return new SparseLongDoubleVector( + from.n, + Arrays.copyOf(from.indices, from.indices.length), + Arrays.copyOf(from.values, from.values.length)); + } + + @Override + public SparseLongDoubleVector copy(SparseLongDoubleVector from, SparseLongDoubleVector reuse) { + if (from.values.length == reuse.values.length && from.n == reuse.n) { + System.arraycopy(from.values, 0, reuse.values, 0, from.values.length); + System.arraycopy(from.indices, 0, reuse.indices, 0, from.indices.length); + return reuse; + } + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(SparseLongDoubleVector vector, DataOutputView target) throws IOException { + if (vector == null) { + throw new IllegalArgumentException("The vector must not be null."); + } + + target.writeLong(vector.n); + final int len = vector.values.length; + target.writeInt(len); + // TODO: optimize the serialization/deserialization process of SparseVectorSerializer. + for (int i = 0; i < len; i++) { + target.writeLong(vector.indices[i]); + target.writeDouble(vector.values[i]); + } + } + + // Reads `len` int values from `source` into `indices` and `len` double values from `source` + // into `values`. + private void readSparseVectorArrays( + long[] indices, double[] values, DataInputView source, int len) throws IOException { + for (int i = 0; i < len; i++) { + indices[i] = source.readLong(); + values[i] = source.readDouble(); + } + } + + @Override + public SparseLongDoubleVector deserialize(DataInputView source) throws IOException { + long n = source.readLong(); + int len = source.readInt(); + long[] indices = new long[len]; + double[] values = new double[len]; + readSparseVectorArrays(indices, values, source, len); + return new SparseLongDoubleVector(n, indices, values); + } + + @Override + public SparseLongDoubleVector deserialize(SparseLongDoubleVector reuse, DataInputView source) + throws IOException { + long n = source.readLong(); + int len = source.readInt(); + if (reuse.n == n && reuse.values.length == len) { + readSparseVectorArrays(reuse.indices, reuse.values, source, len); + return reuse; + } + + long[] indices = new long[len]; + double[] values = new double[len]; + readSparseVectorArrays(indices, values, source, len); + return new SparseLongDoubleVector(n, indices, values); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + long n = source.readLong(); + int len = source.readInt(); + + target.writeLong(n); + target.writeInt(len); + + target.write(source, len * 16); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new SparseLongDoubleVectorSerializerSnapshot(); + } + + /** Serializer configuration snapshot for compatibility and format evolution. */ + @SuppressWarnings("WeakerAccess") + public static final class SparseLongDoubleVectorSerializerSnapshot + extends SimpleTypeSerializerSnapshot { + + public SparseLongDoubleVectorSerializerSnapshot() { + super(() -> INSTANCE); + } + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfo.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfo.java new file mode 100644 index 000000000..d574043eb --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfo.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; + +/** A {@link TypeInformation} for the {@link SparseLongDoubleVector} type. */ +public class SparseLongDoubleVectorTypeInfo extends TypeInformation { + public static final SparseLongDoubleVectorTypeInfo INSTANCE = + new SparseLongDoubleVectorTypeInfo(); + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 3; + } + + @Override + public int getTotalFields() { + return 3; + } + + @Override + public Class getTypeClass() { + return SparseLongDoubleVector.class; + } + + @Override + public boolean isKeyType() { + return false; + } + + @Override + public TypeSerializer createSerializer( + ExecutionConfig executionConfig) { + return SparseLongDoubleVectorSerializer.INSTANCE; + } + + @Override + public String toString() { + return "SparseVectorType"; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof SparseIntDoubleVectorSerializer; + } + + @Override + public int hashCode() { + return getClass().hashCode(); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof SparseIntDoubleVectorSerializer; + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfoFactory.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfoFactory.java new file mode 100644 index 000000000..fb8753382 --- /dev/null +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseLongDoubleVectorTypeInfoFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeinfo.TypeInfoFactory; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; + +import java.lang.reflect.Type; +import java.util.Map; + +/** + * Used by {@link TypeExtractor} to create a {@link TypeInformation} for implementations of {@link + * SparseLongDoubleVector}. + */ +public class SparseLongDoubleVectorTypeInfoFactory extends TypeInfoFactory { + @Override + public TypeInformation createTypeInfo( + Type type, Map> map) { + return new SparseLongDoubleVectorTypeInfo(); + } +} diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorSerializer.java index 51b9307aa..d635d2296 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorSerializer.java @@ -24,23 +24,29 @@ import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; import org.apache.flink.ml.linalg.Vector; import java.io.IOException; -/** Specialized serializer for {@link Vector}. */ +/** Specialized serializer for {@link IntDoubleVector}. */ public final class VectorSerializer extends TypeSerializerSingleton { private static final long serialVersionUID = 1L; private static final double[] EMPTY = new double[0]; - private final DenseVectorSerializer denseVectorSerializer = new DenseVectorSerializer(); + private final DenseIntDoubleVectorSerializer denseVectorSerializer = + new DenseIntDoubleVectorSerializer(); - private static final SparseVectorSerializer SPARSE_VECTOR_SERIALIZER = - SparseVectorSerializer.INSTANCE; + private static final SparseIntDoubleVectorSerializer SPARSE_INT_DOUBLE_VECTOR_SERIALIZER = + SparseIntDoubleVectorSerializer.INSTANCE; + + private static final SparseLongDoubleVectorSerializer SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER = + SparseLongDoubleVectorSerializer.INSTANCE; @Override public boolean isImmutableType() { @@ -49,25 +55,32 @@ public boolean isImmutableType() { @Override public Vector createInstance() { - return new DenseVector(EMPTY); + return new DenseIntDoubleVector(EMPTY); } @Override public Vector copy(Vector from) { - if (from instanceof DenseVector) { - return denseVectorSerializer.copy((DenseVector) from); + if (from instanceof DenseIntDoubleVector) { + return denseVectorSerializer.copy((DenseIntDoubleVector) from); + } else if (from instanceof SparseIntDoubleVector) { + return SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.copy((SparseIntDoubleVector) from); } else { - return SPARSE_VECTOR_SERIALIZER.copy((SparseVector) from); + return SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.copy((SparseLongDoubleVector) from); } } @Override public Vector copy(Vector from, Vector reuse) { assert from.getClass() == reuse.getClass(); - if (from instanceof DenseVector) { - return denseVectorSerializer.copy((DenseVector) from, (DenseVector) reuse); + if (from instanceof DenseIntDoubleVector) { + return denseVectorSerializer.copy( + (DenseIntDoubleVector) from, (DenseIntDoubleVector) reuse); + } else if (from instanceof SparseIntDoubleVector) { + return SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.copy( + (SparseIntDoubleVector) from, (SparseIntDoubleVector) reuse); } else { - return SPARSE_VECTOR_SERIALIZER.copy((SparseVector) from, (SparseVector) reuse); + return SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.copy( + (SparseLongDoubleVector) from, (SparseLongDoubleVector) reuse); } } @@ -78,12 +91,15 @@ public int getLength() { @Override public void serialize(Vector vector, DataOutputView target) throws IOException { - if (vector instanceof DenseVector) { + if (vector instanceof DenseIntDoubleVector) { target.writeByte(0); - denseVectorSerializer.serialize((DenseVector) vector, target); - } else { + denseVectorSerializer.serialize((DenseIntDoubleVector) vector, target); + } else if (vector instanceof SparseIntDoubleVector) { target.writeByte(1); - SPARSE_VECTOR_SERIALIZER.serialize((SparseVector) vector, target); + SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.serialize((SparseIntDoubleVector) vector, target); + } else { + target.writeByte(2); + SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.serialize((SparseLongDoubleVector) vector, target); } } @@ -92,20 +108,25 @@ public Vector deserialize(DataInputView source) throws IOException { byte type = source.readByte(); if (type == 0) { return denseVectorSerializer.deserialize(source); + } else if (type == 1) { + return SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.deserialize(source); } else { - return SPARSE_VECTOR_SERIALIZER.deserialize(source); + return SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.deserialize(source); } } @Override public Vector deserialize(Vector reuse, DataInputView source) throws IOException { byte type = source.readByte(); - assert type == 0 && reuse instanceof DenseVector - || type == 1 && reuse instanceof SparseVector; + assert type == 0 && reuse instanceof DenseIntDoubleVector + || type == 1 && reuse instanceof SparseIntDoubleVector + || type == 2 && reuse instanceof SparseLongDoubleVector; if (type == 0) { return denseVectorSerializer.deserialize(source); + } else if (type == 1) { + return SPARSE_INT_DOUBLE_VECTOR_SERIALIZER.deserialize(source); } else { - return SPARSE_VECTOR_SERIALIZER.deserialize(source); + return SPARSE_LONG_DOUBLE_VECTOR_SERIALIZER.deserialize(source); } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorTypeInfo.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorTypeInfo.java index 672dedf63..71b7a28d3 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorTypeInfo.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorTypeInfo.java @@ -22,9 +22,10 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vector; -/** A {@link TypeInformation} for the {@link Vector} type. */ +/** A {@link TypeInformation} for the {@link IntDoubleVector} type. */ public class VectorTypeInfo extends TypeInformation { public static final VectorTypeInfo INSTANCE = new VectorTypeInfo(); diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java index 92d1de165..8f8c4a21c 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/VectorWithNormSerializer.java @@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.VectorWithNorm; import java.io.IOException; @@ -50,18 +50,18 @@ public TypeSerializer duplicate() { @Override public VectorWithNorm createInstance() { - return new VectorWithNorm(new DenseVector(EMPTY)); + return new VectorWithNorm(new DenseIntDoubleVector(EMPTY)); } @Override public VectorWithNorm copy(VectorWithNorm from) { - Vector vector = vectorSerializer.copy(from.vector); + IntDoubleVector vector = (IntDoubleVector) vectorSerializer.copy(from.vector); return new VectorWithNorm(vector, from.l2Norm); } @Override public VectorWithNorm copy(VectorWithNorm from, VectorWithNorm reuse) { - Vector vector = vectorSerializer.copy(from.vector, reuse.vector); + IntDoubleVector vector = (IntDoubleVector) vectorSerializer.copy(from.vector, reuse.vector); return new VectorWithNorm(vector, from.l2Norm); } @@ -78,7 +78,7 @@ public void serialize(VectorWithNorm from, DataOutputView dataOutputView) throws @Override public VectorWithNorm deserialize(DataInputView dataInputView) throws IOException { - Vector vector = vectorSerializer.deserialize(dataInputView); + IntDoubleVector vector = (IntDoubleVector) vectorSerializer.deserialize(dataInputView); double l2NormSquare = dataInputView.readDouble(); return new VectorWithNorm(vector, l2NormSquare); } @@ -86,7 +86,8 @@ public VectorWithNorm deserialize(DataInputView dataInputView) throws IOExceptio @Override public VectorWithNorm deserialize(VectorWithNorm reuse, DataInputView dataInputView) throws IOException { - Vector vector = vectorSerializer.deserialize(reuse.vector, dataInputView); + IntDoubleVector vector = + (IntDoubleVector) vectorSerializer.deserialize(reuse.vector, dataInputView); double l2NormSquare = dataInputView.readDouble(); return new VectorWithNorm(vector, l2NormSquare); } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/param/VectorParam.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/param/VectorParam.java index 7fbf32b96..179fda831 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/param/VectorParam.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/param/VectorParam.java @@ -18,30 +18,30 @@ package org.apache.flink.ml.param; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.SparseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseIntDoubleVector; import java.util.List; import java.util.Map; /** Class for the Vector parameter. */ -public class VectorParam extends Param { +public class VectorParam extends Param { public VectorParam( String name, String description, - Vector defaultValue, - ParamValidator validator) { - super(name, Vector.class, description, defaultValue, validator); + IntDoubleVector defaultValue, + ParamValidator validator) { + super(name, IntDoubleVector.class, description, defaultValue, validator); } - public VectorParam(String name, String description, Vector defaultValue) { + public VectorParam(String name, String description, IntDoubleVector defaultValue) { this(name, description, defaultValue, ParamValidators.alwaysTrue()); } @Override - public Vector jsonDecode(Object object) { + public IntDoubleVector jsonDecode(Object object) { Map vecValues = (Map) object; if (vecValues.size() == 1) { List list = (List) vecValues.get("values"); @@ -49,7 +49,7 @@ public Vector jsonDecode(Object object) { for (int i = 0; i < values.length; ++i) { values[i] = list.get(i); } - return new DenseVector(values); + return new DenseIntDoubleVector(values); } else if (vecValues.size() == 3) { List valuesList = (List) vecValues.get("values"); List indicesList = (List) vecValues.get("indices"); @@ -60,7 +60,7 @@ public Vector jsonDecode(Object object) { values[i] = valuesList.get(i); indices[i] = indicesList.get(i); } - return new SparseVector(n, indices, values); + return new SparseIntDoubleVector(n, indices, values); } else { throw new UnsupportedOperationException("Vector parameter is invalid."); } diff --git a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java index 5fff534a6..cfa553d41 100644 --- a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java +++ b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java @@ -27,7 +27,7 @@ public class BLASTest { private static final double TOLERANCE = 1e-7; - private static final DenseVector inputDenseVec = Vectors.dense(1, -2, 3, 4, -5); + private static final DenseIntDoubleVector inputDenseVec = Vectors.dense(1, -2, 3, 4, -5); private static final DenseMatrix inputDenseMat = new DenseMatrix(2, 5, new double[] {1, -2, 3, 4, -5, 1, -2, 3, 4, -5}); @@ -39,13 +39,14 @@ public void testAsum() { @Test public void testAxpy() { // Tests axpy(dense, dense). - DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); + DenseIntDoubleVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); BLAS.axpy(1, inputDenseVec, anotherDenseVec); double[] expectedResult = new double[] {2, 0, 6, 8, 0}; assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); // Tests axpy(sparse, dense). - SparseVector sparseVec = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVec = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); BLAS.axpy(2, sparseVec, anotherDenseVec); expectedResult = new double[] {4, 0, 12, 8, 10}; assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); @@ -54,13 +55,14 @@ public void testAxpy() { @Test public void testAxpyK() { // Tests axpy(dense, dense, k). - DenseVector anotherDenseVec = Vectors.dense(1, 2, 3); + DenseIntDoubleVector anotherDenseVec = Vectors.dense(1, 2, 3); BLAS.axpy(1, inputDenseVec, anotherDenseVec, 3); double[] expectedResult = new double[] {2, 0, 6}; assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); // Tests axpy(sparse, dense, k). - SparseVector sparseVec = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVec = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5, 6, 7); BLAS.axpy(2, sparseVec, anotherDenseVec, 5); expectedResult = new double[] {3, 2, 9, 4, 15, 6, 7}; @@ -69,10 +71,10 @@ public void testAxpyK() { @Test public void testDot() { - DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); - SparseVector sparseVector1 = + DenseIntDoubleVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); + SparseIntDoubleVector sparseVector1 = Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {1., 1., 4.}); - SparseVector sparseVector2 = + SparseIntDoubleVector sparseVector2 = Vectors.sparse(5, new int[] {1, 3, 4}, new double[] {1., 2., 1.}); // Tests dot(dense, dense). assertEquals(-3, BLAS.dot(inputDenseVec, anotherDenseVec), TOLERANCE); @@ -88,7 +90,8 @@ public void testDot() { public void testNorm2() { assertEquals(Math.sqrt(55), BLAS.norm2(inputDenseVec), TOLERANCE); - SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVector = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); assertEquals(Math.sqrt(35), BLAS.norm2(sparseVector), TOLERANCE); } @@ -96,7 +99,8 @@ public void testNorm2() { public void testNorm() { assertEquals(Math.sqrt(55), BLAS.norm(inputDenseVec, 2.0), TOLERANCE); - SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVector = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); assertEquals(5.0, BLAS.norm(sparseVector, Double.POSITIVE_INFINITY), TOLERANCE); assertEquals(5.348481241239363, BLAS.norm(sparseVector, 3.0), TOLERANCE); @@ -109,7 +113,7 @@ public void testScal() { double[] expectedDenseResult = new double[] {2, -4, 6, 8, -10}; assertArrayEquals(expectedDenseResult, inputDenseVec.values, TOLERANCE); - SparseVector inputSparseVector = + SparseIntDoubleVector inputSparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); BLAS.scal(1.5, inputSparseVector); @@ -122,7 +126,7 @@ public void testScal() { @Test public void testGemv() { - DenseVector anotherDenseVec = Vectors.dense(1.0, 2.0); + DenseIntDoubleVector anotherDenseVec = Vectors.dense(1.0, 2.0); BLAS.gemv(-2.0, inputDenseMat, false, inputDenseVec, 0.0, anotherDenseVec); double[] expectedResult = new double[] {96.0, -60.0}; assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); @@ -131,16 +135,18 @@ public void testGemv() { @Test public void testHDot() { // Tests hDot(sparse, sparse). - SparseVector sparseVec1 = Vectors.sparse(5, new int[] {0, 2, 3}, new double[] {1, 3, 5}); - SparseVector sparseVec2 = Vectors.sparse(5, new int[] {0, 1, 4}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVec1 = + Vectors.sparse(5, new int[] {0, 2, 3}, new double[] {1, 3, 5}); + SparseIntDoubleVector sparseVec2 = + Vectors.sparse(5, new int[] {0, 1, 4}, new double[] {1, 3, 5}); BLAS.hDot(sparseVec1, sparseVec2); - assertEquals(5, sparseVec2.size()); + assertEquals(5, sparseVec2.size().intValue()); assertArrayEquals(new int[] {0, 1, 4}, sparseVec2.indices); assertArrayEquals(new double[] {1, 0, 0}, sparseVec2.values, TOLERANCE); // Tests hDot(dense, dense). - DenseVector denseVec1 = Vectors.dense(1, 2, 3, 4, 5); - DenseVector denseVec2 = Vectors.dense(1, 2, 3, 4, 5); + DenseIntDoubleVector denseVec1 = Vectors.dense(1, 2, 3, 4, 5); + DenseIntDoubleVector denseVec2 = Vectors.dense(1, 2, 3, 4, 5); BLAS.hDot(denseVec1, denseVec2); double[] expectedResult = new double[] {1, 4, 9, 16, 25}; assertArrayEquals(expectedResult, denseVec2.values, TOLERANCE); @@ -151,9 +157,9 @@ public void testHDot() { assertArrayEquals(expectedResult, denseVec1.values, TOLERANCE); // Tests hDot(dense, sparse). - DenseVector denseVec3 = Vectors.dense(1, 2, 3, 4, 5); + DenseIntDoubleVector denseVec3 = Vectors.dense(1, 2, 3, 4, 5); BLAS.hDot(denseVec3, sparseVec1); - assertEquals(5, sparseVec1.size()); + assertEquals(5, sparseVec1.size().intValue()); assertArrayEquals(new int[] {0, 2, 3}, sparseVec1.indices); assertArrayEquals(new double[] {1, 9, 20}, sparseVec1.values, TOLERANCE); } diff --git a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java index 427403d17..3ffc429ea 100644 --- a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java +++ b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java @@ -23,15 +23,15 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -/** Tests the behavior of {@link DenseVector}. */ +/** Tests the behavior of {@link DenseIntDoubleVector}. */ public class DenseVectorTest { private static final double TOLERANCE = 1e-7; @Test public void testClone() { - DenseVector denseVec = Vectors.dense(1, 2, 3); - DenseVector clonedDenseVec = denseVec.clone(); + DenseIntDoubleVector denseVec = Vectors.dense(1, 2, 3); + DenseIntDoubleVector clonedDenseVec = denseVec.clone(); assertArrayEquals(clonedDenseVec.values, new double[] {1, 2, 3}, TOLERANCE); clonedDenseVec.values[0] = -1; @@ -41,10 +41,10 @@ public void testClone() { @Test public void testGetAndSet() { - DenseVector denseVec = Vectors.dense(1, 2, 3); + DenseIntDoubleVector denseVec = Vectors.dense(1, 2, 3); assertEquals(1, denseVec.get(0), TOLERANCE); - denseVec.set(0, 2); + denseVec.set(0, 2.0); assertEquals(2, denseVec.get(0), TOLERANCE); } } diff --git a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java index 916963f18..e928c6b14 100644 --- a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java +++ b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java @@ -20,7 +20,7 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.typeinfo.SparseVectorSerializer; +import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorSerializer; import org.apache.commons.io.output.ByteArrayOutputStream; import org.junit.Assert; @@ -32,7 +32,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -/** Tests the behavior of {@link SparseVector}. */ +/** Tests the behavior of {@link SparseIntDoubleVector}. */ public class SparseVectorTest { private static final double TOLERANCE = 1e-7; @@ -42,7 +42,7 @@ public void testConstructor() { int[] indices = new int[] {0, 2, 3}; double[] values = new double[] {0.1, 0.3, 0.4}; - SparseVector vector = Vectors.sparse(n, indices, values); + SparseIntDoubleVector vector = Vectors.sparse(n, indices, values); assertEquals(n, vector.n); assertArrayEquals(indices, vector.indices); assertArrayEquals(values, vector.values, 1e-5); @@ -67,13 +67,13 @@ public void testDuplicateIndex() { @Test public void testAllZeroVector() { int n = 4; - SparseVector vector = Vectors.sparse(n, new int[0], new double[0]); + SparseIntDoubleVector vector = Vectors.sparse(n, new int[0], new double[0]); assertArrayEquals(vector.toArray(), new double[n], 1e-5); } @Test public void testUnsortedIndex() { - SparseVector vector; + SparseIntDoubleVector vector; vector = Vectors.sparse(4, new int[] {2}, new double[] {0.3}); assertEquals(4, vector.n); @@ -115,8 +115,8 @@ public void testSerializer() throws IOException { int n = 4; int[] indices = new int[] {0, 2, 3}; double[] values = new double[] {0.1, 0.3, 0.4}; - SparseVector vector = Vectors.sparse(n, indices, values); - SparseVectorSerializer serializer = SparseVectorSerializer.INSTANCE; + SparseIntDoubleVector vector = Vectors.sparse(n, indices, values); + SparseIntDoubleVectorSerializer serializer = SparseIntDoubleVectorSerializer.INSTANCE; ByteArrayOutputStream bOutput = new ByteArrayOutputStream(1024); DataOutputViewStreamWrapper output = new DataOutputViewStreamWrapper(bOutput); @@ -125,7 +125,7 @@ public void testSerializer() throws IOException { byte[] b = bOutput.toByteArray(); ByteArrayInputStream bInput = new ByteArrayInputStream(b); DataInputViewStreamWrapper input = new DataInputViewStreamWrapper(bInput); - SparseVector vector2 = serializer.deserialize(input); + SparseIntDoubleVector vector2 = serializer.deserialize(input); assertEquals(vector.n, vector2.n); assertArrayEquals(vector.indices, vector2.indices); @@ -134,9 +134,9 @@ public void testSerializer() throws IOException { @Test public void testClone() { - SparseVector sparseVec = Vectors.sparse(3, new int[] {0, 2}, new double[] {1, 3}); - SparseVector clonedSparseVec = sparseVec.clone(); - assertEquals(3, clonedSparseVec.size()); + SparseIntDoubleVector sparseVec = Vectors.sparse(3, new int[] {0, 2}, new double[] {1, 3}); + SparseIntDoubleVector clonedSparseVec = sparseVec.clone(); + assertEquals(3, clonedSparseVec.size().intValue()); assertArrayEquals(clonedSparseVec.indices, new int[] {0, 2}); assertArrayEquals(clonedSparseVec.values, new double[] {1, 3}, TOLERANCE); @@ -150,7 +150,7 @@ public void testClone() { @Test public void testGetAndSet() { - SparseVector sparseVec = Vectors.sparse(4, new int[] {2}, new double[] {0.3}); + SparseIntDoubleVector sparseVec = Vectors.sparse(4, new int[] {2}, new double[] {0.3}); assertEquals(0, sparseVec.get(0), TOLERANCE); assertEquals(0.3, sparseVec.get(2), TOLERANCE); diff --git a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java index 25b45b089..31d6cf691 100644 --- a/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java +++ b/flink-ml-servable-core/src/test/java/org/apache/flink/ml/linalg/VectorWithNormTest.java @@ -26,12 +26,13 @@ public class VectorWithNormTest { @Test public void testL2Norm() { - DenseVector denseVector = Vectors.dense(1, 2, 3); + DenseIntDoubleVector denseVector = Vectors.dense(1, 2, 3); VectorWithNorm denseVectorWithNorm = new VectorWithNorm(denseVector); assertEquals(denseVector, denseVectorWithNorm.vector); assertEquals(Math.sqrt(14), denseVectorWithNorm.l2Norm, 1e-7); - SparseVector sparseVector = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 2, 3}); + SparseIntDoubleVector sparseVector = + Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 2, 3}); VectorWithNorm sparseVectorWithNorm = new VectorWithNorm(sparseVector); assertEquals(sparseVector, sparseVectorWithNorm.vector); assertEquals(Math.sqrt(14), sparseVectorWithNorm.l2Norm, 1e-7); diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java index ebaec117b..ade944646 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java @@ -21,8 +21,8 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorSerializer; import org.apache.flink.util.Preconditions; import java.io.IOException; @@ -33,7 +33,7 @@ /** Model data of {@link LogisticRegressionModelServable}. */ public class LogisticRegressionModelData { - public DenseVector coefficient; + public DenseIntDoubleVector coefficient; public long startIndex; @@ -43,12 +43,12 @@ public class LogisticRegressionModelData { public LogisticRegressionModelData() {} - public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) { + public LogisticRegressionModelData(DenseIntDoubleVector coefficient, long modelVersion) { this(coefficient, 0L, coefficient.size(), modelVersion); } public LogisticRegressionModelData( - DenseVector coefficient, long startIndex, long endIndex, long modelVersion) { + DenseIntDoubleVector coefficient, long startIndex, long endIndex, long modelVersion) { this.coefficient = coefficient; this.startIndex = startIndex; this.endIndex = endIndex; @@ -65,7 +65,7 @@ public void encode(OutputStream outputStream) throws IOException { DataOutputViewStreamWrapper dataOutputViewStreamWrapper = new DataOutputViewStreamWrapper(outputStream); - DenseVectorSerializer serializer = new DenseVectorSerializer(); + DenseIntDoubleVectorSerializer serializer = new DenseIntDoubleVectorSerializer(); serializer.serialize(coefficient, dataOutputViewStreamWrapper); dataOutputViewStreamWrapper.writeLong(startIndex); dataOutputViewStreamWrapper.writeLong(endIndex); @@ -82,8 +82,8 @@ static LogisticRegressionModelData decode(InputStream inputStream) throws IOExce DataInputViewStreamWrapper dataInputViewStreamWrapper = new DataInputViewStreamWrapper(inputStream); - DenseVectorSerializer serializer = new DenseVectorSerializer(); - DenseVector coefficient = serializer.deserialize(dataInputViewStreamWrapper); + DenseIntDoubleVectorSerializer serializer = new DenseIntDoubleVectorSerializer(); + DenseIntDoubleVector coefficient = serializer.deserialize(dataInputViewStreamWrapper); long startIndex = dataInputViewStreamWrapper.readLong(); long endIndex = dataInputViewStreamWrapper.readLong(); long modelVersion = dataInputViewStreamWrapper.readLong(); @@ -103,7 +103,7 @@ public static LogisticRegressionModelData mergeSegments( dim < Integer.MAX_VALUE, "The dimension of logistic regression model is larger than INT.MAX. Please consider using distributed inference."); int intDim = (int) dim; - DenseVector mergedCoefficient = new DenseVector(intDim); + DenseIntDoubleVector mergedCoefficient = new DenseIntDoubleVector(intDim); for (LogisticRegressionModelData segment : segments) { int startIndex = (int) segment.startIndex; int endIndex = (int) segment.endIndex; diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java index 6662b6ccf..468392215 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java @@ -20,8 +20,8 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.servable.api.DataFrame; @@ -61,12 +61,12 @@ public LogisticRegressionModelServable() { @Override public DataFrame transform(DataFrame input) { List predictionResults = new ArrayList<>(); - List rawPredictionResults = new ArrayList<>(); + List rawPredictionResults = new ArrayList<>(); int featuresColIndex = input.getIndex(getFeaturesCol()); for (Row row : input.collect()) { - Vector features = (Vector) row.get(featuresColIndex); - Tuple2 dataPoint = transform(features); + IntDoubleVector features = (IntDoubleVector) row.get(featuresColIndex); + Tuple2 dataPoint = transform(features); predictionResults.add(dataPoint.f0); rawPredictionResults.add(dataPoint.f1); } @@ -114,7 +114,7 @@ public static LogisticRegressionModelServable load(String path) throws IOExcepti * @param feature The input feature. * @return The prediction label and the raw probabilities. */ - protected Tuple2 transform(Vector feature) { + protected Tuple2 transform(IntDoubleVector feature) { double dotValue = BLAS.dot(feature, modelData.coefficient); double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue)); return Tuple2.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob)); From 7325477e07a4ecf482c7937709e8a5dac8c3a470 Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Wed, 7 Jun 2023 09:24:43 +0800 Subject: [PATCH 10/18] support allreduce aggregator for double[] --- .../classification/linearsvc/LinearSVC.java | 2 +- .../LogisticRegression.java | 2 +- .../LogisticRegressionWithFtrl.java | 44 ++++++--- .../common/lossfunc/BinaryLogisticLoss.java | 17 ++-- .../flink/ml/common/lossfunc/HingeLoss.java | 18 ++-- .../ml/common/lossfunc/LeastSquareLoss.java | 14 +-- .../apache/flink/ml/common/optimizer/SGD.java | 2 +- .../flink/ml/common/ps/ServerAgent.java | 5 +- .../flink/ml/common/ps/ServerOperator.java | 9 +- .../flink/ml/common/ps/WorkerOperator.java | 3 +- .../ml/common/ps/message/AllReduceM.java | 50 ++++++++++- .../flink/ml/common/ps/message/Message.java | 4 +- .../ml/common/ps/message/MessageData.java | 75 ++++++++++++++++ .../ml/common/ps/message/MessageUtils.java | 13 +++ .../flink/ml/common/ps/message/Meta.java | 83 +++++++++++++++++ .../common/ps/training/ComputeGradients.java | 27 +++--- .../ml/common/ps/training/ComputeIndices.java | 14 +-- .../ml/common/ps/training/TrainingUtils.java | 9 +- .../linearregression/LinearRegression.java | 2 +- .../LogisticRegressionWithFtrlTest.java | 90 +++++++++++++++---- .../feature/LabeledLargePointWithWeight.java | 80 ++++++++--------- .../feature/LabeledPointWithWeight.java | 35 ++------ 22 files changed, 436 insertions(+), 162 deletions(-) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageData.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Meta.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java index 7fce6fe20..d03101ed4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java @@ -85,7 +85,7 @@ public LinearSVCModel fit(Table... inputs) { DataStream initModelData = DataStreamUtils.reduce( - trainData.map(x -> x.getFeatures().size()), + trainData.map(x -> (Integer) x.features.size()), (ReduceFunction) (t0, t1) -> { Preconditions.checkState( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index 41e5ec8f9..22ae3a81d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -93,7 +93,7 @@ public LogisticRegressionModel fit(Table... inputs) { DataStream initModelData = DataStreamUtils.reduce( - trainData.map(x -> x.getFeatures().size()), + trainData.map(x -> (Integer) x.features.size()), (ReduceFunction) (t0, t1) -> { Preconditions.checkState( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java index 2c30740b3..6db78b766 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -21,11 +21,10 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; -import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; import org.apache.flink.ml.common.ps.training.ComputeGradients; import org.apache.flink.ml.common.ps.training.ComputeIndices; @@ -36,6 +35,9 @@ import org.apache.flink.ml.common.ps.training.SerializableConsumer; import org.apache.flink.ml.common.ps.training.TrainingUtils; import org.apache.flink.ml.common.ps.updater.FTRL; +import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.LongDoubleVector; +import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; @@ -78,10 +80,10 @@ public LogisticRegressionModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream trainData = + DataStream trainData = tEnv.toDataStream(inputs[0]) .map( - (MapFunction) + (MapFunction) dataPoint -> { double weight = getWeightCol() == null @@ -100,9 +102,9 @@ public LogisticRegressionModel fit(Table... inputs) { throw new RuntimeException( "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); } - Tuple2 features = + Vector features = dataPoint.getFieldAs(getFeaturesCol()); - return new LabeledLargePointWithWeight( + return new LabeledPointWithWeight( features, label, weight); }); @@ -112,17 +114,32 @@ public LogisticRegressionModel fit(Table... inputs) { } else { modelDim = DataStreamUtils.reduce( - trainData.map(x -> x.features.f0[x.features.f0.length - 1]), + trainData.map( + x -> { + Vector feature = x.features; + long dim; + if (feature instanceof IntDoubleVector) { + dim = + ((IntDoubleVector) feature) + .size() + .intValue(); + } else { + dim = + ((LongDoubleVector) feature) + .size() + .longValue(); + } + return dim; + }), (ReduceFunction) Math::max) - .map((MapFunction) value -> value + 1); + .map((MapFunction) value -> value); } - MiniBatchMLSession mlSession = + MiniBatchMLSession mlSession = new MiniBatchMLSession<>( - getGlobalBatchSize(), - TypeInformation.of(LabeledLargePointWithWeight.class)); + getGlobalBatchSize(), TypeInformation.of(LabeledPointWithWeight.class)); - IterationStageList> iterationStages = + IterationStageList> iterationStages = new IterationStageList<>(mlSession); iterationStages .addStage(new ComputeIndices()) @@ -136,8 +153,7 @@ public LogisticRegressionModel fit(Table... inputs) { (SerializableSupplier) () -> mlSession.pushIndices, (SerializableSupplier) () -> mlSession.pushValues)) .setTerminationCriteria( - (SerializableFunction< - MiniBatchMLSession, Boolean>) + (SerializableFunction, Boolean>) o -> o.iterationId >= getMaxIter()); FTRL ftrl = new FTRL( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java index ce7482626..23c8b438d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java @@ -23,6 +23,7 @@ import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; /** The loss function for binary logistic loss. See {@link LogisticRegression} for example. */ @Internal @@ -33,9 +34,9 @@ private BinaryLogisticLoss() {} @Override public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; - return dataPoint.getWeight() * Math.log(1 + Math.exp(-dot * labelScaled)); + double dot = BLAS.dot((IntDoubleVector) dataPoint.features, coefficient); + double labelScaled = 2 * dataPoint.label - 1; + return dataPoint.weight * Math.log(1 + Math.exp(-dot * labelScaled)); } @Override @@ -43,11 +44,11 @@ public void computeGradient( LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient, DenseIntDoubleVector cumGradient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; - double multiplier = - dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); - BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient, dataPoint.getFeatures().size()); + IntDoubleVector feature = (IntDoubleVector) dataPoint.features; + double dot = BLAS.dot(feature, coefficient); + double labelScaled = 2 * dataPoint.label - 1; + double multiplier = dataPoint.weight * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); + BLAS.axpy(multiplier, feature, cumGradient, feature.size()); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java index 06a104aa7..38f7a43a8 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java @@ -23,6 +23,7 @@ import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; /** * The loss function for hinge loss. See {@link LinearSVC} for example. @@ -37,9 +38,9 @@ private HingeLoss() {} @Override public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; - return dataPoint.getWeight() * Math.max(0, 1 - labelScaled * dot); + double dot = BLAS.dot((IntDoubleVector) dataPoint.features, coefficient); + double labelScaled = 2 * dataPoint.label - 1; + return dataPoint.weight * Math.max(0, 1 - labelScaled * dot); } @Override @@ -47,14 +48,11 @@ public void computeGradient( LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient, DenseIntDoubleVector cumGradient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; + IntDoubleVector feature = (IntDoubleVector) dataPoint.features; + double dot = BLAS.dot(feature, coefficient); + double labelScaled = 2 * dataPoint.label - 1; if (1 - labelScaled * dot > 0) { - BLAS.axpy( - -labelScaled * dataPoint.getWeight(), - dataPoint.getFeatures(), - cumGradient, - dataPoint.getFeatures().size()); + BLAS.axpy(-labelScaled * dataPoint.weight, feature, cumGradient, feature.size()); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java index 7b943491f..d76093189 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java @@ -22,6 +22,7 @@ import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.IntDoubleVector; import org.apache.flink.ml.regression.linearregression.LinearRegression; /** The loss function for least square loss. See {@link LinearRegression} for example. */ @@ -33,8 +34,8 @@ private LeastSquareLoss() {} @Override public double computeLoss(LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - return dataPoint.getWeight() * 0.5 * Math.pow(dot - dataPoint.getLabel(), 2); + double dot = BLAS.dot((IntDoubleVector) dataPoint.features, coefficient); + return dataPoint.weight * 0.5 * Math.pow(dot - dataPoint.label, 2); } @Override @@ -42,11 +43,12 @@ public void computeGradient( LabeledPointWithWeight dataPoint, DenseIntDoubleVector coefficient, DenseIntDoubleVector cumGradient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + IntDoubleVector feature = (IntDoubleVector) dataPoint.features; + double dot = BLAS.dot(feature, coefficient); BLAS.axpy( - (dot - dataPoint.getLabel()) * dataPoint.getWeight(), - dataPoint.getFeatures(), + (dot - dataPoint.label) * dataPoint.weight, + (IntDoubleVector) dataPoint.features, cumGradient, - dataPoint.getFeatures().size()); + feature.size()); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java index b94b55f9d..471d24cdb 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java @@ -277,7 +277,7 @@ public void onEpochWatermarkIncremented( for (LabeledPointWithWeight dataPoint : miniBatchData) { totalLoss += lossFunc.computeLoss(dataPoint, coefficient); lossFunc.computeGradient(dataPoint, coefficient, cumGradientsWrapper); - totalWeight += dataPoint.getWeight(); + totalWeight += dataPoint.weight; } setTotalLoss(totalLoss); setTotalWeight(totalWeight); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java index 4b072e0a2..f849890d4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -29,6 +29,7 @@ import java.util.Arrays; import java.util.Iterator; +import java.util.function.BiFunction; /** ServerAgent resides on each worker. It serves as an agent for workers to talk with servers. */ public class ServerAgent { @@ -88,7 +89,7 @@ void pull(long[] indices) { *

Note that the values pushed by this function are not going to update the model, but just * perform an all reduce operation. */ - void allReducePush(double[] values) { + void allReducePush(double[] values, BiFunction aggregator) { final int MIN_MESSAGE_SIZE = 1024; int numServers = partitioner.numServers; int messageSize = Math.max(MIN_MESSAGE_SIZE, values.length / numServers + 1); @@ -101,7 +102,7 @@ void allReducePush(double[] values) { } else { segment = Arrays.copyOfRange(values, s, e); } - AllReduceM allReduceM = new AllReduceM(serverId, workerId, segment); + AllReduceM allReduceM = new AllReduceM(serverId, workerId, segment, aggregator); output.collect(new StreamRecord<>(Tuple2.of(serverId, allReduceM.toBytes()))); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java index 31fd23441..a9250baa4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -32,6 +32,7 @@ import org.apache.flink.ml.common.ps.message.PullIndexM; import org.apache.flink.ml.common.ps.message.PulledValueM; import org.apache.flink.ml.common.ps.message.PushKvM; +import org.apache.flink.ml.common.ps.training.IterationStageList; import org.apache.flink.ml.common.ps.updater.ModelUpdater; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; @@ -81,6 +82,8 @@ public class ServerOperator extends AbstractStreamOperator> implements OneInputStreamOperator, Tuple2>, IterationListener> { + /** Iteration stage list. */ + private final IterationStageList iterationStageList; /** Number of workers to communicate with. */ private final int numWorkers; /** The logic to answer push/pull request from workers. */ @@ -104,9 +107,11 @@ public class ServerOperator extends AbstractStreamOperator pendingPulls; public ServerOperator( + IterationStageList iterationStageList, int numWorkers, ModelUpdater modelUpdater, OutputTag> modelOutputTag) { + this.iterationStageList = iterationStageList; this.numWorkers = numWorkers; this.modelUpdater = modelUpdater; this.modelOutputTag = modelOutputTag; @@ -383,9 +388,7 @@ private void processAllReduceRequest(byte[] request) { reducedResult = receivedValues; } else { Preconditions.checkArgument(reducedResult.length == receivedValues.length); - for (int i = 0; i < reducedResult.length; i++) { - reducedResult[i] += receivedValues[i]; - } + reducedResult = allReduceM.aggregator.apply(receivedValues, reducedResult); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java index 7fc0011bb..da7584989 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -269,7 +269,8 @@ private int processTrainingStage( // We are not incrementing nextStageToExecute here, since we will need to pull // values from servers. AllReduceStage allReduceStage = (AllReduceStage) stage; - serverAgent.allReducePush(allReduceStage.valuesSupplier.get()); + serverAgent.allReducePush( + allReduceStage.valuesSupplier.get(), allReduceStage.valuesAggregator); return nextStageToExecute; } else if (stage instanceof PushStage) { PushStage pushStage = (PushStage) stage; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java index d53af554e..e5f249800 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java @@ -21,6 +21,12 @@ import org.apache.flink.ml.util.Bits; import org.apache.flink.util.Preconditions; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.function.BiFunction; + import static org.apache.flink.ml.common.ps.message.MessageType.ALL_REDUCE_VALUE; /** The message to apply all-reduce among workers. */ @@ -28,11 +34,17 @@ public class AllReduceM implements Message { public final int serverId; public final int workerId; public final double[] values; + public final BiFunction aggregator; - public AllReduceM(int serverId, int workerId, double[] values) { + public AllReduceM( + int serverId, + int workerId, + double[] values, + BiFunction aggregator) { this.serverId = serverId; this.workerId = workerId; this.values = values; + this.aggregator = aggregator; } public static AllReduceM fromBytes(byte[] bytes) { @@ -46,16 +58,21 @@ public static AllReduceM fromBytes(byte[] bytes) { int workerId = Bits.getInt(bytes, offset); offset += Integer.BYTES; double[] values = MessageUtils.getDoubleArray(bytes, offset); - return new AllReduceM(psId, workerId, values); + offset += MessageUtils.getDoubleArraySizeInBytes(values); + + BiFunction aggregator = deserializeFunction(bytes, offset); + return new AllReduceM(psId, workerId, values, aggregator); } @Override public byte[] toBytes() { + byte[] serializedFunctionInBytes = serializeFunction(aggregator); int numBytes = Character.BYTES + Integer.BYTES + Integer.BYTES - + MessageUtils.getDoubleArraySizeInBytes(values); + + MessageUtils.getDoubleArraySizeInBytes(values) + + serializedFunctionInBytes.length; byte[] buffer = new byte[numBytes]; int offset = 0; Bits.putChar(buffer, offset, ALL_REDUCE_VALUE.type); @@ -66,7 +83,34 @@ public byte[] toBytes() { Bits.putInt(buffer, offset, this.workerId); offset += Integer.BYTES; MessageUtils.putDoubleArray(values, buffer, offset); + offset += MessageUtils.getDoubleArraySizeInBytes(values); + System.arraycopy( + serializedFunctionInBytes, 0, buffer, offset, serializedFunctionInBytes.length); return buffer; } + + private static byte[] serializeFunction(BiFunction aggregator) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(aggregator); + oos.flush(); + } catch (Throwable e) { + return null; + } + return baos.toByteArray(); + } + + private static BiFunction deserializeFunction( + byte[] bytes, int offset) { + ByteArrayInputStream bais = new ByteArrayInputStream(bytes, offset, bytes.length - offset); + try { + ObjectInputStream ois = new ObjectInputStream(bais); + return (BiFunction) ois.readObject(); + } catch (Exception e) { + System.out.println("wrong deserialization"); + return null; + } + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java index 39bafbc13..5d684889c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java @@ -18,6 +18,8 @@ package org.apache.flink.ml.common.ps.message; +import java.io.IOException; + /** * The message to be passed between worker node and server node. * @@ -31,5 +33,5 @@ public interface Message { * *

Note that the first two bytes of the result buffer is reserved for {@link MessageType}. */ - byte[] toBytes(); + byte[] toBytes() throws IOException; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageData.java new file mode 100644 index 000000000..54b81919f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageData.java @@ -0,0 +1,75 @@ +// package org.apache.flink.ml.common.ps.message; +// +// import org.apache.flink.api.common.typeutils.TypeSerializer; +// import org.apache.flink.api.java.tuple.Tuple2; +// import org.apache.flink.core.memory.DataInputViewStreamWrapper; +// import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +// import org.apache.flink.ml.util.Bits; +// +// import java.io.ByteArrayInputStream; +// import java.io.ByteArrayOutputStream; +// import java.io.IOException; +// import java.nio.ByteBuffer; +// import java.util.ArrayList; +// import java.util.List; +// +/// ** +// * Message body. +// */ +// public class MessageData { +// byte[] bytes; +// ByteBuffer byteBuffer; +// int offset = 0; +// +// public MessageData(Meta meta, int messageSize) { +// byteBuffer = ByteBuffer.allocate(messageSize); +// } +// +// /** +// * Adds data for generics. +// */ +// public void addData(long[] keys, V[] values, TypeSerializer serializer) throws IOException +// { +// ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); +// DataOutputViewStreamWrapper dataOutputViewStreamWrapper = new +// DataOutputViewStreamWrapper(byteArrayOutputStream); +// for (int i = 0; i < values.length; i ++) { +// serializer.serialize(values[i], dataOutputViewStreamWrapper); +// } +// byte[] serializedValues = byteArrayOutputStream.toByteArray(); +// +// } +// +// public void addData(long[] keys, double[] values) { +// offset = MessageUtils.putLongDoubleArray(Tuple2.of(keys, values), bytes, offset); +// } +// +// /** Gets a long-double array from the byte array starting from the given offset. */ +// public static V[] getGenericArray(byte[] bytes, int offset, TypeSerializer +// typeSerializer) +// throws IOException { +// int n = Bits.getInt(bytes, offset); +// Object[] result = new Object[n]; +// ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes, offset, +// bytes.length - offset); +// DataInputViewStreamWrapper dataInputViewStreamWrapper = new +// DataInputViewStreamWrapper(byteArrayInputStream); +// for (int i = 0; i < n; i ++) { +// result[i] = typeSerializer.deserialize(dataInputViewStreamWrapper); +// } +// +// return (V[]) result; +// } +// +// public static byte[] getSerializedBytes(V[] values, TypeSerializer typeSerializer) throws +// IOException { +// ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); +// DataOutputViewStreamWrapper dataOutputViewStreamWrapper = new +// DataOutputViewStreamWrapper(byteArrayOutputStream); +// for (int i = 0; i < values.length; i ++) { +// typeSerializer.serialize(values[i], dataOutputViewStreamWrapper); +// } +// byte[] serializedValues = byteArrayOutputStream.toByteArray(); +// return serializedValues; +// } +// } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java index d2a628870..e79f07b39 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java @@ -18,12 +18,25 @@ package org.apache.flink.ml.common.ps.message; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.util.Bits; /** Utility functions for processing messages. */ public class MessageUtils { + public static TypeInformation getKeyType(V key) { + if (key instanceof Integer) { + return Types.INT; + } else if (key instanceof Long) { + return Types.LONG; + } else { + throw new UnsupportedOperationException( + String.format("Unsupported key type: %s.", key.getClass().getSimpleName())); + } + } + /** Retrieves the message type from the byte array. */ public static MessageType getMessageType(byte[] bytes) { char type = Bits.getChar(bytes, 0); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Meta.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Meta.java new file mode 100644 index 000000000..864ced3f0 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Meta.java @@ -0,0 +1,83 @@ +// package org.apache.flink.ml.common.ps.message; +// +// import org.apache.flink.ml.util.Bits; +// +/// ** +// * Meta information of a message. +// */ +// public class Meta { +// /** +// * Index of the sender of this message. +// */ +// public int sender; +// /** +// * Index of the receiver of this message. +// */ +// public int receiver; +// /** +// * Whether this is a push message. +// */ +// public boolean push; +// /** +// * Whether this is a pull message. +// */ +// public boolean pull; +// /** +// * The size of data in bytes. +// */ +// public int dataSize; +// +// public Meta(int sender, int receiver, boolean push, boolean pull, int dataSize) { +// this.sender = sender; +// this.receiver = receiver; +// this.push = push; +// this.pull = pull; +// this.dataSize = dataSize; +// } +// +// /** +// * Empty constructor to make it as a pojo. +// */ +// public Meta() {} +// +// /** +// * Restores meta instance from a given byte array starting from the given offset. +// */ +// public Meta fromBytes(byte[] bytes, int offset) { +// Meta meta = new Meta(); +// meta.sender = Bits.getInt(bytes, offset); +// offset += Integer.BYTES; +// meta.receiver = Bits.getInt(bytes, offset); +// offset += Integer.BYTES; +// meta.push = Bits.getChar(bytes, offset) == (char) 1; +// offset += Character.BYTES; +// meta.pull = Bits.getChar(bytes, offset) == (char) 1; +// offset += Character.BYTES; +// meta.dataSize = Bits.getInt(bytes, offset); +// return meta; +// } +// +// /** +// * Writes a meta instance to a given byte array starting from the given offset. +// */ +// public int toBytes(byte[] bytes, int offset) { +// Bits.putInt(bytes, offset, sender); +// offset += Integer.BYTES; +// Bits.putInt(bytes, offset, receiver); +// offset += Integer.BYTES; +// +// Bits.putChar(bytes, offset, push? (char) 1: (char) 0); +// offset += Character.BYTES; +// Bits.putChar(bytes, offset, pull? (char) 1: (char) 0); +// offset += Character.BYTES; +// +// Bits.putInt(bytes, offset, dataSize); +// offset += Integer.BYTES; +// return offset; +// } +// +// public static int getSizeInBytes() { +// return Integer.BYTES + Integer.BYTES + Character.BYTES + Character.BYTES + Character.BYTES + +// Integer.BYTES; +// } +// } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java index ebeca86d7..bf20268c5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java @@ -19,9 +19,10 @@ package org.apache.flink.ml.common.ps.training; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.common.lossfunc.LossFunc; import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; import org.apache.flink.ml.linalg.Vectors; import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; @@ -30,8 +31,7 @@ import java.util.List; /** An iteration stage that uses the pulled model values and batch data to compute the gradients. */ -public class ComputeGradients - extends ProcessStage> { +public class ComputeGradients extends ProcessStage> { private final LossFunc lossFunc; public ComputeGradients(LossFunc lossFunc) { @@ -39,8 +39,7 @@ public ComputeGradients(LossFunc lossFunc) { } @Override - public void process(MiniBatchMLSession session) - throws IOException { + public void process(MiniBatchMLSession session) throws IOException { long[] indices = ComputeIndices.getSortedIndices(session.batchData); double[] modelValues = session.pulledValues; double[] gradients = computeGradient(session.batchData, Tuple2.of(indices, modelValues)); @@ -50,7 +49,7 @@ public void process(MiniBatchMLSession session) } private double[] computeGradient( - List batchData, Tuple2 modelData) { + List batchData, Tuple2 modelData) { long[] modelIndices = modelData.f0; double[] modelValues = modelData.f1; Long2DoubleOpenHashMap modelInMap = new Long2DoubleOpenHashMap(modelIndices.length); @@ -59,12 +58,13 @@ private double[] computeGradient( } Long2DoubleOpenHashMap cumGradients = new Long2DoubleOpenHashMap(modelIndices.length); - for (LabeledLargePointWithWeight dataPoint : batchData) { - double dot = dot(dataPoint.features, modelInMap); + for (LabeledPointWithWeight dataPoint : batchData) { + SparseLongDoubleVector feature = (SparseLongDoubleVector) dataPoint.features; + double dot = dot(feature, modelInMap); double multiplier = lossFunc.computeGradient(dataPoint.label, dot) * dataPoint.weight; - long[] featureIndices = dataPoint.features.f0; - double[] featureValues = dataPoint.features.f1; + long[] featureIndices = feature.indices; + double[] featureValues = feature.values; double z; for (int i = 0; i < featureIndices.length; i++) { long currentIndex = featureIndices[i]; @@ -80,11 +80,10 @@ private double[] computeGradient( return cumGradientValues; } - private static double dot( - Tuple2 features, Long2DoubleOpenHashMap coefficient) { + private static double dot(SparseLongDoubleVector feature, Long2DoubleOpenHashMap coefficient) { double dot = 0; - for (int i = 0; i < features.f0.length; i++) { - dot += features.f1[i] * coefficient.get(features.f0[i]); + for (int i = 0; i < feature.indices.length; i++) { + dot += feature.values[i] * coefficient.get(feature.indices[i]); } return dot; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java index 5cf868f21..d624887a6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeIndices.java @@ -18,7 +18,8 @@ package org.apache.flink.ml.common.ps.training; -import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; @@ -30,18 +31,19 @@ * An iteration stage that samples a batch of training data and computes the indices needed to * compute gradients. */ -public class ComputeIndices extends ProcessStage> { +public class ComputeIndices extends ProcessStage> { @Override - public void process(MiniBatchMLSession context) throws Exception { + public void process(MiniBatchMLSession context) throws Exception { context.readInNextBatchData(); context.pullIndices = getSortedIndices(context.batchData); } - public static long[] getSortedIndices(List dataPoints) { + public static long[] getSortedIndices(List dataPoints) { LongOpenHashSet indices = new LongOpenHashSet(); - for (LabeledLargePointWithWeight dataPoint : dataPoints) { - long[] notZeros = dataPoint.features.f0; + for (LabeledPointWithWeight dataPoint : dataPoints) { + SparseLongDoubleVector feature = (SparseLongDoubleVector) dataPoint.features; + long[] notZeros = feature.indices; for (long index : notZeros) { indices.add(index); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java index 3dff96868..3ed22c64f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java @@ -32,7 +32,7 @@ import org.apache.flink.iteration.IterationConfig; import org.apache.flink.iteration.Iterations; import org.apache.flink.iteration.ReplayableDataStreamList; -import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; import org.apache.flink.ml.common.ps.MirrorWorkerOperator; import org.apache.flink.ml.common.ps.ServerOperator; import org.apache.flink.ml.common.ps.WorkerOperator; @@ -105,7 +105,7 @@ public TrainIterationBody( public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { DataStream variableStream = variableStreams.get(0); - DataStream trainData = dataStreams.get(0); + DataStream trainData = dataStreams.get(0); final OutputTag> modelDataOutputTag = new OutputTag>("MODEL_OUTPUT") {}; @@ -134,7 +134,10 @@ public IterationBodyResult process( Types.INT, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO), new ServerOperator( - numWorkers, modelUpdater, modelDataOutputTag)); + iterationStages, + numWorkers, + modelUpdater, + modelDataOutputTag)); messageToWorker.setParallelism(numServers); DataStream combinedMessageToWorker = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java index 50422088e..de48b817d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java @@ -81,7 +81,7 @@ public LinearRegressionModel fit(Table... inputs) { DataStream initModelData = DataStreamUtils.reduce( - trainData.map(x -> x.getFeatures().size()), + trainData.map(x -> (Integer) x.features.size()), (ReduceFunction) (t0, t1) -> { Preconditions.checkState( diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java index 2489d594b..422257c38 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -18,12 +18,13 @@ package org.apache.flink.ml.classification; -import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; -import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; @@ -31,11 +32,14 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegressionWithFtrl; import org.apache.flink.ml.linalg.DenseIntDoubleVector; import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.SparseLongDoubleVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.SparseIntDoubleVectorTypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseLongDoubleVectorTypeInfo; import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.types.BasicType; import org.apache.flink.ml.servable.types.DataTypes; +import org.apache.flink.ml.util.Bits; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -50,6 +54,8 @@ import org.junit.rules.TemporaryFolder; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.io.SequenceInputStream; import java.util.ArrayList; import java.util.Arrays; @@ -75,16 +81,46 @@ public class LogisticRegressionWithFtrlTest { private static final List trainRows = Arrays.asList( - Row.of(Tuple2.of(new long[] {0, 1}, new double[] {1, 2}), 0., 1.), - Row.of(Tuple2.of(new long[] {0, 2}, new double[] {2, 3}), 0., 2.), - Row.of(Tuple2.of(new long[] {0, 3}, new double[] {3, 4}), 0., 3.), - Row.of(Tuple2.of(new long[] {0, 2}, new double[] {4, 4}), 0., 4.), - Row.of(Tuple2.of(new long[] {0, 1}, new double[] {5, 4}), 0., 5.), - Row.of(Tuple2.of(new long[] {0, 2}, new double[] {11, 3}), 1., 1.), - Row.of(Tuple2.of(new long[] {0, 3}, new double[] {12, 4}), 1., 2.), - Row.of(Tuple2.of(new long[] {0, 1}, new double[] {13, 2}), 1., 3.), - Row.of(Tuple2.of(new long[] {0, 3}, new double[] {14, 4}), 1., 4.), - Row.of(Tuple2.of(new long[] {0, 2}, new double[] {15, 4}), 1., 5.)); + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 1}, new double[] {1, 2}), + 0., + 1.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 2}, new double[] {2, 3}), + 0., + 2.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 3}, new double[] {3, 4}), + 0., + 3.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 2}, new double[] {4, 4}), + 0., + 4.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 1}, new double[] {5, 4}), + 0., + 5.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 2}, new double[] {11, 3}), + 1., + 1.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 3}, new double[] {12, 4}), + 1., + 2.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 1}, new double[] {13, 2}), + 1., + 3.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 3}, new double[] {14, 4}), + 1., + 4.), + Row.of( + new SparseLongDoubleVector(4, new long[] {0, 2}, new double[] {15, 4}), + 1., + 5.)); private static final List testRows = Arrays.asList( @@ -116,11 +152,7 @@ public void before() { trainRows, new RowTypeInfo( new TypeInformation[] { - new TupleTypeInfo<>( - PrimitiveArrayTypeInfo - .LONG_PRIMITIVE_ARRAY_TYPE_INFO, - PrimitiveArrayTypeInfo - .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + SparseLongDoubleVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, @@ -395,4 +427,26 @@ private void verifyPredictionResult( } } } + + @Test + public void testGetGenericType() throws IOException { + TypeInformation t = getType(128); + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(4); + DataOutputView d = new DataOutputViewStreamWrapper(byteArrayOutputStream); + t.createSerializer(null).serialize(128, d); + byte[] serialized = byteArrayOutputStream.toByteArray(); + System.out.println(Bits.getInt(serialized, 0)); + + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(serialized); + DataInputView inputView = new DataInputViewStreamWrapper(byteArrayInputStream); + int deserializedInt = (Integer) t.createSerializer(null).deserialize(inputView); + System.out.println(deserializedInt); + } + + TypeInformation getType(V v) { + if (v instanceof Integer) { + return Types.INT; + } + return null; + } } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java index 43ad621e5..871dbf36c 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java @@ -1,40 +1,40 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.feature; - -import org.apache.flink.api.java.tuple.Tuple2; - -/** A data point to represent values that use long as index and double as values. */ -public class LabeledLargePointWithWeight { - public Tuple2 features; - - public double label; - - public double weight; - - public LabeledLargePointWithWeight( - Tuple2 features, double label, double weight) { - this.features = features; - this.label = label; - this.weight = weight; - } - - /** Makes it pojo to use flink serializer. */ - public LabeledLargePointWithWeight() {} -} +/// * +// * Licensed to the Apache Software Foundation (ASF) under one +// * or more contributor license agreements. See the NOTICE file +// * distributed with this work for additional information +// * regarding copyright ownership. The ASF licenses this file +// * to you under the Apache License, Version 2.0 (the +// * "License"); you may not use this file except in compliance +// * with the License. You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ +// +// package org.apache.flink.ml.common.feature; +// +// import org.apache.flink.api.java.tuple.Tuple2; +// +/// ** A data point to represent values that use long as index and double as values. */ +// public class LabeledLargePointWithWeight { +// public Tuple2 features; +// +// public double label; +// +// public double weight; +// +// public LabeledLargePointWithWeight( +// Tuple2 features, double label, double weight) { +// this.features = features; +// this.label = label; +// this.weight = weight; +// } +// +// /** Makes it pojo to use flink serializer. */ +// public LabeledLargePointWithWeight() {} +// } diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java index b7a205a60..6a6344ece 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java @@ -18,46 +18,23 @@ package org.apache.flink.ml.common.feature; -import org.apache.flink.ml.linalg.IntDoubleVector; +import org.apache.flink.ml.linalg.Vector; /** Utility class to represent a data point that contains features, label and weight. */ public class LabeledPointWithWeight { - private IntDoubleVector features; + public Vector features; - private double label; + public double label; - private double weight; + public double weight; - public LabeledPointWithWeight(IntDoubleVector features, double label, double weight) { + public LabeledPointWithWeight(Vector features, double label, double weight) { this.features = features; this.label = label; this.weight = weight; } + /** Makes it as pojo. */ public LabeledPointWithWeight() {} - - public IntDoubleVector getFeatures() { - return features; - } - - public void setFeatures(IntDoubleVector features) { - this.features = features; - } - - public double getLabel() { - return label; - } - - public void setLabel(double label) { - this.label = label; - } - - public double getWeight() { - return weight; - } - - public void setWeight(double weight) { - this.weight = weight; - } } From e938e1839cde99eda4960047f49db3d1acd71448 Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Wed, 7 Jun 2023 09:56:42 +0800 Subject: [PATCH 11/18] support allreduce aggregator for double[] --- .../flink/ml/common/ps/ServerAgent.java | 10 +-- .../flink/ml/common/ps/ServerOperator.java | 5 +- ...ModelAsZeroM.java => InitializeModel.java} | 8 +- .../ml/common/ps/message/MessageData.java | 75 ----------------- .../flink/ml/common/ps/message/Meta.java | 83 ------------------- 5 files changed, 10 insertions(+), 171 deletions(-) rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/{InitializeModelAsZeroM.java => InitializeModel.java} (89%) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageData.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Meta.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java index f849890d4..b4ebaafd5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -21,7 +21,7 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.ml.common.ps.message.AllReduceM; -import org.apache.flink.ml.common.ps.message.InitializeModelAsZeroM; +import org.apache.flink.ml.common.ps.message.InitializeModel; import org.apache.flink.ml.common.ps.message.PullIndexM; import org.apache.flink.ml.common.ps.message.PushKvM; import org.apache.flink.streaming.api.operators.Output; @@ -49,15 +49,13 @@ void setPartitioner(RangePartitioner partitioner) { this.partitioner = partitioner; } - /** Sends a request to servers to initialize the values stored as zeros. */ + /** Sends a request to servers to initialize key range on each server. */ void initializeModelAsZeros() { for (int serverId = 0; serverId < partitioner.numServers; serverId++) { long start = partitioner.ranges[serverId]; long end = partitioner.ranges[serverId + 1]; - InitializeModelAsZeroM initializeModelAsZeroM = - new InitializeModelAsZeroM(workerId, serverId, start, end); - output.collect( - new StreamRecord<>(Tuple2.of(serverId, initializeModelAsZeroM.toBytes()))); + InitializeModel initializeModel = new InitializeModel(workerId, serverId, start, end); + output.collect(new StreamRecord<>(Tuple2.of(serverId, initializeModel.toBytes()))); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java index a9250baa4..a0900e39a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -26,7 +26,7 @@ import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.ps.message.AllReduceM; -import org.apache.flink.ml.common.ps.message.InitializeModelAsZeroM; +import org.apache.flink.ml.common.ps.message.InitializeModel; import org.apache.flink.ml.common.ps.message.MessageType; import org.apache.flink.ml.common.ps.message.MessageUtils; import org.apache.flink.ml.common.ps.message.PullIndexM; @@ -135,8 +135,7 @@ public void processElement(StreamRecord> element) throws pendingPulls.add(request); break; case INITIALIZE_MODEL_AS_ZERO: - InitializeModelAsZeroM initializeModelAsZeroM = - InitializeModelAsZeroM.fromBytes(request); + InitializeModel initializeModelAsZeroM = InitializeModel.fromBytes(request); Preconditions.checkState(serverId == initializeModelAsZeroM.serverId); long start = initializeModelAsZeroM.startIndex; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModel.java similarity index 89% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModel.java index 3c2bee6d7..73ff80264 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModelAsZeroM.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModel.java @@ -24,20 +24,20 @@ import static org.apache.flink.ml.common.ps.message.MessageType.INITIALIZE_MODEL_AS_ZERO; /** Message sent by worker to server that initializes the model as zeros with defined range. */ -public class InitializeModelAsZeroM implements Message { +public class InitializeModel implements Message { public final int workerId; public final int serverId; public final long startIndex; public final long endIndex; - public InitializeModelAsZeroM(int workerId, int serverId, long startIndex, long endIndex) { + public InitializeModel(int workerId, int serverId, long startIndex, long endIndex) { this.workerId = workerId; this.serverId = serverId; this.startIndex = startIndex; this.endIndex = endIndex; } - public static InitializeModelAsZeroM fromBytes(byte[] bytes) { + public static InitializeModel fromBytes(byte[] bytes) { int offset = 0; char type = Bits.getChar(bytes, offset); offset += Character.BYTES; @@ -50,7 +50,7 @@ public static InitializeModelAsZeroM fromBytes(byte[] bytes) { long startIndex = Bits.getLong(bytes, offset); offset += Long.BYTES; long endIndex = Bits.getLong(bytes, offset); - return new InitializeModelAsZeroM(workerId, serverId, startIndex, endIndex); + return new InitializeModel(workerId, serverId, startIndex, endIndex); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageData.java deleted file mode 100644 index 54b81919f..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageData.java +++ /dev/null @@ -1,75 +0,0 @@ -// package org.apache.flink.ml.common.ps.message; -// -// import org.apache.flink.api.common.typeutils.TypeSerializer; -// import org.apache.flink.api.java.tuple.Tuple2; -// import org.apache.flink.core.memory.DataInputViewStreamWrapper; -// import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -// import org.apache.flink.ml.util.Bits; -// -// import java.io.ByteArrayInputStream; -// import java.io.ByteArrayOutputStream; -// import java.io.IOException; -// import java.nio.ByteBuffer; -// import java.util.ArrayList; -// import java.util.List; -// -/// ** -// * Message body. -// */ -// public class MessageData { -// byte[] bytes; -// ByteBuffer byteBuffer; -// int offset = 0; -// -// public MessageData(Meta meta, int messageSize) { -// byteBuffer = ByteBuffer.allocate(messageSize); -// } -// -// /** -// * Adds data for generics. -// */ -// public void addData(long[] keys, V[] values, TypeSerializer serializer) throws IOException -// { -// ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); -// DataOutputViewStreamWrapper dataOutputViewStreamWrapper = new -// DataOutputViewStreamWrapper(byteArrayOutputStream); -// for (int i = 0; i < values.length; i ++) { -// serializer.serialize(values[i], dataOutputViewStreamWrapper); -// } -// byte[] serializedValues = byteArrayOutputStream.toByteArray(); -// -// } -// -// public void addData(long[] keys, double[] values) { -// offset = MessageUtils.putLongDoubleArray(Tuple2.of(keys, values), bytes, offset); -// } -// -// /** Gets a long-double array from the byte array starting from the given offset. */ -// public static V[] getGenericArray(byte[] bytes, int offset, TypeSerializer -// typeSerializer) -// throws IOException { -// int n = Bits.getInt(bytes, offset); -// Object[] result = new Object[n]; -// ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes, offset, -// bytes.length - offset); -// DataInputViewStreamWrapper dataInputViewStreamWrapper = new -// DataInputViewStreamWrapper(byteArrayInputStream); -// for (int i = 0; i < n; i ++) { -// result[i] = typeSerializer.deserialize(dataInputViewStreamWrapper); -// } -// -// return (V[]) result; -// } -// -// public static byte[] getSerializedBytes(V[] values, TypeSerializer typeSerializer) throws -// IOException { -// ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); -// DataOutputViewStreamWrapper dataOutputViewStreamWrapper = new -// DataOutputViewStreamWrapper(byteArrayOutputStream); -// for (int i = 0; i < values.length; i ++) { -// typeSerializer.serialize(values[i], dataOutputViewStreamWrapper); -// } -// byte[] serializedValues = byteArrayOutputStream.toByteArray(); -// return serializedValues; -// } -// } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Meta.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Meta.java deleted file mode 100644 index 864ced3f0..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Meta.java +++ /dev/null @@ -1,83 +0,0 @@ -// package org.apache.flink.ml.common.ps.message; -// -// import org.apache.flink.ml.util.Bits; -// -/// ** -// * Meta information of a message. -// */ -// public class Meta { -// /** -// * Index of the sender of this message. -// */ -// public int sender; -// /** -// * Index of the receiver of this message. -// */ -// public int receiver; -// /** -// * Whether this is a push message. -// */ -// public boolean push; -// /** -// * Whether this is a pull message. -// */ -// public boolean pull; -// /** -// * The size of data in bytes. -// */ -// public int dataSize; -// -// public Meta(int sender, int receiver, boolean push, boolean pull, int dataSize) { -// this.sender = sender; -// this.receiver = receiver; -// this.push = push; -// this.pull = pull; -// this.dataSize = dataSize; -// } -// -// /** -// * Empty constructor to make it as a pojo. -// */ -// public Meta() {} -// -// /** -// * Restores meta instance from a given byte array starting from the given offset. -// */ -// public Meta fromBytes(byte[] bytes, int offset) { -// Meta meta = new Meta(); -// meta.sender = Bits.getInt(bytes, offset); -// offset += Integer.BYTES; -// meta.receiver = Bits.getInt(bytes, offset); -// offset += Integer.BYTES; -// meta.push = Bits.getChar(bytes, offset) == (char) 1; -// offset += Character.BYTES; -// meta.pull = Bits.getChar(bytes, offset) == (char) 1; -// offset += Character.BYTES; -// meta.dataSize = Bits.getInt(bytes, offset); -// return meta; -// } -// -// /** -// * Writes a meta instance to a given byte array starting from the given offset. -// */ -// public int toBytes(byte[] bytes, int offset) { -// Bits.putInt(bytes, offset, sender); -// offset += Integer.BYTES; -// Bits.putInt(bytes, offset, receiver); -// offset += Integer.BYTES; -// -// Bits.putChar(bytes, offset, push? (char) 1: (char) 0); -// offset += Character.BYTES; -// Bits.putChar(bytes, offset, pull? (char) 1: (char) 0); -// offset += Character.BYTES; -// -// Bits.putInt(bytes, offset, dataSize); -// offset += Integer.BYTES; -// return offset; -// } -// -// public static int getSizeInBytes() { -// return Integer.BYTES + Integer.BYTES + Character.BYTES + Character.BYTES + Character.BYTES + -// Integer.BYTES; -// } -// } From 3f45880206ea79428d3282eb5fed181b8b90196a Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Wed, 7 Jun 2023 10:00:29 +0800 Subject: [PATCH 12/18] FTRL should not be aware of numWorkers --- .../logisticregression/LogisticRegressionWithFtrl.java | 8 +------- .../flink/ml/common/ps/training/ComputeGradients.java | 10 +++++++--- .../org/apache/flink/ml/common/ps/updater/FTRL.java | 7 ++----- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java index 6db78b766..581169e4a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -155,13 +155,7 @@ public LogisticRegressionModel fit(Table... inputs) { .setTerminationCriteria( (SerializableFunction, Boolean>) o -> o.iterationId >= getMaxIter()); - FTRL ftrl = - new FTRL( - getAlpha(), - getBeta(), - getReg(), - getElasticNet(), - trainData.getParallelism()); + FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet()); DataStream> rawModelData = TrainingUtils.train(modelDim, trainData, ftrl, iterationStages, getNumServers()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java index bf20268c5..45f175dea 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ComputeGradients.java @@ -42,14 +42,18 @@ public ComputeGradients(LossFunc lossFunc) { public void process(MiniBatchMLSession session) throws IOException { long[] indices = ComputeIndices.getSortedIndices(session.batchData); double[] modelValues = session.pulledValues; - double[] gradients = computeGradient(session.batchData, Tuple2.of(indices, modelValues)); + double[] gradients = + computeGradient( + session.batchData, Tuple2.of(indices, modelValues), session.numWorkers); session.pushIndices = indices; session.pushValues = gradients; } private double[] computeGradient( - List batchData, Tuple2 modelData) { + List batchData, + Tuple2 modelData, + int numWorkers) { long[] modelIndices = modelData.f0; double[] modelValues = modelData.f1; Long2DoubleOpenHashMap modelInMap = new Long2DoubleOpenHashMap(modelIndices.length); @@ -76,7 +80,7 @@ private double[] computeGradient( for (int i = 0; i < modelIndices.length; i++) { cumGradientValues[i] = cumGradients.get(modelIndices[i]); } - BLAS.scal(1.0 / batchData.size(), Vectors.dense(cumGradientValues)); + BLAS.scal(1.0 / batchData.size() / numWorkers, Vectors.dense(cumGradientValues)); return cumGradientValues; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java index c3fa92b2d..7d6f17f8b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java @@ -43,8 +43,6 @@ public class FTRL implements ModelUpdater { private final double lambda1; private final double lambda2; - private final int numWorkers; - // ------ Model data of FTRL optimizer. ----- private long startIndex; private long endIndex; @@ -56,12 +54,11 @@ public class FTRL implements ModelUpdater { private ListState boundaryState; private ListState modelDataState; - public FTRL(double alpha, double beta, double lambda1, double lambda2, int numWorkers) { + public FTRL(double alpha, double beta, double lambda1, double lambda2) { this.alpha = alpha; this.beta = beta; this.lambda1 = lambda1; this.lambda2 = lambda2; - this.numWorkers = numWorkers; } @Override @@ -79,7 +76,7 @@ public void open(long startFeatureIndex, long endFeatureIndex) { public void handlePush(long[] keys, double[] values) { for (int i = 0; i < keys.length; i++) { int index = (int) (keys[i] - startIndex); - double gi = values[i] / numWorkers; + double gi = values[i]; updateModelOnOneDim(gi, index, weight); } } From ea4e159879048ce735e624d0749ecf36be3082e1 Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Wed, 7 Jun 2023 20:59:46 +0800 Subject: [PATCH 13/18] Reorganize message infra --- .../flink/ml/common/ps/RangePartitioner.java | 131 ----------- ...or.java => ResponseAssemblerOperator.java} | 82 +++---- .../flink/ml/common/ps/ServerAgent.java | 156 ++++++++++--- .../flink/ml/common/ps/ServerOperator.java | 199 ++++++++--------- .../flink/ml/common/ps/WorkerOperator.java | 72 +++--- .../ml/common/ps/message/AllReduceM.java | 116 ---------- .../ml/common/ps/message/InitializeModel.java | 74 ------ .../flink/ml/common/ps/message/Message.java | 210 +++++++++++++++++- .../ml/common/ps/message/MessageType.java | 50 ++--- .../ml/common/ps/message/MessageUtils.java | 136 ------------ .../ml/common/ps/message/PullIndexM.java | 68 ------ .../ml/common/ps/message/PulledValueM.java | 72 ------ .../flink/ml/common/ps/message/PushKvM.java | 74 ------ .../ml/common/ps/training/AllReduceStage.java | 39 ++-- .../ml/common/ps/training/IterationStage.java | 2 +- .../ml/common/ps/training/PushStage.java | 9 +- .../ps/training/SerializableBiFunction.java | 6 - .../ml/common/ps/training/TrainingUtils.java | 4 +- .../flink/ml/common/ps/updater/FTRL.java | 18 +- .../ml/common/ps/updater/ModelUpdater.java | 16 +- .../java/org/apache/flink/ml/util/Bits.java | 93 ++++++++ 21 files changed, 637 insertions(+), 990 deletions(-) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java rename flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/{MirrorWorkerOperator.java => ResponseAssemblerOperator.java} (53%) delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModel.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java delete mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableBiFunction.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java deleted file mode 100644 index 2bfc255e6..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/RangePartitioner.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.ps; - -import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.util.Preconditions; - -import javax.annotation.Nullable; - -import java.util.Arrays; -import java.util.Iterator; - -/** - * Range partitioner for model data. It partitions the model data for each dimension according to - * the dimension id. The model data for each dimension could be a double or several doubles. Note - * that the model data for all dimensions should share the same size. - */ -public class RangePartitioner { - public final long dim; - public final int numServers; - public final long[] ranges; - - public RangePartitioner(long dim, int numServers) { - Preconditions.checkArgument( - dim > 0, - String.format( - "Illegal dimension when using %s: %d", - RangePartitioner.class.getSimpleName(), dim)); - - this.dim = dim; - this.numServers = numServers; - this.ranges = new long[numServers + 1]; - long shardSize = dim / numServers; - - for (int serverId = 0; serverId < numServers; serverId++) { - ranges[serverId] = shardSize * serverId; - } - ranges[numServers] = dim; - } - - /** - * Splits the push/pull request according to the given sorted indices and the corresponding - * values. - * - * @param indices sorted indices of push/pull request. - * @param values the push values if not null. - * @return the split requests for each server. - */ - public Iterator> splitRequest( - long[] indices, @Nullable double[] values) { - return new RequestsIterator(numServers, indices, values, ranges); - } - - private static class RequestsIterator implements Iterator> { - private final int numServers; - private final long[] indices; - private final double[] values; - /** - * Number of values per key. If the model data is a vector, numValuesPerKey is one. If the - * model data is a matrix, numValuesPerKey is the number of columns. - */ - private final int numValuesPerKey; - - private final long[] ranges; - - private int serverId = 0; - - private int s = 0; - - public RequestsIterator( - int numServers, long[] indices, @Nullable double[] values, long[] ranges) { - this.numServers = numServers; - this.indices = indices; - this.values = values; - this.ranges = ranges; - if (indices.length != 0 && values != null) { - numValuesPerKey = values.length / indices.length; - Preconditions.checkArgument( - numValuesPerKey * indices.length == values.length, - String.format( - "The size of values [%d] cannot be divided by size of keys [%d].", - values.length, indices.length)); - } else { - numValuesPerKey = 1; - } - } - - @Override - public boolean hasNext() { - return serverId < numServers; - } - - @Override - public Tuple3 next() { - int e = s; - while (e < indices.length && indices[e] < ranges[serverId + 1]) { - e++; - } - - long[] splitIndices = new long[0]; - double[] splitValues = values == null ? null : new double[0]; - if (s < e) { - splitIndices = Arrays.copyOfRange(indices, s, e); - splitValues = - values == null - ? null - : Arrays.copyOfRange( - values, s * numValuesPerKey, e * numValuesPerKey); - } - s = e; - serverId++; - return Tuple3.of(serverId - 1, splitIndices, splitValues); - } - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ResponseAssemblerOperator.java similarity index 53% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ResponseAssemblerOperator.java index 32a601d0e..29e7c4b3e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ResponseAssemblerOperator.java @@ -21,8 +21,10 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.ml.common.ps.message.PulledValueM; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.ps.message.Message; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -30,28 +32,24 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Preconditions; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.Iterator; -import java.util.List; - /** - * Merges the message from different servers for one pull request. + * Assembles the responses from different servers for one pull request. * *

Note that for each single-thread worker, there are at exactly #numServers segments for each - * pull request in the feedback edge. + * pull request in the responses. */ -public class MirrorWorkerOperator extends AbstractStreamOperator +public class ResponseAssemblerOperator extends AbstractStreamOperator implements OneInputStreamOperator, byte[]> { private final int numServers; + private int workerId; - /** The received messages from servers for the current pull request. */ - private List messageReceived; + private int numResponsesReceived = 0; + private ListState numResponsesReceivedState; - private ListState messageReceivedState; + private ListState responsesReceived; - public MirrorWorkerOperator(int numServers) { + public ResponseAssemblerOperator(int numServers) { this.numServers = numServers; } @@ -64,59 +62,43 @@ public void open() throws Exception { @Override public void processElement(StreamRecord> element) throws Exception { Preconditions.checkState(element.getValue().f0 == workerId); - PulledValueM pulledValueM = PulledValueM.fromBytes(element.getValue().f1); - messageReceived.add(pulledValueM); - trySendingPulls(numServers); - } + responsesReceived.add(element.getValue().f1); + numResponsesReceived++; - private void trySendingPulls(int numSegments) { - if (messageReceived.size() == numSegments) { - Comparator comparator = Comparator.comparingInt(o -> o.serverId); - messageReceived.sort(comparator); - int size = 0; - for (PulledValueM pulledValueM : messageReceived) { - size += pulledValueM.values.length; - } - double[] answer = new double[size]; - int offset = 0; - for (PulledValueM pulledValueM : messageReceived) { - double[] values = pulledValueM.values; - System.arraycopy(values, 0, answer, offset, values.length); - offset += values.length; - } - PulledValueM pulledValueM = new PulledValueM(-1, workerId, answer); - output.collect(new StreamRecord<>(pulledValueM.toBytes())); - messageReceived.clear(); + if (numResponsesReceived == numServers) { + Message message = Message.assembleMessages(responsesReceived.get().iterator()); + output.collect(new StreamRecord<>(message.bytes)); + responsesReceived.clear(); + numResponsesReceived = 0; } } @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); - messageReceivedState = + responsesReceived = context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "messageReceivedState", + "responsesReceivedState", PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); - messageReceived = new ArrayList<>(); - - Iterator iterator = messageReceivedState.get().iterator(); - if (iterator.hasNext()) { - while (iterator.hasNext()) { - messageReceived.add(PulledValueM.fromBytes(iterator.next())); - } - } + numResponsesReceivedState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("numResponsesReceivedState", Types.INT)); + numResponsesReceived = + OperatorStateUtils.getUniqueElement( + numResponsesReceivedState, "numResponsesReceived") + .orElse(0); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { super.snapshotState(context); - messageReceivedState.clear(); - if (messageReceived.size() > 0) { - for (PulledValueM valuesPulled : messageReceived) { - messageReceivedState.add(valuesPulled.toBytes()); - } + responsesReceived.clear(); + if (numResponsesReceived > 0) { + numResponsesReceivedState.clear(); + numResponsesReceivedState.add(numResponsesReceived); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java index b4ebaafd5..f634478ee 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -18,25 +18,29 @@ package org.apache.flink.ml.common.ps; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; -import org.apache.flink.ml.common.ps.message.AllReduceM; -import org.apache.flink.ml.common.ps.message.InitializeModel; -import org.apache.flink.ml.common.ps.message.PullIndexM; -import org.apache.flink.ml.common.ps.message.PushKvM; +import org.apache.flink.ml.common.ps.message.Message; +import org.apache.flink.ml.common.ps.message.MessageType; import org.apache.flink.streaming.api.operators.Output; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Preconditions; +import javax.annotation.Nullable; + +import java.io.IOException; import java.util.Arrays; import java.util.Iterator; -import java.util.function.BiFunction; /** ServerAgent resides on each worker. It serves as an agent for workers to talk with servers. */ public class ServerAgent { /** Index of the worker that this agent resides on. */ private final int workerId; - /** Partitioner of the model data that this ServerAgent maintains. */ - private RangePartitioner partitioner; + /** Number of servers to talk to. */ + private int numServers; + /** Key ranges of each server. */ + private long[] ranges; /** The collector on this worker. */ private final Output>> output; @@ -45,39 +49,52 @@ public class ServerAgent { this.output = output; } - void setPartitioner(RangePartitioner partitioner) { - this.partitioner = partitioner; + void open(int numServers, long maxKey) { + this.numServers = numServers; + this.ranges = new long[numServers + 1]; + long shardSize = (maxKey + 1) / numServers; + + for (int serverId = 0; serverId < numServers; serverId++) { + ranges[serverId] = shardSize * serverId; + } + ranges[numServers] = maxKey + 1; } /** Sends a request to servers to initialize key range on each server. */ - void initializeModelAsZeros() { - for (int serverId = 0; serverId < partitioner.numServers; serverId++) { - long start = partitioner.ranges[serverId]; - long end = partitioner.ranges[serverId + 1]; - InitializeModel initializeModel = new InitializeModel(workerId, serverId, start, end); - output.collect(new StreamRecord<>(Tuple2.of(serverId, initializeModel.toBytes()))); + void initializeModel() { + for (int serverId = 0; serverId < numServers; serverId++) { + long start = ranges[serverId]; + long end = ranges[serverId + 1]; + Message message = + new Message( + serverId, + workerId, + MessageType.INITIALIZE, + new long[] {start, end}, + new double[0]); + output.collect(new StreamRecord<>(Tuple2.of(serverId, message.bytes))); } } /** Pushes a key-value arrays to servers. */ void push(long[] indices, double[] values) { - Iterator> requests = - partitioner.splitRequest(indices, values); + Iterator> requests = sliceRequest(indices, values); while (requests.hasNext()) { Tuple3 request = requests.next(); - PushKvM pushKvM = new PushKvM(workerId, request.f0, Tuple2.of(request.f1, request.f2)); - output.collect(new StreamRecord<>(Tuple2.of(request.f0, pushKvM.toBytes()))); + Message message = + new Message(request.f0, workerId, MessageType.PUSH, request.f1, request.f2); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, message.bytes))); } } /** Pulls the values from servers with the specified indices. */ void pull(long[] indices) { - Iterator> requests = - partitioner.splitRequest(indices, null); + Iterator> requests = sliceRequest(indices, null); while (requests.hasNext()) { Tuple3 request = requests.next(); - PullIndexM pullIndexM = new PullIndexM(request.f0, workerId, request.f1); - output.collect(new StreamRecord<>(Tuple2.of(request.f0, pullIndexM.toBytes()))); + Message message = + new Message(request.f0, workerId, MessageType.PULL, request.f1, new double[0]); + output.collect(new StreamRecord<>(Tuple2.of(request.f0, message.bytes))); } } @@ -87,21 +104,102 @@ void pull(long[] indices) { *

Note that the values pushed by this function are not going to update the model, but just * perform an all reduce operation. */ - void allReducePush(double[] values, BiFunction aggregator) { + void allReducePush(V[] values, TypeSerializer typeSerializer) throws IOException { final int MIN_MESSAGE_SIZE = 1024; - int numServers = partitioner.numServers; int messageSize = Math.max(MIN_MESSAGE_SIZE, values.length / numServers + 1); for (int serverId = 0; serverId < numServers; serverId++) { int s = Math.min(serverId * messageSize, values.length); int e = Math.min(s + messageSize, values.length); - double[] segment; + V[] segment; if (s == e) { - segment = new double[0]; + segment = (V[]) new Object[0]; } else { segment = Arrays.copyOfRange(values, s, e); } - AllReduceM allReduceM = new AllReduceM(serverId, workerId, segment, aggregator); - output.collect(new StreamRecord<>(Tuple2.of(serverId, allReduceM.toBytes()))); + Message message = + new Message( + workerId, + serverId, + MessageType.ALL_REDUCE, + new long[0], + segment, + typeSerializer); + output.collect(new StreamRecord<>(Tuple2.of(serverId, message.bytes))); + } + } + + /** + * Splits the push/pull request according to the given sorted indices and the corresponding + * values. + * + * @param indices sorted indices of push/pull request. + * @param values the push values if not null. + * @return the split requests for each server. + */ + private Iterator> sliceRequest( + long[] indices, @Nullable double[] values) { + return new RequestsIterator(numServers, indices, values, ranges); + } + + private static class RequestsIterator implements Iterator> { + private final int numServers; + private final long[] indices; + private final double[] values; + /** + * Number of values per key. If the model data is a vector, numValuesPerKey is one. If the + * model data is a matrix, numValuesPerKey is the number of columns. + */ + private final int numValuesPerKey; + + private final long[] ranges; + + private int serverId = 0; + + private int s = 0; + + public RequestsIterator( + int numServers, long[] indices, @Nullable double[] values, long[] ranges) { + this.numServers = numServers; + this.indices = indices; + this.values = values; + this.ranges = ranges; + if (indices.length != 0 && values != null) { + numValuesPerKey = values.length / indices.length; + Preconditions.checkArgument( + numValuesPerKey * indices.length == values.length, + String.format( + "The size of values [%d] cannot be divided by size of keys [%d].", + values.length, indices.length)); + } else { + numValuesPerKey = 1; + } + } + + @Override + public boolean hasNext() { + return serverId < numServers; + } + + @Override + public Tuple3 next() { + int e = s; + while (e < indices.length && indices[e] < ranges[serverId + 1]) { + e++; + } + + long[] splitIndices = new long[0]; + double[] splitValues = values == null ? null : new double[0]; + if (s < e) { + splitIndices = Arrays.copyOfRange(indices, s, e); + splitValues = + values == null + ? null + : Arrays.copyOfRange( + values, s * numValuesPerKey, e * numValuesPerKey); + } + s = e; + serverId++; + return Tuple3.of(serverId - 1, splitIndices, splitValues); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java index a0900e39a..766ab9ce6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.common.ps; +import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; @@ -25,15 +26,14 @@ import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.operator.OperatorStateUtils; -import org.apache.flink.ml.common.ps.message.AllReduceM; -import org.apache.flink.ml.common.ps.message.InitializeModel; +import org.apache.flink.ml.common.ps.message.Message; import org.apache.flink.ml.common.ps.message.MessageType; -import org.apache.flink.ml.common.ps.message.MessageUtils; -import org.apache.flink.ml.common.ps.message.PullIndexM; -import org.apache.flink.ml.common.ps.message.PulledValueM; -import org.apache.flink.ml.common.ps.message.PushKvM; +import org.apache.flink.ml.common.ps.training.AllReduceStage; +import org.apache.flink.ml.common.ps.training.IterationStage; import org.apache.flink.ml.common.ps.training.IterationStageList; +import org.apache.flink.ml.common.ps.training.PullStage; import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.util.Bits; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; @@ -57,15 +57,15 @@ import java.util.concurrent.Future; /** - * The server operator maintains the shared parameters. It receives push/pull requests from {@link - * WorkerOperator} and sends the answer request to {@link MirrorWorkerOperator}. It works closely - * with {@link ModelUpdater} in the following way: + * The server operator maintains the shared parameters. It receives push/pull/allreduce requests + * from {@link WorkerOperator} and sends the answer request to {@link ResponseAssemblerOperator}. It + * works closely with {@link ModelUpdater} in the following way: * *

    - *
  • The server operator deals with the message from workers and decide when to process the + *
  • The server operator deals with the message from workers and decides when to process the * received message. - *
  • The server operator calls {@link ModelUpdater#handlePush(long[], double[])} and {@link - * ModelUpdater#handlePull(long[])} to process the messages in detail. + *
  • The server operator calls {@link ModelUpdater#update(long[], double[])} and {@link + * ModelUpdater#get(long[])} to process the messages in detail. *
  • The server operator triggers checkpoint for {@link ModelUpdater}. *
  • The server operator outputs the final output parameters by calling {@link * ModelUpdater#getModelSegments()}. @@ -76,19 +76,17 @@ * ModelUpdater}. * *

    TODO: Add support for asynchronous operations on servers. - * - *

    TODO: Add support for maintaining multiple parameters on servers. */ public class ServerOperator extends AbstractStreamOperator> implements OneInputStreamOperator, Tuple2>, IterationListener> { - /** Iteration stage list. */ - private final IterationStageList iterationStageList; + /** The iterationStage list that asks responses from servers. */ + private final List stageList; /** Number of workers to communicate with. */ private final int numWorkers; /** The logic to answer push/pull request from workers. */ private final ModelUpdater modelUpdater; - /** Format of model data: start index, end index, dense double array. */ + /** Format of model data: start key index, end key index, dense double array. */ private final OutputTag> modelOutputTag; /** Index of the server task. */ private int serverId = -1; @@ -101,22 +99,27 @@ public class ServerOperator extends AbstractStreamOperator> futuresInEpoch = new ArrayList<>(); /** The merger for push requests. */ private final PushRequestMerger pushRequestMerger; - /** The merger for all reduce requests. */ - private final AllReduceMerger allReduceMerger; /** The pending pull requests. */ private ListState pendingPulls; + /** The pending allreduce requests. */ + private ListState pendingAllReduces; + public ServerOperator( IterationStageList iterationStageList, int numWorkers, ModelUpdater modelUpdater, OutputTag> modelOutputTag) { - this.iterationStageList = iterationStageList; + this.stageList = new ArrayList<>(); + for (IterationStage stage : iterationStageList.stageList) { + if (stage instanceof PullStage || stage instanceof AllReduceStage) { + stageList.add(stage); + } + } this.numWorkers = numWorkers; this.modelUpdater = modelUpdater; this.modelOutputTag = modelOutputTag; this.pushRequestMerger = new PushRequestMerger(); - this.allReduceMerger = new AllReduceMerger(); } @Override @@ -129,30 +132,26 @@ public void open() throws Exception { @Override public void processElement(StreamRecord> element) throws Exception { byte[] request = element.getValue().f1; - MessageType type = MessageUtils.getMessageType(request); + Message message = new Message(element.getValue().f1); + MessageType type = message.getMessageType(); switch (type) { - case PULL_INDEX: - pendingPulls.add(request); - break; - case INITIALIZE_MODEL_AS_ZERO: - InitializeModel initializeModelAsZeroM = InitializeModel.fromBytes(request); - Preconditions.checkState(serverId == initializeModelAsZeroM.serverId); - - long start = initializeModelAsZeroM.startIndex; - long end = initializeModelAsZeroM.endIndex; - if (initializeModelAsZeroM.workerId == 0) { - modelUpdater.open(start, end); + case INITIALIZE: + long[] indices = message.getKeys(); + Preconditions.checkState(serverId == message.getServerId() && indices.length == 2); + if (message.getWorkerId() == 0) { + modelUpdater.open(indices[0], indices[1]); } break; - case PUSH_KV: + case PUSH: futuresInEpoch.add( singleThreadExecutor.submit( - () -> pushRequestMerger.processPushRequest(request))); + () -> pushRequestMerger.processPushRequest(message))); break; - case ALL_REDUCE_VALUE: - futuresInEpoch.add( - singleThreadExecutor.submit( - () -> allReduceMerger.processAllReduceRequest(request))); + case PULL: + pendingPulls.add(request); + break; + case ALL_REDUCE: + pendingAllReduces.add(request); break; default: throw new UnsupportedOperationException("Unsupported message type: " + type + "."); @@ -175,7 +174,7 @@ public void onEpochWatermarkIncremented( pushRequestMerger.accumulatedKvsForVector.clear(); if (kvs.f0.length > 0) { // There are pushes at this epoch. - modelUpdater.handlePush(kvs.f0, kvs.f1); + modelUpdater.update(kvs.f0, kvs.f1); } Iterator pullsIterator = pendingPulls.get().iterator(); @@ -183,34 +182,48 @@ public void onEpochWatermarkIncremented( // This is a pull stage. while (pullsIterator.hasNext()) { byte[] pull = pullsIterator.next(); - futuresInEpoch.add(singleThreadExecutor.submit(() -> processPullRequest(pull))); + futuresInEpoch.add( + singleThreadExecutor.submit(() -> processPullRequest(new Message(pull)))); } } - if (allReduceMerger.reducedResult != null) { - // This is an all reduce stage. - PulledValueM pulledValueM = - new PulledValueM(serverId, -1, allReduceMerger.reducedResult); + Iterator allreduceIterator = pendingAllReduces.get().iterator(); + if (allreduceIterator.hasNext()) { + int stageId = epochWatermark % stageList.size(); + AllReduceStage allReduceStage = (AllReduceStage) stageList.get(stageId); + Message reducedResult = processAllReduceRequest(allReduceStage, allreduceIterator); for (int workerId = 0; workerId < numWorkers; workerId++) { - int finalWorkerId = workerId; - pulledValueM.workerId = finalWorkerId; - futuresInEpoch.add( - singleThreadExecutor.submit( - () -> - output.collect( - new StreamRecord<>( - Tuple2.of( - finalWorkerId, - pulledValueM.toBytes()))))); + reducedResult.setWorkerId(workerId); + output.collect(new StreamRecord<>(Tuple2.of(workerId, reducedResult.bytes))); } } + for (Future future : futuresInEpoch) { future.get(); } pendingPulls.clear(); - allReduceMerger.reducedResult = null; + pendingAllReduces.clear(); futuresInEpoch.clear(); } + private Message processAllReduceRequest(AllReduceStage stage, Iterator requests) + throws Exception { + ReduceFunction reduceFunction = stage.reducer; + V[] reducedResult = null; + while (requests.hasNext()) { + byte[] allreduceRequest = requests.next(); + Message message = new Message(allreduceRequest); + V[] receivedResult = message.getValues(stage.typeSerializer); + if (reducedResult == null) { + reducedResult = receivedResult; + } else { + reducedResult = reduceFunction.reduce(receivedResult, reducedResult); + } + } + + return new Message( + -1, -1, MessageType.ALL_REDUCE, new long[0], reducedResult, stage.typeSerializer); + } + @Override public void onIterationTerminated( Context context, Collector> collector) { @@ -230,9 +243,14 @@ public void initializeState(StateInitializationContext context) throws Exception new ListStateDescriptor<>( "pendingPulls", PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); + pendingAllReduces = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "pendingAllReduces", + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); modelUpdater.initializeState(context); pushRequestMerger.initializeState(context); - allReduceMerger.initializeState(context); } @Override @@ -246,18 +264,16 @@ public void snapshotState(StateSnapshotContext context) throws Exception { futuresInEpoch.clear(); modelUpdater.snapshotState(context); pushRequestMerger.snapshotState(context); - allReduceMerger.snapshotState(context); } - private Object processPullRequest(byte[] bytesData) { - PullIndexM pullIndexM = PullIndexM.fromBytes(bytesData); - Preconditions.checkState(serverId == pullIndexM.serverId); - int workerId = pullIndexM.workerId; - long[] indices = pullIndexM.indices; - double[] pulledValues = modelUpdater.handlePull(indices); - PulledValueM pulledValueM = new PulledValueM(serverId, workerId, pulledValues); + private Object processPullRequest(Message message) { + Preconditions.checkState(serverId == message.getServerId()); + int workerId = message.getWorkerId(); + double[] pulledValues = modelUpdater.get(message.getKeys()); + Message pulledMessage = + new Message(serverId, workerId, MessageType.PULL, new long[0], pulledValues); StreamRecord> record = - new StreamRecord<>(Tuple2.of(workerId, pulledValueM.toBytes())); + new StreamRecord<>(Tuple2.of(workerId, pulledMessage.bytes)); output.collect(record); return new Object(); @@ -277,11 +293,9 @@ public PushRequestMerger() { this.accumulatedKvsForMatrix = new HashMap<>(); } - private Object processPushRequest(byte[] pushKv) { - PushKvM pushKvM = PushKvM.fromBytes(pushKv); - Tuple2 pushKvs = pushKvM.kvs; - long[] keys = pushKvs.f0; - double[] values = pushKvs.f1; + private Object processPushRequest(Message message) { + long[] keys = message.getKeys(); + double[] values = message.getValuesInDoubleArray(); if (values.length == keys.length) { for (int i = 0; i < keys.length; i++) { @@ -343,8 +357,7 @@ private void initializeState(StateInitializationContext context) throws Exceptio .orElse(null); if (accumulatedKvsInBytes != null) { - Tuple2 kvs = - MessageUtils.getLongDoubleArray(accumulatedKvsInBytes, 0); + Tuple2 kvs = Bits.getLongDoubleArray(accumulatedKvsInBytes, 0); long[] keys = kvs.f0; double[] values = kvs.f1; int numValuesPerKey = values.length / keys.length; @@ -369,46 +382,10 @@ private void snapshotState(StateSnapshotContext context) throws Exception { Tuple2 kvs = toKvArrays(); accumulatedKvsState.clear(); if (kvs.f0.length > 0) { - byte[] bytes = new byte[MessageUtils.getLongDoubleArraySizeInBytes(kvs)]; - MessageUtils.putLongDoubleArray(kvs, bytes, 0); + byte[] bytes = new byte[Bits.getLongDoubleArraySizeInBytes(kvs)]; + Bits.putLongDoubleArray(kvs, bytes, 0); accumulatedKvsState.add(bytes); } } } - - private static class AllReduceMerger implements Serializable { - private double[] reducedResult; - private ListState reducedResultState; - - private void processAllReduceRequest(byte[] request) { - AllReduceM allReduceM = AllReduceM.fromBytes(request); - double[] receivedValues = allReduceM.values; - if (reducedResult == null) { - reducedResult = receivedValues; - } else { - Preconditions.checkArgument(reducedResult.length == receivedValues.length); - reducedResult = allReduceM.aggregator.apply(receivedValues, reducedResult); - } - } - - private void initializeState(StateInitializationContext context) throws Exception { - reducedResultState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor( - "reducedResultState", - PrimitiveArrayTypeInfo - .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); - reducedResult = - OperatorStateUtils.getUniqueElement(reducedResultState, "reducedResultState") - .orElse(null); - } - - private void snapshotState(StateSnapshotContext context) throws Exception { - reducedResultState.clear(); - if (reducedResult != null) { - reducedResultState.add(reducedResult); - } - } - } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java index da7584989..7be97223c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -26,7 +26,7 @@ import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache; import org.apache.flink.iteration.operator.OperatorStateUtils; -import org.apache.flink.ml.common.ps.message.PulledValueM; +import org.apache.flink.ml.common.ps.message.Message; import org.apache.flink.ml.common.ps.training.AllReduceStage; import org.apache.flink.ml.common.ps.training.IterationStage; import org.apache.flink.ml.common.ps.training.IterationStageList; @@ -42,9 +42,8 @@ import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.util.Collector; -import org.apache.flink.util.Preconditions; -import java.util.Arrays; +import java.io.IOException; import java.util.Iterator; /** @@ -56,8 +55,8 @@ *

      *
    • Caches the training data. *
    • Initializes the {@link MLSession}. - *
    • Splits the {@link IterationStageList} by {@link PullStage} into multiple sequences and map - * it into flink-ml-iterations. + *
    • Splits the {@link IterationStageList} by {@link PullStage} and {@link AllReduceStage} into + * multiple sequences and map it into flink-ml-iterations. *
    • Executes the process function in each {@link ProcessStage}. *
    • Executes the push/pull request in {@link PushStage} and {@link PullStage} and talk to * servers, by reading/writing {@link MLSession}. @@ -70,12 +69,12 @@ public class WorkerOperator /** Number of servers that this worker needs to talk to. */ private final int numServers; - /** The agent for each worker to talk with servers. */ - private ServerAgent serverAgent; - /** The user defined iteration logic. */ private final IterationStageList iterationStages; + /** The agent for each worker to talk with servers. */ + private ServerAgent serverAgent; + /** * Iteration id in terms of {@link IterationStageList}. When we finished processing all stages * in stageList, the iteration id increments by one. @@ -121,10 +120,10 @@ public void onEpochWatermarkIncremented( throws Exception { if (epochWatermark == 0) { modelDim = Bits.getLong(feedback, 0); - serverAgent.setPartitioner(new RangePartitioner(modelDim, numServers)); - serverAgent.initializeModelAsZeros(); + serverAgent.open(numServers, modelDim - 1); + serverAgent.initializeModel(); iterationStages.session.setInputData(new ResettableTrainDataIterator<>(trainDataState)); - nextStageToExecute = processTrainingStage(nextStageToExecute, iterationStages); + nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); } } @@ -143,34 +142,30 @@ public void processElement1(StreamRecord
      streamRecord) throws Exception { public void processElement2(StreamRecord streamRecord) throws Exception { feedback = streamRecord.getValue(); if (modelDim > 0) { - // Decodes the pulled values and puts it in ml session. + Message message = new Message(streamRecord.getValue()); IterationStage stage = iterationStages.stageList.get(nextStageToExecute); if (stage instanceof PullStage) { PullStage pullStage = (PullStage) stage; - PulledValueM valuesPulledMessage = PulledValueM.fromBytes(streamRecord.getValue()); - Preconditions.checkState( - getRuntimeContext().getIndexOfThisSubtask() - == valuesPulledMessage.workerId); - pullStage.valuesConsumer.accept(valuesPulledMessage.values); + pullStage.valuesConsumer.accept(message.getValuesInDoubleArray()); } else if (stage instanceof AllReduceStage) { - AllReduceStage allReduceStage = (AllReduceStage) stage; - PulledValueM pulledValueM = PulledValueM.fromBytes(streamRecord.getValue()); - Preconditions.checkState( - getRuntimeContext().getIndexOfThisSubtask() == pulledValueM.workerId); - System.out.println( - "Worker received allreduce result: " - + Arrays.toString(pulledValueM.values)); - allReduceStage.valuesConsumer.accept(pulledValueM.values); + AllReduceStage allReduceStage = (AllReduceStage) stage; + processAllReduceStage(allReduceStage, message); } else { throw new IllegalStateException( String.format("Illegal stage type: %s", stage.getClass().getSimpleName())); } - nextStageToExecute++; - nextStageToExecute = processTrainingStage(nextStageToExecute, iterationStages); + nextStageToExecute++; + nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); } } + private void processAllReduceStage(AllReduceStage stage, Message message) + throws IOException { + V[] reducedResult = message.getValues(stage.typeSerializer); + stage.valuesConsumer.accept(reducedResult); + } + @Override public void initializeState(StateInitializationContext context) throws Exception { super.initializeState(context); @@ -178,9 +173,9 @@ public void initializeState(StateInitializationContext context) throws Exception context.getOperatorStateStore() .getListState( new ListStateDescriptor<>( - "feedbackArrayState", + "feedbackState", PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO)); - OperatorStateUtils.getUniqueElement(feedbackState, "feedbackArrayState") + OperatorStateUtils.getUniqueElement(feedbackState, "feedbackState") .ifPresent(x -> feedback = x); trainDataState = @@ -212,7 +207,7 @@ public void initializeState(StateInitializationContext context) throws Exception OperatorStateUtils.getUniqueElement(iterationIdState, "iterationIdState").orElse(0); if (modelDim > 0) { - serverAgent.setPartitioner(new RangePartitioner(modelDim, numServers)); + serverAgent.open(numServers, modelDim - 1); } iterationStages.session.initializeState(context); @@ -239,14 +234,14 @@ public void snapshotState(StateSnapshotContext context) throws Exception { /** * Processes the stages described in the given iterationStages from the given nextStage id. This - * function processes the stages until it meets an {@link PullStage}. + * function processes the stages until it meets a {@link PullStage} or {@link AllReduceStage}. * * @param nextStageToExecute id of the next stage to execute in the given iteration stages. * @param iterationStages iteration stages used to describe the training logic. * @return the id of the next stage to execute. */ @SuppressWarnings("unchecked") - private int processTrainingStage( + private int processIterationStages( int nextStageToExecute, IterationStageList iterationStages) throws Exception { while (true) { if (nextStageToExecute >= iterationStages.stageList.size()) { @@ -259,19 +254,19 @@ private int processTrainingStage( } IterationStage stage = iterationStages.stageList.get(nextStageToExecute); + // We are not incrementing nextStageToExecute for PullStage and AllReduceStage, since we + // will need to receive values from servers. if (stage instanceof PullStage) { - // We are not incrementing nextStageToExecute here, since we will need to pull - // values from servers. PullStage pullStage = ((PullStage) stage); serverAgent.pull(pullStage.keysSupplier.get()); return nextStageToExecute; + } else if (stage instanceof AllReduceStage) { - // We are not incrementing nextStageToExecute here, since we will need to pull - // values from servers. - AllReduceStage allReduceStage = (AllReduceStage) stage; + AllReduceStage allReduceStage = (AllReduceStage) stage; serverAgent.allReducePush( - allReduceStage.valuesSupplier.get(), allReduceStage.valuesAggregator); + allReduceStage.valuesSupplier.get(), allReduceStage.typeSerializer); return nextStageToExecute; + } else if (stage instanceof PushStage) { PushStage pushStage = (PushStage) stage; serverAgent.push(pushStage.keysSupplier.get(), pushStage.valuesSupplier.get()); @@ -280,6 +275,7 @@ private int processTrainingStage( } else if (stage instanceof ProcessStage) { ((ProcessStage) stage).process(iterationStages.session); nextStageToExecute++; + } else { throw new IllegalStateException( "Illegal type of IterationStage: + " + stage.getClass().getSimpleName()); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java deleted file mode 100644 index e5f249800..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/AllReduceM.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.ps.message; - -import org.apache.flink.ml.util.Bits; -import org.apache.flink.util.Preconditions; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.util.function.BiFunction; - -import static org.apache.flink.ml.common.ps.message.MessageType.ALL_REDUCE_VALUE; - -/** The message to apply all-reduce among workers. */ -public class AllReduceM implements Message { - public final int serverId; - public final int workerId; - public final double[] values; - public final BiFunction aggregator; - - public AllReduceM( - int serverId, - int workerId, - double[] values, - BiFunction aggregator) { - this.serverId = serverId; - this.workerId = workerId; - this.values = values; - this.aggregator = aggregator; - } - - public static AllReduceM fromBytes(byte[] bytes) { - int offset = 0; - char type = Bits.getChar(bytes, offset); - offset += Character.BYTES; - Preconditions.checkState(type == ALL_REDUCE_VALUE.type); - - int psId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - int workerId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - double[] values = MessageUtils.getDoubleArray(bytes, offset); - offset += MessageUtils.getDoubleArraySizeInBytes(values); - - BiFunction aggregator = deserializeFunction(bytes, offset); - return new AllReduceM(psId, workerId, values, aggregator); - } - - @Override - public byte[] toBytes() { - byte[] serializedFunctionInBytes = serializeFunction(aggregator); - int numBytes = - Character.BYTES - + Integer.BYTES - + Integer.BYTES - + MessageUtils.getDoubleArraySizeInBytes(values) - + serializedFunctionInBytes.length; - byte[] buffer = new byte[numBytes]; - int offset = 0; - Bits.putChar(buffer, offset, ALL_REDUCE_VALUE.type); - offset += Character.BYTES; - - Bits.putInt(buffer, offset, this.serverId); - offset += Integer.BYTES; - Bits.putInt(buffer, offset, this.workerId); - offset += Integer.BYTES; - MessageUtils.putDoubleArray(values, buffer, offset); - offset += MessageUtils.getDoubleArraySizeInBytes(values); - System.arraycopy( - serializedFunctionInBytes, 0, buffer, offset, serializedFunctionInBytes.length); - - return buffer; - } - - private static byte[] serializeFunction(BiFunction aggregator) { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - try { - ObjectOutputStream oos = new ObjectOutputStream(baos); - oos.writeObject(aggregator); - oos.flush(); - } catch (Throwable e) { - return null; - } - return baos.toByteArray(); - } - - private static BiFunction deserializeFunction( - byte[] bytes, int offset) { - ByteArrayInputStream bais = new ByteArrayInputStream(bytes, offset, bytes.length - offset); - try { - ObjectInputStream ois = new ObjectInputStream(bais); - return (BiFunction) ois.readObject(); - } catch (Exception e) { - System.out.println("wrong deserialization"); - return null; - } - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModel.java deleted file mode 100644 index 73ff80264..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/InitializeModel.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.ps.message; - -import org.apache.flink.ml.util.Bits; -import org.apache.flink.util.Preconditions; - -import static org.apache.flink.ml.common.ps.message.MessageType.INITIALIZE_MODEL_AS_ZERO; - -/** Message sent by worker to server that initializes the model as zeros with defined range. */ -public class InitializeModel implements Message { - public final int workerId; - public final int serverId; - public final long startIndex; - public final long endIndex; - - public InitializeModel(int workerId, int serverId, long startIndex, long endIndex) { - this.workerId = workerId; - this.serverId = serverId; - this.startIndex = startIndex; - this.endIndex = endIndex; - } - - public static InitializeModel fromBytes(byte[] bytes) { - int offset = 0; - char type = Bits.getChar(bytes, offset); - offset += Character.BYTES; - Preconditions.checkState(type == INITIALIZE_MODEL_AS_ZERO.type); - - int workerId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - int serverId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - long startIndex = Bits.getLong(bytes, offset); - offset += Long.BYTES; - long endIndex = Bits.getLong(bytes, offset); - return new InitializeModel(workerId, serverId, startIndex, endIndex); - } - - @Override - public byte[] toBytes() { - int numBytes = Character.BYTES + Integer.BYTES + Integer.BYTES + Long.BYTES + Long.BYTES; - byte[] buffer = new byte[numBytes]; - int offset = 0; - Bits.putChar(buffer, offset, INITIALIZE_MODEL_AS_ZERO.type); - offset += Character.BYTES; - - Bits.putInt(buffer, offset, this.workerId); - offset += Integer.BYTES; - Bits.putInt(buffer, offset, this.serverId); - offset += Integer.BYTES; - Bits.putLong(buffer, offset, this.startIndex); - offset += Long.BYTES; - Bits.putLong(buffer, offset, this.endIndex); - - return buffer; - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java index 5d684889c..aa764568e 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java @@ -18,20 +18,212 @@ package org.apache.flink.ml.common.ps.message; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.util.Bits; +import org.apache.flink.util.Preconditions; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; /** - * The message to be passed between worker node and server node. + * {@link Message} is responsible for encoding all messages exchanged between a worker and a server. + * The message format follows this structure: + * + *

      `workerId serverId messageType keyLength keys valuesLength values` * - *

      NOTE: Every Message subclass should implement a static method with signature {@code static T - * fromBytes(byte[] bytesData)}, where {@code T} refers to the concrete subclass. This static method - * should instantiate a new Message instance based on the data read from the given byte array. + *

      where the message fields include the worker ID, server ID, message type, length of the keys, + * keys themselves, length of the values, and the values. */ -public interface Message { +public class Message { + private static final int WORKER_ID_OFFSET = 0; + private static final int SERVER_ID_OFFSET = Integer.BYTES; + private static final int MESSAGE_TYPE_OFFSET = Integer.BYTES + SERVER_ID_OFFSET; + private static final int KVS_OFFSET = Integer.BYTES + MESSAGE_TYPE_OFFSET; + + public final byte[] bytes; + + public Message(byte[] bytes) { + this.bytes = bytes; + } + + /** Constructs a message instance from long keys and double values. */ + public Message( + int serverId, int workerId, MessageType messageType, long[] keys, double[] values) { + int sizeInBytes = KVS_OFFSET + Bits.getLongDoubleArraySizeInBytes(Tuple2.of(keys, values)); + bytes = new byte[sizeInBytes]; + Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); + Bits.putInt(bytes, SERVER_ID_OFFSET, serverId); + Bits.putInt(bytes, MESSAGE_TYPE_OFFSET, messageType.type); + Bits.putLongDoubleArray(Tuple2.of(keys, values), bytes, KVS_OFFSET); + } + + /** Constructs a message instance from long keys and generics values. */ + public Message( + int serverId, + int workerId, + MessageType messageType, + long[] keys, + V[] values, + TypeSerializer serializer) + throws IOException { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(byteArrayOutputStream); + dataOutputViewStreamWrapper.writeInt(workerId); + dataOutputViewStreamWrapper.writeInt(serverId); + dataOutputViewStreamWrapper.writeInt(messageType.type); + + dataOutputViewStreamWrapper.writeInt(keys.length); + for (long key : keys) { + dataOutputViewStreamWrapper.writeLong(key); + } + dataOutputViewStreamWrapper.writeInt(values.length); + for (V value : values) { + serializer.serialize(value, dataOutputViewStreamWrapper); + } + bytes = byteArrayOutputStream.toByteArray(); + } + + /** Retrieves the keys. */ + public long[] getKeys() { + return Bits.getLongArray(bytes, KVS_OFFSET); + } + + /** Retrieves the values using the given serializer. */ + public V[] getValues(TypeSerializer serializer) throws IOException { + int numIndices = Bits.getInt(bytes, KVS_OFFSET); + int offset = KVS_OFFSET + Integer.BYTES + numIndices * Long.BYTES; + int numValues = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + + // Since the generics got erased, we use reflections to create the array. + V[] result = (V[]) Array.newInstance(serializer.createInstance().getClass(), numValues); + ByteArrayInputStream byteArrayInputStream = + new ByteArrayInputStream(bytes, offset, bytes.length - offset); + DataInputViewStreamWrapper dataInputViewStreamWrapper = + new DataInputViewStreamWrapper(byteArrayInputStream); + for (int i = 0; i < numValues; i++) { + result[i] = serializer.deserialize(dataInputViewStreamWrapper); + } + return result; + } + + /** Retrieves the values in double array format. */ + public double[] getValuesInDoubleArray() { + int offset = KVS_OFFSET + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + Integer.BYTES; + return Bits.getDoubleArray(bytes, offset); + } + + /** Retrieves the worker id. */ + public int getWorkerId() { + return Bits.getInt(bytes, WORKER_ID_OFFSET); + } + + /** Retrieves the server id. */ + public int getServerId() { + return Bits.getInt(bytes, SERVER_ID_OFFSET); + } + + /** Sets the worker id. */ + public void setWorkerId(int workerId) { + Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); + } + + /** Sets the server id. */ + public void setServerId(int serverId) { + Bits.putInt(bytes, SERVER_ID_OFFSET, serverId); + } + + /** Retrieves the message type. */ + public MessageType getMessageType() { + return MessageType.valueOf(Bits.getInt(bytes, MESSAGE_TYPE_OFFSET)); + } + /** - * Serializes the message into a byte array. - * - *

      Note that the first two bytes of the result buffer is reserved for {@link MessageType}. + * Assembles the received messages from servers according to the server id. Note that these + * messages should come from the same request. */ - byte[] toBytes() throws IOException; + public static Message assembleMessages(Iterator messageIterator) { + List messages = new ArrayList<>(); + while (messageIterator.hasNext()) { + messages.add(new Message(messageIterator.next())); + } + messages.sort(Comparator.comparingInt(Message::getServerId)); + + int numMessages = messages.size(); + int numKeys = 0, numValues = 0; + int numAssembledBytes = 0; + int workerId = -1; + for (Message message : messages) { + Preconditions.checkState(workerId == -1 || workerId == message.getWorkerId()); + workerId = message.getWorkerId(); + numKeys += message.getNumKeys(); + numValues += message.getNumValues(); + numAssembledBytes += message.bytes.length; + } + numAssembledBytes -= (numMessages - 1) * (KVS_OFFSET + Integer.BYTES * 2); + byte[] assembledBytes = new byte[numAssembledBytes]; + int keysOffset = KVS_OFFSET; + Bits.putInt(assembledBytes, keysOffset, numKeys); + keysOffset += Integer.BYTES; + int valuesOffset = keysOffset + numKeys * Long.BYTES; + Bits.putInt(assembledBytes, valuesOffset, numValues); + valuesOffset += Integer.BYTES; + + for (Message message : messages) { + Tuple2 keyOoffsetAndLength = message.getKeysOffsetAndLength(); + System.arraycopy( + message.bytes, + keyOoffsetAndLength.f0, + assembledBytes, + keysOffset, + keyOoffsetAndLength.f1); + keysOffset += keyOoffsetAndLength.f1; + Tuple2 valuesOffsetAndLength = message.getValuesOffSetAndLength(); + System.arraycopy( + message.bytes, + valuesOffsetAndLength.f0, + assembledBytes, + valuesOffset, + valuesOffsetAndLength.f1); + valuesOffset += valuesOffsetAndLength.f1; + } + + Message message = new Message(assembledBytes); + message.setServerId(-1); + message.setWorkerId(workerId); + return message; + } + + private Tuple2 getKeysOffsetAndLength() { + int start = KVS_OFFSET + Integer.BYTES; + int numBytes = Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES; + return Tuple2.of(start, numBytes); + } + + private Tuple2 getValuesOffSetAndLength() { + int start = + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + + KVS_OFFSET + + Integer.BYTES + + Integer.BYTES; + return Tuple2.of(start, bytes.length - start); + } + + private int getNumKeys() { + return Bits.getInt(bytes, KVS_OFFSET); + } + + private int getNumValues() { + return Bits.getInt(bytes, KVS_OFFSET + Integer.BYTES + Long.BYTES * getNumKeys()); + } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java index 9df4e599d..de0e4f6fe 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java @@ -18,43 +18,33 @@ package org.apache.flink.ml.common.ps.message; -/** Message Type between workers and servers. */ +/** Message type between workers and servers. */ public enum MessageType { - /** Message sent from workers to servers, which initializes the model on servers as zero. */ - INITIALIZE_MODEL_AS_ZERO((char) 0), - /** Message sent from workers to servers, which specifies the indices of model to pull. */ - PULL_INDEX((char) 1), - /** - * Message sent from server to workers, which specifies the values of the model pulled from - * servers. - */ - PULLED_VALUE((char) 2), - /** - * Message sent from workers to servers, which specifies the indices and values of the model to - * push to servers. - */ - PUSH_KV((char) 3), - /** Message to apply all-reduce among workers. */ - ALL_REDUCE_VALUE((char) 4); + /** The initialization request. */ + INITIALIZE(0), + /** The pull request. */ + PUSH(1), + /** The push request. */ + PULL(2), + /** The all reduce request. */ + ALL_REDUCE(3); - public final char type; + public final int type; - MessageType(char type) { + MessageType(int type) { this.type = type; } - public static MessageType valueOf(char value) { + public static MessageType valueOf(int value) { switch (value) { - case (char) 0: - return MessageType.INITIALIZE_MODEL_AS_ZERO; - case (char) 1: - return MessageType.PULL_INDEX; - case ((char) 2): - return MessageType.PULLED_VALUE; - case ((char) 3): - return MessageType.PUSH_KV; - case ((char) 4): - return MessageType.ALL_REDUCE_VALUE; + case 0: + return MessageType.INITIALIZE; + case 1: + return MessageType.PUSH; + case 2: + return MessageType.PULL; + case 3: + return MessageType.ALL_REDUCE; default: throw new UnsupportedOperationException(); } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java deleted file mode 100644 index e79f07b39..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.ps.message; - -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.common.typeinfo.Types; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.ml.util.Bits; - -/** Utility functions for processing messages. */ -public class MessageUtils { - - public static TypeInformation getKeyType(V key) { - if (key instanceof Integer) { - return Types.INT; - } else if (key instanceof Long) { - return Types.LONG; - } else { - throw new UnsupportedOperationException( - String.format("Unsupported key type: %s.", key.getClass().getSimpleName())); - } - } - - /** Retrieves the message type from the byte array. */ - public static MessageType getMessageType(byte[] bytes) { - char type = Bits.getChar(bytes, 0); - return MessageType.valueOf(type); - } - - /** Gets a long array from the byte array starting from the given offset. */ - public static long[] getLongArray(byte[] bytes, int offset) { - int size = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - long[] result = new long[size]; - for (int i = 0; i < size; i++) { - result[i] = Bits.getLong(bytes, offset); - offset += Long.BYTES; - } - return result; - } - - /** - * Puts a long array to the byte array starting from the given offset. - * - * @return the next position to write on. - */ - public static int putLongArray(long[] array, byte[] bytes, int offset) { - Bits.putInt(bytes, offset, array.length); - offset += Integer.BYTES; - for (int i = 0; i < array.length; i++) { - Bits.putLong(bytes, offset, array[i]); - offset += Long.BYTES; - } - return offset; - } - - /** Returns the size of a long array in bytes. */ - public static int getLongArraySizeInBytes(long[] array) { - return Integer.BYTES + array.length * Long.BYTES; - } - - /** Gets a double array from the byte array starting from the given offset. */ - public static double[] getDoubleArray(byte[] bytes, int offset) { - int size = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - double[] result = new double[size]; - for (int i = 0; i < size; i++) { - result[i] = Bits.getDouble(bytes, offset); - offset += Long.BYTES; - } - return result; - } - - /** - * Puts a double array to the byte array starting from the given offset. - * - * @return the next position to write on. - */ - public static int putDoubleArray(double[] array, byte[] bytes, int offset) { - Bits.putInt(bytes, offset, array.length); - offset += Integer.BYTES; - for (int i = 0; i < array.length; i++) { - Bits.putDouble(bytes, offset, array[i]); - offset += Double.BYTES; - } - return offset; - } - - /** Returns the size of a double array in bytes. */ - public static int getDoubleArraySizeInBytes(double[] array) { - return Integer.BYTES + array.length * Long.BYTES; - } - - /** Gets a long-double array from the byte array starting from the given offset. */ - public static Tuple2 getLongDoubleArray(byte[] bytes, int offset) { - long[] indices = getLongArray(bytes, offset); - offset += getLongArraySizeInBytes(indices); - double[] values = getDoubleArray(bytes, offset); - return Tuple2.of(indices, values); - } - - /** - * Puts a long-double array to the byte array starting from the given offset. - * - * @return the next position to write on. - */ - public static int putLongDoubleArray( - Tuple2 longDoubleArray, byte[] bytes, int offset) { - offset = putLongArray(longDoubleArray.f0, bytes, offset); - offset = putDoubleArray(longDoubleArray.f1, bytes, offset); - - return offset; - } - - /** Returns the size of a long-double array in bytes. */ - public static int getLongDoubleArraySizeInBytes(Tuple2 longDoubleArray) { - return getLongArraySizeInBytes(longDoubleArray.f0) - + getDoubleArraySizeInBytes(longDoubleArray.f1); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java deleted file mode 100644 index bf6d4caa6..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PullIndexM.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.ps.message; - -import org.apache.flink.ml.util.Bits; -import org.apache.flink.util.Preconditions; - -import static org.apache.flink.ml.common.ps.message.MessageType.PULL_INDEX; - -/** The indices one worker needs to pull from servers. */ -public class PullIndexM implements Message { - public final int serverId; - public final int workerId; - public final long[] indices; - - public PullIndexM(int serverId, int workerId, long[] indices) { - this.serverId = serverId; - this.workerId = workerId; - this.indices = indices; - } - - public static PullIndexM fromBytes(byte[] bytes) { - int offset = 0; - char type = Bits.getChar(bytes, offset); - offset += Character.BYTES; - Preconditions.checkState(type == PULL_INDEX.type); - - int psId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - int workerId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - long[] indices = MessageUtils.getLongArray(bytes, offset); - return new PullIndexM(psId, workerId, indices); - } - - @Override - public byte[] toBytes() { - int numBytes = - Character.BYTES + Integer.BYTES * 2 + MessageUtils.getLongArraySizeInBytes(indices); - byte[] buffer = new byte[numBytes]; - int offset = 0; - - Bits.putChar(buffer, offset, PULL_INDEX.type); - offset += Character.BYTES; - Bits.putInt(buffer, offset, this.serverId); - offset += Integer.BYTES; - Bits.putInt(buffer, offset, this.workerId); - offset += Integer.BYTES; - MessageUtils.putLongArray(this.indices, buffer, offset); - return buffer; - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java deleted file mode 100644 index 5457f1a57..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PulledValueM.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.ps.message; - -import org.apache.flink.ml.util.Bits; -import org.apache.flink.util.Preconditions; - -import static org.apache.flink.ml.common.ps.message.MessageType.PULLED_VALUE; - -/** The values pulled from servers. */ -public class PulledValueM implements Message { - public final int serverId; - public int workerId; - public final double[] values; - - public PulledValueM(int serverId, int workerId, double[] values) { - this.serverId = serverId; - this.workerId = workerId; - this.values = values; - } - - public static PulledValueM fromBytes(byte[] bytes) { - int offset = 0; - char type = Bits.getChar(bytes, offset); - offset += Character.BYTES; - Preconditions.checkState(type == PULLED_VALUE.type); - - int psId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - int workerId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - double[] values = MessageUtils.getDoubleArray(bytes, offset); - return new PulledValueM(psId, workerId, values); - } - - @Override - public byte[] toBytes() { - int numBytes = - Character.BYTES - + Integer.BYTES - + Integer.BYTES - + MessageUtils.getDoubleArraySizeInBytes(values); - byte[] buffer = new byte[numBytes]; - int offset = 0; - Bits.putChar(buffer, offset, PULLED_VALUE.type); - offset += Character.BYTES; - - Bits.putInt(buffer, offset, this.serverId); - offset += Integer.BYTES; - Bits.putInt(buffer, offset, this.workerId); - offset += Integer.BYTES; - MessageUtils.putDoubleArray(values, buffer, offset); - - return buffer; - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java deleted file mode 100644 index b3162cbe9..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/PushKvM.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common.ps.message; - -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.ml.util.Bits; -import org.apache.flink.util.Preconditions; - -import static org.apache.flink.ml.common.ps.message.MessageType.PUSH_KV; - -/** The sparse key-values to push from workers to servers. */ -public class PushKvM implements Message { - public final int serverId; - public final int workerId; - public final Tuple2 kvs; - - public PushKvM(int workerId, int serverId, Tuple2 kvs) { - this.workerId = workerId; - this.serverId = serverId; - this.kvs = kvs; - } - - public static PushKvM fromBytes(byte[] bytes) { - int offset = 0; - char type = Bits.getChar(bytes, offset); - offset += Character.BYTES; - Preconditions.checkState(type == PUSH_KV.type); - - int workerId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - int psId = Bits.getInt(bytes, offset); - offset += Integer.BYTES; - Tuple2 grad = MessageUtils.getLongDoubleArray(bytes, offset); - return new PushKvM(workerId, psId, grad); - } - - @Override - public byte[] toBytes() { - int numBytes = - Character.BYTES - + Integer.BYTES - + Integer.BYTES - + MessageUtils.getLongDoubleArraySizeInBytes(kvs); - byte[] buffer = new byte[numBytes]; - int offset = 0; - - Bits.putChar(buffer, offset, PUSH_KV.type); - offset += Character.BYTES; - - Bits.putInt(buffer, offset, this.workerId); - offset += Integer.BYTES; - Bits.putInt(buffer, offset, this.serverId); - offset += Integer.BYTES; - MessageUtils.putLongDoubleArray(kvs, buffer, offset); - - return buffer; - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java index c9153862e..aaabbe633 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/AllReduceStage.java @@ -18,38 +18,27 @@ package org.apache.flink.ml.common.ps.training; -import org.apache.flink.util.Preconditions; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeutils.TypeSerializer; -import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Supplier; -/** A communication stage that conducts all-reduce on the given double array. */ -public final class AllReduceStage implements IterationStage { - public final Supplier valuesSupplier; - public final Consumer valuesConsumer; - public final BiFunction valuesAggregator; +/** A communication stage that conducts all-reduce on the given array. */ +public final class AllReduceStage implements IterationStage { + public final Supplier valuesSupplier; + public final Consumer valuesConsumer; + public final ReduceFunction reducer; + public final TypeSerializer typeSerializer; public AllReduceStage( - Supplier valuesSupplier, - Consumer valuesConsumer, - BiFunction valuesAggregator) { + Supplier valuesSupplier, + Consumer valuesConsumer, + ReduceFunction reducer, + TypeSerializer typeSerializer) { this.valuesSupplier = valuesSupplier; this.valuesConsumer = valuesConsumer; - this.valuesAggregator = valuesAggregator; - } - - public AllReduceStage(Supplier valuesSupplier, Consumer valuesConsumer) { - this( - valuesSupplier, - valuesConsumer, - (SerializableBiFunction) - (array1, array2) -> { - Preconditions.checkState(array1.length == array2.length); - for (int i = 0; i < array1.length; i++) { - array2[i] += array1[i]; - } - return array2; - }); + this.reducer = reducer; + this.typeSerializer = typeSerializer; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java index 13c5909f3..4db772c25 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStage.java @@ -23,7 +23,7 @@ /** * Iterative machine learning training usually incurs local computation step (e.g., computing * gradients) and global communication step (e.g., all-reduce and parameter servers to aggregate the - * gradients). + * updates from workers). * *

      To describe the above iteration training process, we model the training process as a sequence * of iteration stages. An iteration stage could be either local computation or global diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java index 952377935..814aa5b96 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/PushStage.java @@ -20,7 +20,14 @@ import java.util.function.Supplier; -/** A communication stage that push (indices, values) to servers. */ +/** + * A communication stage that push (indices, values) to servers. + * + *

      Note that the length of the values array must be evenly divisible by the length of the keys + * array. Additionally, each value corresponding to a given key must have the same length. For + * instance, considering the keys {1, 4} and values {1,2,3,4,5,6}, the value at index 1 would be + * {1,2,3}, and the value at index 4 would be {4,5,6}. + */ public class PushStage implements IterationStage { public final Supplier keysSupplier; public final Supplier valuesSupplier; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableBiFunction.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableBiFunction.java deleted file mode 100644 index e191a38a6..000000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/SerializableBiFunction.java +++ /dev/null @@ -1,6 +0,0 @@ -package org.apache.flink.ml.common.ps.training; - -import java.io.Serializable; -import java.util.function.BiFunction; - -public interface SerializableBiFunction extends BiFunction, Serializable {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java index 3ed22c64f..57a001ebe 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java @@ -33,7 +33,7 @@ import org.apache.flink.iteration.Iterations; import org.apache.flink.iteration.ReplayableDataStreamList; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.common.ps.MirrorWorkerOperator; +import org.apache.flink.ml.common.ps.ResponseAssemblerOperator; import org.apache.flink.ml.common.ps.ServerOperator; import org.apache.flink.ml.common.ps.WorkerOperator; import org.apache.flink.ml.common.ps.updater.ModelUpdater; @@ -150,7 +150,7 @@ public IterationBodyResult process( .transform( "MirrorWorkerOp", PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, - new MirrorWorkerOperator(numServers)) + new ResponseAssemblerOperator(numServers)) .setParallelism(numWorkers); return new IterationBodyResult( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java index 7d6f17f8b..8656d88ca 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java @@ -31,11 +31,11 @@ import java.util.List; /** - * FTRL (Follow-the-regularized-leader) is an optimization algorithm which is widely deployed by - * online learning. + * The FTRL (Follow-the-Regularized-Leader) algorithm is an optimization algorithm used for + * large-scale linear models. It aims to minimize the sum of a loss function over all training + * examples, subject to a regularization constraint. * - *

      See H. Brendan McMahan et al., Ad click * - * prediction: a view from the trenches. + *

      FTRL is well-suited for sparse data and can handle problems with billions of features. */ public class FTRL implements ModelUpdater { private final double alpha; @@ -62,9 +62,9 @@ public FTRL(double alpha, double beta, double lambda1, double lambda2) { } @Override - public void open(long startFeatureIndex, long endFeatureIndex) { - this.startIndex = startFeatureIndex; - this.endIndex = endFeatureIndex; + public void open(long startKeyIndex, long endKeyIndex) { + this.startIndex = startKeyIndex; + this.endIndex = endKeyIndex; int modelShardSize = (int) (endIndex - startIndex); weight = new double[modelShardSize]; sigma = new double[modelShardSize]; @@ -73,7 +73,7 @@ public void open(long startFeatureIndex, long endFeatureIndex) { } @Override - public void handlePush(long[] keys, double[] values) { + public void update(long[] keys, double[] values) { for (int i = 0; i < keys.length; i++) { int index = (int) (keys[i] - startIndex); double gi = values[i]; @@ -97,7 +97,7 @@ private void updateModelOnOneDim(double gi, int index, double[] weight) { } @Override - public double[] handlePull(long[] keys) { + public double[] get(long[] keys) { double[] values = new double[keys.length]; for (int i = 0; i < keys.length; i++) { values[i] = weight[(int) (keys[i] - startIndex)]; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java index feb04e185..79e43ca40 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java @@ -26,7 +26,7 @@ import java.util.Iterator; /** - * A model updater that could be used to handle push/pull request from workers. + * A model updater that could be used to update and retrieve model data. * *

      Note that model updater should also ensure that model data is robust to failures, by writing * model data to snapshots. @@ -34,18 +34,18 @@ public interface ModelUpdater extends Serializable { /** Initializes the model data. */ - void open(long startFeatureIndex, long endFeatureIndex); + void open(long startKeyIndex, long endKeyIndex); /** Applies the push to update the model data, e.g., using gradient to update model. */ - void handlePush(long[] keys, double[] values); + void update(long[] keys, double[] values); - /** Applies the pull and return the retrieved model data. */ - double[] handlePull(long[] keys); + /** Retrieves the model data of the given keys. */ + double[] get(long[] keys); /** - * Returns model segments with the format of (startFeatureIdx, endFeatureIdx, modelValues). The - * model segments are continuously updated/retrieved by push/pull(i.e., `handlePush` and - * `handlePull`). + * Returns model segments with the format of (startKeyIdx, endKeyIdx, modelValues). The model + * segments are continuously updated/retrieved by push/pull(i.e., {@link + * ModelUpdater#update(long[], double[])} and {@link ModelUpdater#get(long[])}). */ Iterator> getModelSegments(); diff --git a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java index f28231704..fff889edc 100644 --- a/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java +++ b/flink-ml-servable-core/src/main/java/org/apache/flink/ml/util/Bits.java @@ -18,6 +18,8 @@ package org.apache.flink.ml.util; +import org.apache.flink.api.java.tuple.Tuple2; + /** * Utility methods for packing/unpacking primitive values in/out of byte arrays using big-endian * byte ordering. Referenced from java.io.Bits. @@ -86,4 +88,95 @@ public static void putChar(byte[] b, int off, char val) { b[off + 1] = (byte) (val); b[off] = (byte) (val >>> 8); } + + /** Gets a long array from the byte array starting from the given offset. */ + public static long[] getLongArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + long[] result = new long[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getLong(bytes, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Puts a long array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putLongArray(long[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putLong(bytes, offset, array[i]); + offset += Long.BYTES; + } + return offset; + } + + /** Returns the size of a long array in bytes. */ + public static int getLongArraySizeInBytes(long[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Gets a double array from the byte array starting from the given offset. */ + public static double[] getDoubleArray(byte[] bytes, int offset) { + int size = Bits.getInt(bytes, offset); + offset += Integer.BYTES; + double[] result = new double[size]; + for (int i = 0; i < size; i++) { + result[i] = Bits.getDouble(bytes, offset); + offset += Long.BYTES; + } + return result; + } + + /** + * Puts a double array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putDoubleArray(double[] array, byte[] bytes, int offset) { + Bits.putInt(bytes, offset, array.length); + offset += Integer.BYTES; + for (int i = 0; i < array.length; i++) { + Bits.putDouble(bytes, offset, array[i]); + offset += Double.BYTES; + } + return offset; + } + + /** Returns the size of a double array in bytes. */ + public static int getDoubleArraySizeInBytes(double[] array) { + return Integer.BYTES + array.length * Long.BYTES; + } + + /** Gets a long-double array from the byte array starting from the given offset. */ + public static Tuple2 getLongDoubleArray(byte[] bytes, int offset) { + long[] indices = getLongArray(bytes, offset); + offset += getLongArraySizeInBytes(indices); + double[] values = getDoubleArray(bytes, offset); + return Tuple2.of(indices, values); + } + + /** + * Puts a long-double array to the byte array starting from the given offset. + * + * @return the next position to write on. + */ + public static int putLongDoubleArray( + Tuple2 longDoubleArray, byte[] bytes, int offset) { + offset = putLongArray(longDoubleArray.f0, bytes, offset); + offset = putDoubleArray(longDoubleArray.f1, bytes, offset); + + return offset; + } + + /** Returns the size of a long-double array in bytes. */ + public static int getLongDoubleArraySizeInBytes(Tuple2 longDoubleArray) { + return getLongArraySizeInBytes(longDoubleArray.f0) + + getDoubleArraySizeInBytes(longDoubleArray.f1); + } } From 6b4df2a2350e2f0540de93b98adca93caffacb9c Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Thu, 8 Jun 2023 14:24:15 +0800 Subject: [PATCH 14/18] Support output from worker operator --- .../LogisticRegressionWithFtrl.java | 70 +++++++++------- .../flink/ml/common/ps/ServerAgent.java | 14 +--- .../flink/ml/common/ps/ServerOperator.java | 35 +++++--- .../flink/ml/common/ps/WorkerOperator.java | 6 +- .../flink/ml/common/ps/message/Message.java | 18 ++-- .../ps/training/IterationStageList.java | 4 +- .../ml/common/ps/training/MLSession.java | 25 ++++-- .../ml/common/ps/training/MLSessionImpl.java | 18 ++++ .../ps/training/MiniBatchMLSession.java | 10 +++ .../common/ps/training/ProxySideOutput.java | 20 +++++ .../ml/common/ps/training/TrainingUtils.java | 83 ++++++++++++------- .../flink/ml/common/ps/updater/FTRL.java | 2 +- .../ml/common/ps/updater/ModelUpdater.java | 13 +-- .../LogisticRegressionWithFtrlTest.java | 29 ------- 14 files changed, 213 insertions(+), 134 deletions(-) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java index 581169e4a..2e9fcd1c7 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -20,8 +20,12 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; @@ -108,31 +112,25 @@ public LogisticRegressionModel fit(Table... inputs) { features, label, weight); }); - DataStream modelDim; + DataStream maxKey; if (getModelDim() > 0) { - modelDim = trainData.getExecutionEnvironment().fromElements(getModelDim()); + maxKey = trainData.getExecutionEnvironment().fromElements(getModelDim() - 1); } else { - modelDim = + maxKey = DataStreamUtils.reduce( trainData.map( x -> { Vector feature = x.features; long dim; if (feature instanceof IntDoubleVector) { - dim = - ((IntDoubleVector) feature) - .size() - .intValue(); + dim = ((IntDoubleVector) feature).size(); } else { - dim = - ((LongDoubleVector) feature) - .size() - .longValue(); + dim = ((LongDoubleVector) feature).size(); } return dim; }), (ReduceFunction) Math::max) - .map((MapFunction) value -> value); + .map((MapFunction) value -> value - 1); } MiniBatchMLSession mlSession = @@ -140,25 +138,39 @@ public LogisticRegressionModel fit(Table... inputs) { getGlobalBatchSize(), TypeInformation.of(LabeledPointWithWeight.class)); IterationStageList> iterationStages = - new IterationStageList<>(mlSession); - iterationStages - .addStage(new ComputeIndices()) - .addStage( - new PullStage( - (SerializableSupplier) () -> mlSession.pullIndices, - (SerializableConsumer) x -> mlSession.pulledValues = x)) - .addStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) - .addStage( - new PushStage( - (SerializableSupplier) () -> mlSession.pushIndices, - (SerializableSupplier) () -> mlSession.pushValues)) - .setTerminationCriteria( - (SerializableFunction, Boolean>) - o -> o.iterationId >= getMaxIter()); + new IterationStageList<>(mlSession) + .addStage(new ComputeIndices()) + .addStage( + new PullStage( + (SerializableSupplier) () -> mlSession.pullIndices, + (SerializableConsumer) + x -> mlSession.pulledValues = x)) + .addStage(new ComputeGradients(BinaryLogisticLoss.INSTANCE)) + .addStage( + new PushStage( + (SerializableSupplier) () -> mlSession.pushIndices, + (SerializableSupplier) + () -> mlSession.pushValues)) + .setTerminationCriteria( + (SerializableFunction< + MiniBatchMLSession, + Boolean>) + o -> o.iterationId >= getMaxIter()); FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet()); - DataStream> rawModelData = - TrainingUtils.train(modelDim, trainData, ftrl, iterationStages, getNumServers()); + DataStreamList resultList = + TrainingUtils.train( + trainData, + iterationStages, + maxKey, + new TupleTypeInfo<>( + Types.LONG, + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + ftrl, + getNumServers()); + + DataStream> rawModelData = resultList.get(0); final long modelVersion = 0L; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java index f634478ee..99f8afab0 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerAgent.java @@ -61,7 +61,7 @@ void open(int numServers, long maxKey) { } /** Sends a request to servers to initialize key range on each server. */ - void initializeModel() { + void initialize() { for (int serverId = 0; serverId < numServers; serverId++) { long start = ranges[serverId]; long end = ranges[serverId + 1]; @@ -104,18 +104,12 @@ void pull(long[] indices) { *

      Note that the values pushed by this function are not going to update the model, but just * perform an all reduce operation. */ - void allReducePush(V[] values, TypeSerializer typeSerializer) throws IOException { - final int MIN_MESSAGE_SIZE = 1024; - int messageSize = Math.max(MIN_MESSAGE_SIZE, values.length / numServers + 1); + void allReduce(V[] values, TypeSerializer typeSerializer) throws IOException { + int messageSize = values.length / numServers + 1; for (int serverId = 0; serverId < numServers; serverId++) { int s = Math.min(serverId * messageSize, values.length); int e = Math.min(s + messageSize, values.length); - V[] segment; - if (s == e) { - segment = (V[]) new Object[0]; - } else { - segment = Arrays.copyOfRange(values, s, e); - } + V[] segment = Arrays.copyOfRange(values, s, e); Message message = new Message( workerId, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java index 766ab9ce6..1d557adb4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java @@ -23,7 +23,6 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.common.ps.message.Message; @@ -57,9 +56,14 @@ import java.util.concurrent.Future; /** - * The server operator maintains the shared parameters. It receives push/pull/allreduce requests - * from {@link WorkerOperator} and sends the answer request to {@link ResponseAssemblerOperator}. It - * works closely with {@link ModelUpdater} in the following way: + * The server operator maintains the shared parameters. The shared parameters can be modeled as a + * collection of {key:value} pairs. By default, the keys are evenly distributed across servers + * through range partitioning. For example, if there are two servers and the keys are {1,2,3,4,5,6}, + * then server-0 maintains keys {1,2,3} and server-1 maintains keys {4,5,6}. + * + *

      The server receives push/pull/allreduce requests from {@link WorkerOperator} and sends the + * answer request to {@link ResponseAssemblerOperator}. It works closely with {@link ModelUpdater} + * in the following way: * *

        *
      • The server operator deals with the message from workers and decides when to process the @@ -76,8 +80,10 @@ * ModelUpdater}. * *

        TODO: Add support for asynchronous operations on servers. + * + * @param output format of model data. */ -public class ServerOperator extends AbstractStreamOperator> +public class ServerOperator extends AbstractStreamOperator> implements OneInputStreamOperator, Tuple2>, IterationListener> { /** The iterationStage list that asks responses from servers. */ @@ -85,9 +91,9 @@ public class ServerOperator extends AbstractStreamOperator> modelOutputTag; + private final ModelUpdater modelUpdater; + /** Format of model data. */ + private final OutputTag modelOutputTag; /** Index of the server task. */ private int serverId = -1; /** @@ -108,8 +114,8 @@ public class ServerOperator extends AbstractStreamOperator iterationStageList, int numWorkers, - ModelUpdater modelUpdater, - OutputTag> modelOutputTag) { + ModelUpdater modelUpdater, + OutputTag modelOutputTag) { this.stageList = new ArrayList<>(); for (IterationStage stage : iterationStageList.stageList) { if (stage instanceof PullStage || stage instanceof AllReduceStage) { @@ -227,9 +233,9 @@ private Message processAllReduceRequest(AllReduceStage stage, Iterator> collector) { - Iterator> modelSegments = modelUpdater.getModelSegments(); + Iterator modelSegments = modelUpdater.getModelSegments(); while (modelSegments.hasNext()) { - Tuple3 modelSegment = modelSegments.next(); + MT modelSegment = modelSegments.next(); output.collect(modelOutputTag, new StreamRecord<>(modelSegment)); } } @@ -281,7 +287,10 @@ private Object processPullRequest(Message message) { /** Utility class to merge the push request from different workers. */ private static class PushRequestMerger implements Serializable { - /** The accumulated kv if the push request is for a vector. */ + /** + * The accumulated kv if the push request is for a vector. If the value is a double, we use + * {@link Long2DoubleOpenHashMap} for better efficiency. + */ private final Long2DoubleOpenHashMap accumulatedKvsForVector; /** The accumulated kv if the push request is for a matrix. */ private final Map accumulatedKvsForMatrix; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java index 7be97223c..54855c961 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/WorkerOperator.java @@ -32,6 +32,7 @@ import org.apache.flink.ml.common.ps.training.IterationStageList; import org.apache.flink.ml.common.ps.training.MLSession; import org.apache.flink.ml.common.ps.training.ProcessStage; +import org.apache.flink.ml.common.ps.training.ProxySideOutput; import org.apache.flink.ml.common.ps.training.PullStage; import org.apache.flink.ml.common.ps.training.PushStage; import org.apache.flink.ml.util.Bits; @@ -112,6 +113,7 @@ public void open() { int workerId = getRuntimeContext().getIndexOfThisSubtask(); this.serverAgent = new ServerAgent(workerId, output); iterationStages.session.setWorldInfo(workerId, numTasks); + iterationStages.session.setOutput(new ProxySideOutput(output)); } @Override @@ -121,7 +123,7 @@ public void onEpochWatermarkIncremented( if (epochWatermark == 0) { modelDim = Bits.getLong(feedback, 0); serverAgent.open(numServers, modelDim - 1); - serverAgent.initializeModel(); + serverAgent.initialize(); iterationStages.session.setInputData(new ResettableTrainDataIterator<>(trainDataState)); nextStageToExecute = processIterationStages(nextStageToExecute, iterationStages); } @@ -263,7 +265,7 @@ private int processIterationStages( } else if (stage instanceof AllReduceStage) { AllReduceStage allReduceStage = (AllReduceStage) stage; - serverAgent.allReducePush( + serverAgent.allReduce( allReduceStage.valuesSupplier.get(), allReduceStage.typeSerializer); return nextStageToExecute; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java index aa764568e..2c1b98a33 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java @@ -49,8 +49,10 @@ public class Message { private static final int MESSAGE_TYPE_OFFSET = Integer.BYTES + SERVER_ID_OFFSET; private static final int KVS_OFFSET = Integer.BYTES + MESSAGE_TYPE_OFFSET; + /** The storage of message in bytes. */ public final byte[] bytes; + /** Constructs a message instance from the bytes. */ public Message(byte[] bytes) { this.bytes = bytes; } @@ -117,7 +119,12 @@ public V[] getValues(TypeSerializer serializer) throws IOException { return result; } - /** Retrieves the values in double array format. */ + /** + * Retrieves the values in double array format. + * + *

        Note that getting double array in this function using {@link Bits#getDoubleArray(byte[], + * int)} is faster than {@link Message#getValues} by up to 2.3X. + */ public double[] getValuesInDoubleArray() { int offset = KVS_OFFSET + Bits.getInt(bytes, KVS_OFFSET) * Long.BYTES + Integer.BYTES; return Bits.getDoubleArray(bytes, offset); @@ -128,15 +135,14 @@ public int getWorkerId() { return Bits.getInt(bytes, WORKER_ID_OFFSET); } - /** Retrieves the server id. */ - public int getServerId() { - return Bits.getInt(bytes, SERVER_ID_OFFSET); - } - /** Sets the worker id. */ public void setWorkerId(int workerId) { Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); } + /** Retrieves the server id. */ + public int getServerId() { + return Bits.getInt(bytes, SERVER_ID_OFFSET); + } /** Sets the server id. */ public void setServerId(int serverId) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java index 1cffbcaa4..9c430d17d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java @@ -40,8 +40,10 @@ public IterationStageList(T session) { } /** Sets the criteria of termination. */ - public void setTerminationCriteria(SerializableFunction shouldTerminate) { + public IterationStageList setTerminationCriteria( + SerializableFunction shouldTerminate) { this.shouldTerminate = shouldTerminate; + return this; } /** Adds an iteration stage into the stage list. */ diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java index f19f035af..21a65d2c8 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSession.java @@ -18,19 +18,23 @@ package org.apache.flink.ml.common.ps.training; +import org.apache.flink.ml.common.ps.WorkerOperator; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.util.OutputTag; import java.io.Serializable; +import java.util.List; /** - * Stores the session information that is alive during the training process. Note that the session - * information will be updated by each {@link IterationStage}. + * Stores the session information that is alive during the training process on {@link + * WorkerOperator}. Note that the session information will be updated by each {@link + * IterationStage}. * *

        Note that subclasses should take care of the snapshot of object stored in {@link MLSession} if - * the object satisfies that: the write-process is followed by an {@link PullStage}, which is later - * again read by other stages. + * the object satisfies that: the write-process is followed by a {@link PullStage} or a {@link + * AllReduceStage}, which is later again read by other stages. */ public interface MLSession extends Serializable { /** Sets the current iteration ID. */ @@ -42,7 +46,18 @@ default void setWorldInfo(int workerId, int numWorkers) {} /** Sets the training data. */ default void setInputData(ResettableIterator inputData) {} - /** Recover from state. */ + /** Sets the collector that users can output records to downstream tasks. */ + default void setOutput(ProxySideOutput collector) {} + + /** + * Retrieves the output tags from the {@link MLSession} which can be used to output records from + * the worker operator. + */ + default List> getOutputTags() { + return null; + } + + /** Recovers from state. */ default void initializeState(StateInitializationContext context) throws Exception {} /** Snapshots to state. */ diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java index 13cc70e08..196fbd215 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MLSessionImpl.java @@ -19,6 +19,9 @@ package org.apache.flink.ml.common.ps.training; import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.util.OutputTag; + +import java.util.List; /** * The default implementation of {@link MLSession}. @@ -35,6 +38,21 @@ public class MLSessionImpl

        implements MLSession { /** The input data. */ public ResettableIterator
        inputData; + public List> outputTags; + + /** Constructs an instance with side outputs. */ + public MLSessionImpl(List> outputTags) { + this.outputTags = outputTags; + } + + /** Constructs an instance without side outputs. */ + public MLSessionImpl() {} + + @Override + public List> getOutputTags() { + return outputTags; + } + @Override public void setIterationId(int iterationId) { this.iterationId = iterationId; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java index afeaa4cb5..de5d5da45 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/MiniBatchMLSession.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.util.OutputTag; import org.apache.commons.collections.IteratorUtils; @@ -59,6 +60,15 @@ public MiniBatchMLSession(int globalBatchSize, TypeInformation
        typeInformati this.typeInformation = typeInformation; } + public MiniBatchMLSession( + int globalBatchSize, + TypeInformation
        typeInformation, + List> outputTags) { + super(outputTags); + this.globalBatchSize = globalBatchSize; + this.typeInformation = typeInformation; + } + @Override public void setWorldInfo(int workerId, int numWorkers) { super.setWorldInfo(workerId, numWorkers); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java new file mode 100644 index 000000000..b843894d2 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java @@ -0,0 +1,20 @@ +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.streaming.api.operators.Output; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +/** A collector that can only output using {@link OutputTag}. */ +public final class ProxySideOutput { + private final Output output; + + public ProxySideOutput(Output output) { + this.output = output; + } + + public void output(OutputTag outputTag, StreamRecord record) { + Preconditions.checkNotNull(outputTag); + output.collect(outputTag, record); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java index 57a001ebe..cfb327c4a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingUtils.java @@ -21,10 +21,10 @@ import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationBody; @@ -42,60 +42,70 @@ import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.util.OutputTag; +import java.util.ArrayList; +import java.util.List; + /** Utility function to describe iterative training process. */ public final class TrainingUtils { - /** * Executes the training logic described in {@link IterationStageList} and returns the fitted - * model data. + * model data as well as the outputs from worker operator. The outputs from worker operator are + * specified via {@link MLSession#getOutputTags()}. * - * @param modelDim dimension of the input model. - * @param trainData the training data. - * @param iterationStages the iterative training logic. + * @param inputData the input data. + * @param iterationStages the iterative processing logic. + * @param maxKey max value of the key. For example, the maxKey should be the max feature index + * in LogisticRegression. + * @param modelDataType output type information of model data. * @param modelUpdater the logic to update model on servers. * @param numServers number of servers. - * @return the fitted model data. + * @return the fitted model data as well as the outputs from worker operator. The orders are + * {modelData, sideOutputs from workers}. Note that the outputs from workers shares the same + * order with the {@link MLSession#getOutputTags()}. + * @param
        type information of input data. + * @param type information of the output model data. */ - public static DataStream> train( - DataStream modelDim, - DataStream trainData, - ModelUpdater modelUpdater, + public static DataStreamList train( + DataStream
        inputData, IterationStageList iterationStages, + DataStream maxKey, + TypeInformation modelDataType, + ModelUpdater modelUpdater, int numServers) { - // TODO: Support incremental training for multiple models. + // TODO: Support incremental training. DataStream variableStream = - modelDim.broadcast() + maxKey.broadcast() .map( (MapFunction) value -> { byte[] buffer = new byte[Long.BYTES]; - Bits.putLong(buffer, 0, value); + Bits.putLong(buffer, 0, value + 1); return buffer; }); - DataStreamList resultList = - Iterations.iterateBoundedStreamsUntilTermination( - DataStreamList.of(variableStream), - ReplayableDataStreamList.notReplay( - trainData.rebalance().map(x -> x, trainData.getType())), - IterationConfig.newBuilder().build(), - new TrainIterationBody(modelUpdater, iterationStages, numServers)); - - return resultList.get(0); + return Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(variableStream), + ReplayableDataStreamList.notReplay(inputData), + IterationConfig.newBuilder().build(), + new TrainIterationBody<>(modelUpdater, modelDataType, iterationStages, numServers)); } /** The iteration implementation for training process. */ - private static class TrainIterationBody implements IterationBody { - private final ModelUpdater modelUpdater; + private static class TrainIterationBody implements IterationBody { + private final ModelUpdater modelUpdater; + + private final TypeInformation modelType; private final IterationStageList iterationStages; private final int numServers; public TrainIterationBody( - ModelUpdater modelUpdater, + ModelUpdater modelUpdater, + TypeInformation modelType, IterationStageList iterationStages, int numServers) { this.iterationStages = iterationStages; + this.modelType = modelType; this.modelUpdater = modelUpdater; this.numServers = numServers; } @@ -106,8 +116,7 @@ public IterationBodyResult process( DataStreamList variableStreams, DataStreamList dataStreams) { DataStream variableStream = variableStreams.get(0); DataStream trainData = dataStreams.get(0); - final OutputTag> modelDataOutputTag = - new OutputTag>("MODEL_OUTPUT") {}; + final OutputTag modelDataOutputTag = new OutputTag<>("MODEL_OUTPUT", modelType); SingleOutputStreamOperator> messageToServer = trainData @@ -133,7 +142,7 @@ public IterationBodyResult process( new TupleTypeInfo<>( Types.INT, PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO), - new ServerOperator( + new ServerOperator<>( iterationStages, numWorkers, modelUpdater, @@ -153,10 +162,20 @@ public IterationBodyResult process( new ResponseAssemblerOperator(numServers)) .setParallelism(numWorkers); + DataStream model = messageToWorker.getSideOutput(modelDataOutputTag); + + List> result = new ArrayList<>(); + result.add(model); + + List> sideOutputTags = iterationStages.session.getOutputTags(); + if (sideOutputTags != null) { + for (OutputTag outputTag : sideOutputTags) { + result.add(messageToServer.getSideOutput(outputTag)); + } + } + return new IterationBodyResult( - DataStreamList.of(combinedMessageToWorker), - DataStreamList.of(messageToWorker.getSideOutput(modelDataOutputTag)), - null); + DataStreamList.of(combinedMessageToWorker), new DataStreamList(result), null); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java index 8656d88ca..2f403e4b3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/FTRL.java @@ -37,7 +37,7 @@ * *

        FTRL is well-suited for sparse data and can handle problems with billions of features. */ -public class FTRL implements ModelUpdater { +public class FTRL implements ModelUpdater> { private final double alpha; private final double beta; private final double lambda1; diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java index 79e43ca40..0d7ac3ed4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/updater/ModelUpdater.java @@ -18,7 +18,6 @@ package org.apache.flink.ml.common.ps.updater; -import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; @@ -30,8 +29,10 @@ * *

        Note that model updater should also ensure that model data is robust to failures, by writing * model data to snapshots. + * + * @param data type of model. */ -public interface ModelUpdater extends Serializable { +public interface ModelUpdater extends Serializable { /** Initializes the model data. */ void open(long startKeyIndex, long endKeyIndex); @@ -43,11 +44,11 @@ public interface ModelUpdater extends Serializable { double[] get(long[] keys); /** - * Returns model segments with the format of (startKeyIdx, endKeyIdx, modelValues). The model - * segments are continuously updated/retrieved by push/pull(i.e., {@link - * ModelUpdater#update(long[], double[])} and {@link ModelUpdater#get(long[])}). + * Returns model segments. The model segments are continuously updated/retrieved by + * push/pull(i.e., {@link ModelUpdater#update(long[], double[])} and {@link + * ModelUpdater#get(long[])}). */ - Iterator> getModelSegments(); + Iterator getModelSegments(); /** Recovers the model data from state. */ void initializeState(StateInitializationContext context) throws Exception; diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java index 422257c38..218947496 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -21,10 +21,6 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; @@ -39,7 +35,6 @@ import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.types.BasicType; import org.apache.flink.ml.servable.types.DataTypes; -import org.apache.flink.ml.util.Bits; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -54,8 +49,6 @@ import org.junit.rules.TemporaryFolder; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.io.SequenceInputStream; import java.util.ArrayList; import java.util.Arrays; @@ -427,26 +420,4 @@ private void verifyPredictionResult( } } } - - @Test - public void testGetGenericType() throws IOException { - TypeInformation t = getType(128); - ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(4); - DataOutputView d = new DataOutputViewStreamWrapper(byteArrayOutputStream); - t.createSerializer(null).serialize(128, d); - byte[] serialized = byteArrayOutputStream.toByteArray(); - System.out.println(Bits.getInt(serialized, 0)); - - ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(serialized); - DataInputView inputView = new DataInputViewStreamWrapper(byteArrayInputStream); - int deserializedInt = (Integer) t.createSerializer(null).deserialize(inputView); - System.out.println(deserializedInt); - } - - TypeInformation getType(V v) { - if (v instanceof Integer) { - return Types.INT; - } - return null; - } } From dce8b9be9a56bd72f793bdd2fc84a7136c47ceca Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Thu, 8 Jun 2023 14:28:38 +0800 Subject: [PATCH 15/18] Add test for trainingUtils.java --- .../common/ps/training/TrainingUtilsTest.java | 360 ++++++++++++++++++ 1 file changed, 360 insertions(+) create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java new file mode 100644 index 000000000..f9e47f283 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.ml.common.ps.updater.ModelUpdater; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseIntDoubleVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseIntDoubleVectorTypeInfo; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.util.ResettableIterator; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.test.util.TestBaseUtils; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.function.SerializableSupplier; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +/** Tests {@link TrainingUtils}. */ +public class TrainingUtilsTest { + private static final int NUM_WORKERS = 2; + private static final int NUM_SERVERS = 2; + private static final int MAX_ITER = 3; + private static final int NUM_COLUMNS_PER_KEY = 2; + + private DataStream maxKey; + private DataStream inputData; + + @Before + public void before() { + StreamExecutionEnvironment env = TestUtils.getExecutionEnvironment(); + env.setParallelism(NUM_WORKERS); + maxKey = env.fromElements(3L); + inputData = + env.fromCollection( + Arrays.asList( + Vectors.dense(1, 1, 1, 1), + Vectors.dense(2, 2, 2, 2), + Vectors.dense(3, 3, 3, 3), + Vectors.dense(4, 4, 4, 4))) + .map(x -> x, DenseIntDoubleVectorTypeInfo.INSTANCE); + } + + @Test + public void test() throws Exception { + ExecutionConfig config = maxKey.getExecutionEnvironment().getConfig(); + + TypeSerializer pojoDemoTypeSerializer = + Types.POJO(MockPojo.class).createSerializer(config); + + MockSession mockSession = + new MockSession( + DenseIntDoubleVectorTypeInfo.INSTANCE, + Collections.singletonList( + new OutputTag<>("AllReduceOutputTag", Types.POJO(MockPojo.class)))); + + IterationStageList stageList = + new IterationStageList<>(mockSession) + .addStage(new MockComputePullIndicesStage()) + .addStage( + new PullStage( + (SerializableSupplier) + () -> mockSession.pullIndices, + (SerializableConsumer) + x -> mockSession.pulledValues = x)) + .addStage( + new AllReduceStage<>( + (SerializableSupplier) + () -> mockSession.toAllReduce, + (SerializableConsumer) + x -> mockSession.toAllReduce = x, + (ReduceFunction) TrainingUtilsTest::sumPojo, + pojoDemoTypeSerializer)) + .addStage(new MockComputePushValuesStage(NUM_COLUMNS_PER_KEY)) + .addStage( + new PushStage( + (SerializableSupplier) + () -> mockSession.pushIndices, + (SerializableSupplier) + () -> mockSession.pushValues)) + .setTerminationCriteria(context -> context.iterationId >= MAX_ITER); + + DataStreamList resultList = + TrainingUtils.train( + inputData, + stageList, + maxKey, + new TupleTypeInfo<>( + Types.LONG, + Types.LONG, + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO), + new MockModelUpdater(NUM_COLUMNS_PER_KEY), + NUM_SERVERS); + + // Verifies the model data. + DataStream> modelSegments = resultList.get(0); + List> collectedModelPieces = + IteratorUtils.toList(modelSegments.executeAndCollect()); + Assert.assertEquals(NUM_SERVERS, collectedModelPieces.size()); + collectedModelPieces.sort(Comparator.comparing(o -> o.f0)); + + double[] result = new double[4 * NUM_COLUMNS_PER_KEY]; + double[] expectedResult = new double[4 * NUM_COLUMNS_PER_KEY]; + Arrays.fill(expectedResult, 0, NUM_COLUMNS_PER_KEY, 35.0); + Arrays.fill(expectedResult, 3 * NUM_COLUMNS_PER_KEY, 4 * NUM_COLUMNS_PER_KEY, 35.0); + for (Tuple3 modelPiece : collectedModelPieces) { + int startIndex = (int) (long) modelPiece.f0 * NUM_COLUMNS_PER_KEY; + double[] pieceCoeff = modelPiece.f2; + System.arraycopy(pieceCoeff, 0, result, startIndex, pieceCoeff.length); + } + Assert.assertArrayEquals(expectedResult, result, 1e-7); + + // Verifies the all reduce result from worker output. + DataStream allReduceResult = resultList.get(1); + allReduceResult.getTransformation().setParallelism(1); + List collectedPojo = IteratorUtils.toList(allReduceResult.executeAndCollect()); + List expectedPojo = + Arrays.asList(new MockPojo(1, 0), new MockPojo(2, 0), new MockPojo(4, 0)); + TestBaseUtils.compareResultCollections( + expectedPojo, + collectedPojo, + new Comparator() { + @Override + public int compare(MockPojo o1, MockPojo o2) { + return Integer.compare(o1.i, o2.i); + } + }); + } + + private static MockPojo[] sumPojo(MockPojo[] d1, MockPojo[] d2) { + Preconditions.checkArgument(d1.length == d2.length); + for (int i = 0; i < d1.length; i++) { + d2[i].i += d1[i].i; + d2[i].j += d1[i].j; + } + return d2; + } + + private static class MockSession extends MiniBatchMLSession { + + public MockPojo[] toAllReduce; + private ProxySideOutput output; + + @Override + public void setOutput(ProxySideOutput collector) { + this.output = collector; + } + + public MockSession( + TypeInformation typeInformation, + List> outputTags) { + super(0, typeInformation, outputTags); + } + } + + /** Pulls the 0-th and 3-th dimension of the model from servers. */ + private static class MockComputePullIndicesStage extends ProcessStage { + + @Override + public void process(MockSession context) { + context.pullIndices = new long[] {0, 3}; + if (context.toAllReduce == null) { + context.toAllReduce = new MockPojo[1]; + context.toAllReduce[0] = new MockPojo(1, 0); + } + if (context.workerId == 0) { + context.output.output( + (OutputTag) context.getOutputTags().get(0), + new StreamRecord(context.toAllReduce[0])); + } + } + } + + /** + * Adds the 0-th and 3-th dimension of all training data to the model and pushes it to servers. + */ + private static class MockComputePushValuesStage extends ProcessStage { + private final int numCols; + + public MockComputePushValuesStage(int numCols) { + this.numCols = numCols; + } + + @Override + public void process(MockSession context) throws Exception { + long[] indices = context.pullIndices; + double[] values = context.pulledValues; + ResettableIterator data = context.inputData; + while (data.hasNext()) { + double[] d = data.next().values; + for (int i = 0; i < indices.length; i++) { + double v = d[(int) indices[i]]; + for (int j = 0; j < numCols; j++) { + values[i * numCols + j] += v; + } + } + } + data.reset(); + + BLAS.scal(1.0 / context.numWorkers, new DenseIntDoubleVector(values)); + + context.pushIndices = indices; + context.pushValues = values; + } + } + + /** The logic on servers. */ + private static class MockModelUpdater implements ModelUpdater> { + private final int numDoublesPerKey; + private long startIndex; + private long endIndex; + private double[] model; + + private ListState boundaryState; + private ListState modelDataState; + + public MockModelUpdater(int numDoublesPerKey) { + this.numDoublesPerKey = numDoublesPerKey; + } + + @Override + public void open(long startKeyIndex, long endKeyIndex) { + this.startIndex = startKeyIndex; + this.endIndex = endKeyIndex; + this.model = new double[(int) (endKeyIndex - startKeyIndex) * numDoublesPerKey]; + } + + @Override + public void update(long[] keys, double[] values) { + Preconditions.checkState(keys.length * numDoublesPerKey == values.length); + for (int i = 0; i < keys.length; i++) { + int index = (int) (keys[i] - startIndex); + for (int j = 0; j < numDoublesPerKey; j++) { + model[index * numDoublesPerKey + j] += values[i * numDoublesPerKey + j]; + } + } + } + + @Override + public double[] get(long[] keys) { + double[] values = new double[keys.length * numDoublesPerKey]; + for (int i = 0; i < keys.length; i++) { + int index = (int) (keys[i] - startIndex); + for (int j = 0; j < numDoublesPerKey; j++) { + values[i * numDoublesPerKey + j] += model[index * numDoublesPerKey + j]; + } + } + return values; + } + + @Override + public Iterator> getModelSegments() { + return Collections.singleton(Tuple3.of(startIndex, endIndex, model)).iterator(); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + boundaryState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("BoundaryState", Types.LONG)); + + Iterator iterator = boundaryState.get().iterator(); + if (iterator.hasNext()) { + startIndex = iterator.next(); + endIndex = iterator.next(); + } + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "modelDataState", + PrimitiveArrayTypeInfo + .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); + Iterator modelData = modelDataState.get().iterator(); + if (modelData.hasNext()) { + model = modelData.next(); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + if (model != null) { + boundaryState.clear(); + boundaryState.add(startIndex); + boundaryState.add(endIndex); + + modelDataState.clear(); + modelDataState.add(model); + } + } + } + + public static class MockPojo { + public int i; + public int j; + + public MockPojo(int i, int j) { + this.i = i; + this.j = j; + } + + public MockPojo() {} + + @Override + public String toString() { + return i + "-" + j; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof MockPojo) { + MockPojo other = (MockPojo) obj; + return i == other.i && j == other.j; + } + return false; + } + } +} From 4e77b2ca3a0ec2bd8cbbec61fbcdf0c714ab0fe7 Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Thu, 8 Jun 2023 14:36:29 +0800 Subject: [PATCH 16/18] Rename LogisticRegressionModelData as LogisticRegressionModelDataSegment --- .../LogisticRegression.java | 4 +- .../LogisticRegressionModel.java | 8 ++-- .../LogisticRegressionModelDataUtil.java | 38 ++++++++++--------- .../LogisticRegressionWithFtrl.java | 4 +- .../OnlineLogisticRegression.java | 15 ++++---- .../OnlineLogisticRegressionModel.java | 8 ++-- .../flink/ml/common/ps/message/Message.java | 1 + .../common/ps/training/ProxySideOutput.java | 18 +++++++++ .../LogisticRegressionTest.java | 8 ++-- .../LogisticRegressionWithFtrlTest.java | 6 +-- .../OnlineLogisticRegressionTest.java | 22 +++++------ .../common/ps/training/TrainingUtilsTest.java | 1 + ...> LogisticRegressionModelDataSegment.java} | 26 ++++++------- .../LogisticRegressionModelServable.java | 12 +++--- 14 files changed, 97 insertions(+), 74 deletions(-) rename flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/{LogisticRegressionModelData.java => LogisticRegressionModelDataSegment.java} (82%) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index 22ae3a81d..8d5c6ce2a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -114,8 +114,8 @@ public LogisticRegressionModel fit(Table... inputs) { DataStream rawModelData = optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE); - DataStream modelData = - rawModelData.map(vector -> new LogisticRegressionModelData(vector, 0L)); + DataStream modelData = + rawModelData.map(vector -> new LogisticRegressionModelDataSegment(vector, 0L)); LogisticRegressionModel model = new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); ParamUtils.updateExistingParams(model, paramMap); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java index 2ab502299..8081ec263 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java @@ -67,7 +67,7 @@ public Table[] transform(Table... inputs) { (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); DataStream inputStream = tEnv.toDataStream(inputs[0]); final String broadcastModelKey = "broadcastModelKey"; - DataStream modelDataStream = + DataStream modelDataStream = LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); RowTypeInfo outputTypeInfo = @@ -148,14 +148,14 @@ public PredictLabelFunction(String broadcastModelKey, Map, Object> para @Override public Row map(Row dataPoint) { if (servable == null) { - List modelData = + List modelData = getRuntimeContext().getBroadcastVariable(broadcastModelKey); if (modelData.size() == 1) { servable = new LogisticRegressionModelServable(modelData.get(0)); } else { - LogisticRegressionModelData mergedModel = - LogisticRegressionModelData.mergeSegments(modelData); + LogisticRegressionModelDataSegment mergedModel = + LogisticRegressionModelDataSegment.mergeSegments(modelData); servable = new LogisticRegressionModelServable(mergedModel); } ParamUtils.updateExistingParams(servable, params); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java index 562842536..6a10f98c3 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java @@ -45,8 +45,8 @@ public class LogisticRegressionModelDataUtil { /** - * Generates a Table containing a {@link LogisticRegressionModelData} instance with randomly - * generated coefficient. + * Generates a Table containing a {@link LogisticRegressionModelDataSegment} instance with + * randomly generated coefficient. * * @param tEnv The environment where to create the table. * @param dim The size of generated coefficient. @@ -59,7 +59,7 @@ public static Table generateRandomModelData(StreamTableEnvironment tEnv, int dim } private static class RandomModelDataGenerator - implements MapFunction { + implements MapFunction { private final int dim; private final int seed; @@ -69,13 +69,13 @@ public RandomModelDataGenerator(int dim, int seed) { } @Override - public LogisticRegressionModelData map(Integer integer) throws Exception { + public LogisticRegressionModelDataSegment map(Integer integer) throws Exception { DenseIntDoubleVector vector = new DenseIntDoubleVector(dim); Random random = new Random(seed); for (int j = 0; j < dim; j++) { vector.values[j] = random.nextDouble(); } - return new LogisticRegressionModelData(vector, 0L); + return new LogisticRegressionModelDataSegment(vector, 0L); } } @@ -85,13 +85,14 @@ public LogisticRegressionModelData map(Integer integer) throws Exception { * @param modelData The table model data. * @return The data stream model data. */ - public static DataStream getModelDataStream(Table modelData) { + public static DataStream getModelDataStream( + Table modelData) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); return tEnv.toDataStream(modelData) .map( x -> - new LogisticRegressionModelData( + new LogisticRegressionModelDataSegment( x.getFieldAs(0), x.getFieldAs(1), x.getFieldAs(2), @@ -111,8 +112,8 @@ public static DataStream getModelDataByteStream(Table modelDataTable) { return tEnv.toDataStream(modelDataTable) .map( x -> { - LogisticRegressionModelData modelData = - new LogisticRegressionModelData( + LogisticRegressionModelDataSegment modelData = + new LogisticRegressionModelDataSegment( x.getFieldAs(0), x.getFieldAs(1), x.getFieldAs(2), @@ -125,26 +126,27 @@ public static DataStream getModelDataByteStream(Table modelDataTable) { } /** Data encoder for {@link LogisticRegression} and {@link OnlineLogisticRegression}. */ - public static class ModelDataEncoder implements Encoder { + public static class ModelDataEncoder implements Encoder { @Override - public void encode(LogisticRegressionModelData modelData, OutputStream outputStream) + public void encode(LogisticRegressionModelDataSegment modelData, OutputStream outputStream) throws IOException { modelData.encode(outputStream); } } /** Data decoder for {@link LogisticRegression} and {@link OnlineLogisticRegression}. */ - public static class ModelDataDecoder extends SimpleStreamFormat { + public static class ModelDataDecoder + extends SimpleStreamFormat { @Override - public Reader createReader( + public Reader createReader( Configuration configuration, FSDataInputStream inputStream) { - return new Reader() { + return new Reader() { @Override - public LogisticRegressionModelData read() throws IOException { + public LogisticRegressionModelDataSegment read() throws IOException { try { - return LogisticRegressionModelData.decode(inputStream); + return LogisticRegressionModelDataSegment.decode(inputStream); } catch (EOFException e) { return null; } @@ -158,8 +160,8 @@ public void close() throws IOException { } @Override - public TypeInformation getProducedType() { - return TypeInformation.of(LogisticRegressionModelData.class); + public TypeInformation getProducedType() { + return TypeInformation.of(LogisticRegressionModelDataSegment.class); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java index 2e9fcd1c7..5299185ce 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java @@ -174,10 +174,10 @@ public LogisticRegressionModel fit(Table... inputs) { final long modelVersion = 0L; - DataStream modelData = + DataStream modelData = rawModelData.map( tuple3 -> - new LogisticRegressionModelData( + new LogisticRegressionModelDataSegment( Vectors.dense(tuple3.f2), tuple3.f0, tuple3.f1, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java index 75a7ce722..57d9b23a8 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java @@ -88,7 +88,7 @@ public OnlineLogisticRegressionModel fit(Table... inputs) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); - DataStream modelDataStream = + DataStream modelDataStream = LogisticRegressionModelDataUtil.getModelDataStream(initModelDataTable); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); @@ -116,7 +116,7 @@ public OnlineLogisticRegressionModel fit(Table... inputs) { DataStream initModelData = modelDataStream.map( - (MapFunction) + (MapFunction) value -> value.coefficient); initModelData.getTransformation().setParallelism(1); @@ -125,7 +125,7 @@ public OnlineLogisticRegressionModel fit(Table... inputs) { new FtrlIterationBody( getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet()); - DataStream onlineModelData = + DataStream onlineModelData = Iterations.iterateUnboundedStreams( DataStreamList.of(initModelData), DataStreamList.of(points), body) .get(0); @@ -225,7 +225,7 @@ public IterationBodyResult process( new UpdateModel(alpha, beta, l1, l2)) .setParallelism(1); - DataStream outputModelData = + DataStream outputModelData = feedbackModelData.map(new CreateLrModelData()).setParallelism(1); return new IterationBodyResult( DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData)); @@ -233,14 +233,15 @@ public IterationBodyResult process( } private static class CreateLrModelData - implements MapFunction, + implements MapFunction, CheckpointedFunction { private Long modelVersion = 1L; private transient ListState modelVersionState; @Override - public LogisticRegressionModelData map(DenseIntDoubleVector denseVector) throws Exception { - return new LogisticRegressionModelData(denseVector, modelVersion++); + public LogisticRegressionModelDataSegment map(DenseIntDoubleVector denseVector) + throws Exception { + return new LogisticRegressionModelDataSegment(denseVector, modelVersion++); } @Override diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java index 81f0c9979..a0e6a096f 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java @@ -99,7 +99,7 @@ public Table[] transform(Table... inputs) { /** A utility operator used for prediction. */ private static class PredictLabelOperator extends AbstractStreamOperator - implements TwoInputStreamOperator { + implements TwoInputStreamOperator { private final RowTypeInfo inputTypeInfo; private final Map, Object> params; @@ -139,9 +139,9 @@ public void processElement1(StreamRecord streamRecord) throws Exception { } @Override - public void processElement2(StreamRecord streamRecord) + public void processElement2(StreamRecord streamRecord) throws Exception { - LogisticRegressionModelData modelData = streamRecord.getValue(); + LogisticRegressionModelDataSegment modelData = streamRecord.getValue(); coefficient = modelData.coefficient; modelDataVersion = modelData.modelVersion; for (Row dataPoint : bufferedPointsState.get()) { @@ -159,7 +159,7 @@ public void processElement(StreamRecord streamRecord) throws Exception { if (servable == null) { servable = new LogisticRegressionModelServable( - new LogisticRegressionModelData(coefficient, 0L)); + new LogisticRegressionModelDataSegment(coefficient, 0L)); ParamUtils.updateExistingParams(servable, params); } IntDoubleVector features = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java index 2c1b98a33..67997a7f4 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/Message.java @@ -139,6 +139,7 @@ public int getWorkerId() { public void setWorkerId(int workerId) { Bits.putInt(bytes, WORKER_ID_OFFSET, workerId); } + /** Retrieves the server id. */ public int getServerId() { return Bits.getInt(bytes, SERVER_ID_OFFSET); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java index b843894d2..6c8036204 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/ProxySideOutput.java @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.flink.ml.common.ps.training; import org.apache.flink.streaming.api.operators.Output; diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index b70c173ea..739d41973 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -23,7 +23,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataSegment; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable; import org.apache.flink.ml.linalg.DenseIntDoubleVector; @@ -339,7 +339,7 @@ public void testSaveLoadAndPredict() throws Exception { public void testGetModelData() throws Exception { LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); - List modelData = + List modelData = IteratorUtils.toList( LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); @@ -352,7 +352,7 @@ public void testGetModelData() throws Exception { public void testGetModelDataFromSparseInput() throws Exception { LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); LogisticRegressionModel model = logisticRegression.fit(binomialSparseDataTable); - List modelData = + List modelData = IteratorUtils.toList( LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); @@ -476,7 +476,7 @@ private void checkRegularization(double reg, double elasticNet, double[] expecte .setReg(reg) .setElasticNet(elasticNet) .fit(binomialDataTable); - List modelData = + List modelData = IteratorUtils.toList( LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java index 218947496..02e3579b3 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionWithFtrlTest.java @@ -22,7 +22,7 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataSegment; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionWithFtrl; @@ -256,7 +256,7 @@ public void testGetModelData() throws Exception { LogisticRegressionWithFtrl logisticRegressionWithFtrl = new LogisticRegressionWithFtrl().setMaxIter(MAX_ITER).setNumServers(NUM_SERVERS); LogisticRegressionModel model = logisticRegressionWithFtrl.fit(trainTable); - List modelData = + List modelData = IteratorUtils.toList( LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); @@ -266,7 +266,7 @@ public void testGetModelData() throws Exception { modelData.sort(Comparator.comparingLong(o -> o.startIndex)); double[] collectedCoefficient = new double[4]; - for (LogisticRegressionModelData modelSegment : modelData) { + for (LogisticRegressionModelDataSegment modelSegment : modelData) { int startIndex = (int) modelSegment.startIndex; double[] segment = modelSegment.coefficient.values; System.arraycopy(segment, 0, collectedCoefficient, startIndex, segment.length); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java index 035d8c791..9c8e37605 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java @@ -29,7 +29,7 @@ import org.apache.flink.metrics.Gauge; import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataSegment; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; @@ -153,7 +153,7 @@ public class OnlineLogisticRegressionTest extends TestLogger { private InMemorySourceFunction trainSparseSource; private InMemorySourceFunction predictSparseSource; private InMemorySinkFunction outputSink; - private InMemorySinkFunction modelDataSink; + private InMemorySinkFunction modelDataSink; private static InMemoryReporter reporter; private static MiniCluster miniCluster; @@ -671,10 +671,10 @@ public void testGetModelData() throws Exception { submitJob(env.getStreamGraph().getJobGraph()); trainDenseSource.addAll(TRAIN_DENSE_ROWS_1); - LogisticRegressionModelData actualModelData = modelDataSink.poll(); + LogisticRegressionModelDataSegment actualModelData = modelDataSink.poll(); - LogisticRegressionModelData expectedModelData = - new LogisticRegressionModelData( + LogisticRegressionModelDataSegment expectedModelData = + new LogisticRegressionModelDataSegment( new DenseIntDoubleVector( new double[] {0.2994527071464283, -0.1412541067743284}), 1L); @@ -685,12 +685,12 @@ public void testGetModelData() throws Exception { @Test public void testSetModelData() throws Exception { - LogisticRegressionModelData modelData1 = - new LogisticRegressionModelData( + LogisticRegressionModelDataSegment modelData1 = + new LogisticRegressionModelDataSegment( new DenseIntDoubleVector(new double[] {0.085, -0.22}), 1L); - LogisticRegressionModelData modelData2 = - new LogisticRegressionModelData( + LogisticRegressionModelDataSegment modelData2 = + new LogisticRegressionModelDataSegment( new DenseIntDoubleVector(new double[] {0.075, -0.28}), 2L); final List expectedRawInfo1 = @@ -706,13 +706,13 @@ public void testSetModelData() throws Exception { new DenseIntDoubleVector( new double[] {0.8779865510655934, 0.12201344893440658})); - InMemorySourceFunction modelDataSource = + InMemorySourceFunction modelDataSource = new InMemorySourceFunction<>(); Table modelDataTable = tEnv.fromDataStream( env.addSource( modelDataSource, - TypeInformation.of(LogisticRegressionModelData.class))); + TypeInformation.of(LogisticRegressionModelDataSegment.class))); OnlineLogisticRegressionModel onlineModel = new OnlineLogisticRegressionModel() diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java index f9e47f283..a2ca98236 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/TrainingUtilsTest.java @@ -332,6 +332,7 @@ public void snapshotState(StateSnapshotContext context) throws Exception { } } + /** Mock pojo class to test all reduce. */ public static class MockPojo { public int i; public int j; diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataSegment.java similarity index 82% rename from flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java rename to flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataSegment.java index ade944646..127e00371 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataSegment.java @@ -30,8 +30,8 @@ import java.io.OutputStream; import java.util.List; -/** Model data of {@link LogisticRegressionModelServable}. */ -public class LogisticRegressionModelData { +/** Segment model data of {@link LogisticRegressionModelServable}. */ +public class LogisticRegressionModelDataSegment { public DenseIntDoubleVector coefficient; @@ -41,13 +41,13 @@ public class LogisticRegressionModelData { public long modelVersion; - public LogisticRegressionModelData() {} + public LogisticRegressionModelDataSegment() {} - public LogisticRegressionModelData(DenseIntDoubleVector coefficient, long modelVersion) { + public LogisticRegressionModelDataSegment(DenseIntDoubleVector coefficient, long modelVersion) { this(coefficient, 0L, coefficient.size(), modelVersion); } - public LogisticRegressionModelData( + public LogisticRegressionModelDataSegment( DenseIntDoubleVector coefficient, long startIndex, long endIndex, long modelVersion) { this.coefficient = coefficient; this.startIndex = startIndex; @@ -78,7 +78,7 @@ public void encode(OutputStream outputStream) throws IOException { * @param inputStream The stream to read from. * @return The model data instance. */ - static LogisticRegressionModelData decode(InputStream inputStream) throws IOException { + static LogisticRegressionModelDataSegment decode(InputStream inputStream) throws IOException { DataInputViewStreamWrapper dataInputViewStreamWrapper = new DataInputViewStreamWrapper(inputStream); @@ -88,23 +88,23 @@ static LogisticRegressionModelData decode(InputStream inputStream) throws IOExce long endIndex = dataInputViewStreamWrapper.readLong(); long modelVersion = dataInputViewStreamWrapper.readLong(); - return new LogisticRegressionModelData(coefficient, startIndex, endIndex, modelVersion); + return new LogisticRegressionModelDataSegment( + coefficient, startIndex, endIndex, modelVersion); } @VisibleForTesting - public static LogisticRegressionModelData mergeSegments( - List segments) { + public static LogisticRegressionModelDataSegment mergeSegments( + List segments) { long dim = 0; - for (LogisticRegressionModelData segment : segments) { + for (LogisticRegressionModelDataSegment segment : segments) { dim = Math.max(dim, segment.endIndex); } - // TODO: Add distributed inference for very large models. Preconditions.checkState( dim < Integer.MAX_VALUE, "The dimension of logistic regression model is larger than INT.MAX. Please consider using distributed inference."); int intDim = (int) dim; DenseIntDoubleVector mergedCoefficient = new DenseIntDoubleVector(intDim); - for (LogisticRegressionModelData segment : segments) { + for (LogisticRegressionModelDataSegment segment : segments) { int startIndex = (int) segment.startIndex; int endIndex = (int) segment.endIndex; System.arraycopy( @@ -114,7 +114,7 @@ public static LogisticRegressionModelData mergeSegments( startIndex, endIndex - startIndex); } - return new LogisticRegressionModelData( + return new LogisticRegressionModelDataSegment( mergedCoefficient, 0, mergedCoefficient.size(), segments.get(0).modelVersion); } } diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java index 468392215..655ea13b4 100644 --- a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java @@ -47,13 +47,13 @@ public class LogisticRegressionModelServable private final Map, Object> paramMap = new HashMap<>(); - private LogisticRegressionModelData modelData; + private LogisticRegressionModelDataSegment modelData; public LogisticRegressionModelServable() { ParamUtils.initializeMapWithDefaultValues(paramMap, this); } - LogisticRegressionModelServable(LogisticRegressionModelData modelData) { + LogisticRegressionModelServable(LogisticRegressionModelDataSegment modelData) { this(); this.modelData = modelData; } @@ -81,11 +81,11 @@ public DataFrame transform(DataFrame input) { public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs) throws IOException { Preconditions.checkArgument(modelDataInputs.length == 1); - List modelSegments = new ArrayList<>(); + List modelSegments = new ArrayList<>(); while (true) { try { - LogisticRegressionModelData segment = - LogisticRegressionModelData.decode(modelDataInputs[0]); + LogisticRegressionModelDataSegment segment = + LogisticRegressionModelDataSegment.decode(modelDataInputs[0]); modelSegments.add(segment); } catch (IOException e) { // Reached the end of model stream. @@ -93,7 +93,7 @@ public LogisticRegressionModelServable setModelData(InputStream... modelDataInpu } } - modelData = LogisticRegressionModelData.mergeSegments(modelSegments); + modelData = LogisticRegressionModelDataSegment.mergeSegments(modelSegments); return this; } From d85382809a213837f26c1054daf76405c559fa44 Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Thu, 8 Jun 2023 15:13:10 +0800 Subject: [PATCH 17/18] add bench io stuff, should be deleted in the real PR --- .../flink/ml/common/ps/training/PSBench.java | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/PSBench.java diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/PSBench.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/PSBench.java new file mode 100644 index 000000000..a3dd5c105 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/ps/training/PSBench.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.ps.training; + +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.util.Bits; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; + +/** Benchmark for PS-related stuff. */ +public class PSBench { + @Test + public void benchBits() { + double[] result = new double[100000]; + int warmUp = 500; + int numTries = 1000; + for (int i = 0; i < result.length; i++) { + result[i] = i; + } + byte[] bytes = new byte[Bits.getDoubleArraySizeInBytes(result)]; + + for (int i = 0; i < warmUp; i++) { + Bits.putDoubleArray(result, bytes, 0); + } + + long start = System.currentTimeMillis(); + for (int i = 0; i < numTries; i++) { + Bits.putDoubleArray(result, bytes, 0); + } + long end = System.currentTimeMillis(); + System.out.println(end - start); // ~600ms + } + + @Test + public void benchTypeSerializer() throws IOException { + double[] result = new double[100000]; + int warmUp = 500; + int numTries = 1000; + for (int i = 0; i < result.length; i++) { + result[i] = i; + } + byte[] bytes = new byte[Bits.getDoubleArraySizeInBytes(result)]; + + for (int i = 0; i < warmUp; i++) { + bytes = serializeDoubleArray(result); + } + + long start = System.currentTimeMillis(); + for (int i = 0; i < numTries; i++) { + bytes = serializeDoubleArray(result); + } + long end = System.currentTimeMillis(); + System.out.println(end - start); // 2000ms + Assert.assertEquals(Bits.getDoubleArraySizeInBytes(result), bytes.length); + } + + private byte[] serializeDoubleArray(double[] result) throws IOException { + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(byteArrayOutputStream); + dataOutputViewStreamWrapper.writeInt(result.length); + + for (double value : result) { + dataOutputViewStreamWrapper.writeDouble(value); + } + return byteArrayOutputStream.toByteArray(); + } +} From b83939ca5670e07642ca8ea253b4c53973ec74c5 Mon Sep 17 00:00:00 2001 From: "congzhou.zzp" Date: Fri, 9 Jun 2023 10:37:10 +0800 Subject: [PATCH 18/18] cp --- .../flink/ml/common/param/HasOptimizer.java | 40 +++++++++++++++++++ .../flink/ml/common/param/Optimizer.java | 3 ++ .../org/apache/flink/ml/common/param/SGD.java | 29 ++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOptimizer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/Optimizer.java create mode 100644 flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/SGD.java diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOptimizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOptimizer.java new file mode 100644 index 000000000..8165db869 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOptimizer.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.WithParams; + +/** + * cc + * + * @param + */ +public interface HasOptimizer extends WithParams { + Param OPTIMIZER = + new Param<>("optimizer", Optimizer.class, "The optimizer", new SGD(0.1, 100), null); + + default Optimizer getOptimizerParam() { + return get(OPTIMIZER); + } + + default T setOptimizer(Optimizer optimizer) { + return set(OPTIMIZER, optimizer); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/Optimizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/Optimizer.java new file mode 100644 index 000000000..970c34d2e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/Optimizer.java @@ -0,0 +1,3 @@ +package org.apache.flink.ml.common.param; + +public interface Optimizer {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/SGD.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/SGD.java new file mode 100644 index 000000000..e2b47996d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/SGD.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.ml.common.param; + +/** xx */ +public class SGD implements Optimizer { + public final double stepSize; + public final int globalBatchSize; + + public SGD(double stepSize, int globalBatchSize) { + this.stepSize = stepSize; + this.globalBatchSize = globalBatchSize; + } +}