-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathInteractionDemo.java
More file actions
71 lines (57 loc) · 2.94 KB
/
InteractionDemo.java
File metadata and controls
71 lines (57 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package com.topsec.ti.patronus;
import org.apache.spark.ml.feature.Interaction;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.Arrays;
import java.util.List;
import static javafx.beans.binding.Bindings.select;
/**
* Created by hhy on 2017/09/07.
*/
public class InteractionDemo {
public static void main(String[] args){
SparkSession spark= SparkSession.builder().master("local").appName("").getOrCreate();
List<Row> data = Arrays.asList(
RowFactory.create(1, 1, 2, 3, 8, 4, 5),
RowFactory.create(2, 4, 3, 8, 7, 9, 8),
RowFactory.create(3, 6, 1, 9, 2, 3, 6),
RowFactory.create(4, 10, 8, 6, 9, 4, 5),
RowFactory.create(5, 9, 2, 7, 10, 7, 3),
RowFactory.create(6, 1, 1, 4, 2, 8, 4)
);
StructType schema = new StructType(new StructField[]{
new StructField("id1", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("id2", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("id3", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("id4", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("id5", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("id6", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("id7", DataTypes.IntegerType, false, Metadata.empty())
});
Dataset<Row> df = spark.createDataFrame(data, schema);
VectorAssembler assembler1 = new VectorAssembler()
.setInputCols(new String[]{"id2", "id3"})
.setOutputCol("vec1");
Dataset<Row> assembled1 = assembler1.transform(df);
VectorAssembler assembler2 = new VectorAssembler()
.setInputCols(new String[]{"id4","id5"})
.setOutputCol("vec2");
Dataset<Row> assembled2 = assembler2.transform(assembled1).select("id1", "vec1", "vec2");
VectorAssembler assembler3 = new VectorAssembler()
.setInputCols(new String[]{"id6", "id7"})
.setOutputCol("vec3");
Dataset<Row> assembled3 = assembler3.transform(assembled2).select("id1", "vec1", "vec2","vec3");
Interaction interaction = new Interaction()
.setInputCols(new String[]{"id1","vec1","vec2","vec3"})
.setOutputCol("interactedCol");
Dataset<Row> interacted = interaction.transform(assembled2);
interacted.show(false);
}
}