diff --git a/components/Java/pmml_predictor/Makefile b/components/Java/pmml_predictor/Makefile new file mode 100644 index 0000000..fed5754 --- /dev/null +++ b/components/Java/pmml_predictor/Makefile @@ -0,0 +1,9 @@ + +.PHONEY: verify run + +verify: + python -m json.tool < component.json > /dev/null + + +run: + run.sh diff --git a/components/Java/pmml_predictor/README.txt b/components/Java/pmml_predictor/README.txt new file mode 100755 index 0000000..17435b6 --- /dev/null +++ b/components/Java/pmml_predictor/README.txt @@ -0,0 +1,11 @@ +#!/bin/bash + +# --- copy mlcomp jar, and (jpmml generated MOJO jar file) ---- +# a) copy over ../../../../../../reflex-common/mlcomp/target/mlcomp.jar . +# b) provide the model to be used for inference along with data-set for inference + +java -cp ./mlcomp.jar:./target/pmml_predictor/pmml_predictor.jar \ + org.mlpiper.mlhub.components.pmml_predictor.PmmlPredictor \ + --input-model \ + --samples-file \ + --output-file ./Results.csv diff --git a/components/Java/pmml_predictor/component.json b/components/Java/pmml_predictor/component.json new file mode 100644 index 0000000..ba64834 --- /dev/null +++ b/components/Java/pmml_predictor/component.json @@ -0,0 +1,72 @@ +{ + "version": 1, + "engineType": "Generic", + "language": "Java", + "userStandalone": false, + "name": "pmml_predictor", + "label": "PMML Predictor", + "description": "Given a PMML model, perform predictions by reading data from a csv file and save predictions to file", + "program": "pmml_predictor.jar", + "componentClass": "org.mlpiper.mlhub.components.pmml_predictor.PmmlPredictor", + "modelBehavior": "ModelConsumer", + "useMLOps": true, + "inputInfo": [ + { + "label": "samples-file", + "description": "Samples Input file, csv format. First line should contain features names", + "defaultComponent": "", + "type": "str", + "group": "data" + } + ], + "outputInfo": [ + { + "label": "predictions-file", + "description": "Output file containing predictions", + "defaultComponent": "", + "type": "str", + "group": "data" + } + ], + "group": "Algorithms", + "arguments": [ + { + "key": "input_model", + "label": "Model input file", + "description": "File to use for loading the model", + "type": "str", + "optional": true, + "tag": "input_model_path" + }, + { + "key": "samples_file", + "label": "Prediction samples file", + "description": "Samples Input file, csv format. First line should contain features names", + "type": "str", + "optional": true + }, + { + "key": "output_file", + "label": "Predictions output file", + "description": "File to save predictions in, if a directory is provided then the file is created inside", + "type": "str", + "optional": true + }, + { + "key": "convert_unknown_categorical_levels_to_na", + "label": "Convert Unknown Categorical Levels To Na", + "description": "Convert Unknown Categorical Levels To Na", + "type": "boolean", + "default": 1, + "optional": true + }, + { + "key": "convert_invalid_numbers_to_na", + "label": "Convert Invalid Numbers To Na", + "description": "Convert Invalid Numbers To Na", + "type": "boolean", + "default": 1, + "optional": true + } + ] +} diff --git a/components/Java/pmml_predictor/pom.xml b/components/Java/pmml_predictor/pom.xml new file mode 100644 index 0000000..46b810d --- /dev/null +++ b/components/Java/pmml_predictor/pom.xml @@ -0,0 +1,217 @@ + + 4.0.0 + org.mlpiper.mlhub.components.pmml_predictor + pmml_predictor + jar + 1.0 + pmml_predictor + http://maven.apache.org + + + ${project.build.directory}/${project.artifactId} + + + + + log4j + log4j + 1.2.17 + + + + net.sourceforge.argparse4j + argparse4j + 0.8.1 + + + + com.parallelm.mlcomp + mlcomp + 1.0 + + + + org.apache.commons + commons-csv + 1.5 + + + junit + junit + 4.12 + test + + + org.jpmml + pmml-evaluator + 1.3.5 + + + org.jpmml + pmml-model + 1.3.7 + + + org.jpmml + pmml-agent + + + + + javax.xml.bind + jaxb-api + 2.3.0 + + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + javax.activation + javax.activation-api + 1.2.0 + + + + + + + production + + true + + + + + ${project.basedir}/src/test/resources + + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + + + + org.apache.maven.plugins + maven-assembly-plugin + 2.4.1 + + ${project.artifactId} + ${componentDirectory} + + + + org.mlpiper.mlhub.components.pmml_predictor.PmmlPredictor + + + + + jar-with-dependencies + + false + + + + package + + single + + + + + + + org.codehaus.mojo + exec-maven-plugin + 1.6.0 + + + + verify-component-json + compile + + exec + + + ${project.basedir} + make + verify + + + + copy-component-json + package + + exec + + + ${project.basedir} + cp + component.json ${componentDirectory} + + + + create-init-dot-py + package + + exec + + + ${project.basedir} + touch + ${componentDirectory}/__init__.py + + + + gen-component-tar + package + + exec + + + ${project.build.directory} + tar + cvf ${project.name}.tar ${project.name} + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.12 + + + + + + + + + test + + test + false + true + false + + + + diff --git a/components/Java/pmml_predictor/run.sh b/components/Java/pmml_predictor/run.sh new file mode 100755 index 0000000..dfa6721 --- /dev/null +++ b/components/Java/pmml_predictor/run.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +MODEL=$1 +DATA_FILE=$2 + +java -cp ./target/pmml_predictor/pmml_predictor.jar \ + org.mlpiper.mlhub.components.pmml_predictor.PmmlPredictor \ + --convert-invalid-numbers-to-na true \ + --convert-unknown-categorical-levels-to-na true \ + --input-model $MODEL \ + --samples-file $DATA_FILE \ + --output-file /tmp/predictions.csv + diff --git a/components/Java/pmml_predictor/src/main/java/org/mlpiper/mlhub/components/pmml_predictor/PmmlPredictor.java b/components/Java/pmml_predictor/src/main/java/org/mlpiper/mlhub/components/pmml_predictor/PmmlPredictor.java new file mode 100644 index 0000000..df68145 --- /dev/null +++ b/components/Java/pmml_predictor/src/main/java/org/mlpiper/mlhub/components/pmml_predictor/PmmlPredictor.java @@ -0,0 +1,303 @@ +package org.mlpiper.mlhub.components.pmml_predictor; + +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVPrinter; +import org.apache.commons.csv.CSVRecord; + +import org.apache.log4j.ConsoleAppender; +import org.apache.log4j.Logger; +import org.apache.log4j.Level; + +import com.parallelm.mlcomp.MCenterComponent; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; + +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.*; + +import org.apache.log4j.PatternLayout; + +import org.dmg.pmml.FieldName; +import org.jpmml.evaluator.Evaluator; +import org.jpmml.evaluator.ModelEvaluatorFactory; +import org.jpmml.evaluator.*; +import org.jpmml.model.JAXBUtil; +import org.jpmml.model.ImportFilter; +import org.xml.sax.InputSource; + +import javax.xml.transform.sax.SAXSource; + +class CSVSampleReader { + CSVParser csvParser; + Iterator csvRecordIterator; + Map headerMap; + Evaluator jpmmlEvaluator = null; + + public CSVSampleReader(Path csvSampleFilePath, Evaluator evaluator) throws Exception { + + // TODO: add header as an option + new BufferedReader(new FileReader(csvSampleFilePath.toString())); + Reader reader = Files.newBufferedReader(Paths.get(csvSampleFilePath.toString())); + csvParser = new CSVParser(reader, CSVFormat.DEFAULT + .withFirstRecordAsHeader() + .withIgnoreHeaderCase() + .withTrim()); + csvRecordIterator = csvParser.iterator(); + headerMap = csvParser.getHeaderMap(); + jpmmlEvaluator = evaluator; + } + + public Map getHeader() { + return headerMap; + } + + public Map nextSample() throws Exception { + Map row = new HashMap(); + + if (!csvRecordIterator.hasNext()) { + return null; + } + + CSVRecord csvRecord = csvRecordIterator.next(); + + List inputField = jpmmlEvaluator.getInputFields(); + + for (InputField field : inputField) { + Integer idx = headerMap.get(field.getName().getValue()); + FieldValue fv = field.prepare(csvRecord.get(idx)); + row.put(field.getName(), fv); + } + return row; + } +} + +abstract class PredictionWriter { + abstract public void writeHeader(ArrayList header) throws Exception; + abstract public void writePrediction(ArrayList record) throws Exception; + abstract public void close() throws Exception; +} + +class CSVPredictionWriter extends PredictionWriter { + private Path predictionFilePath; + private CSVPrinter csvPrinter; + + public CSVPredictionWriter(Path predictionFilePath) throws Exception { + this.predictionFilePath = predictionFilePath; + csvPrinter = new CSVPrinter(new FileWriter(predictionFilePath.toString()), CSVFormat.DEFAULT); + } + + @Override + public void writeHeader(ArrayList header) throws Exception { + csvPrinter.printRecord(header); + } + @Override + public void writePrediction(ArrayList record) throws Exception { + csvPrinter.printRecord(record); + } + + @Override + public void close() throws Exception { + csvPrinter.close(); + } +} + +public class PmmlPredictor extends MCenterComponent +{ + private Path modelFilePath; + private Path inputSamplesFilePath; + private Path outputPredictionsFilePath; + private static Logger logger = Logger.getLogger(PmmlPredictor.class); + + private Evaluator evaluator; + + private final String tmpDir = "/tmp"; + + private void checkArgs(List parentDataObjects) throws Exception { + + String inputSamplesFileStr; + // From params + String modelPathStr = (String) params.get("input_model"); + System.out.println("param - input_model: " + modelPathStr); + String outputPredictionsFileStr = (String) params.getOrDefault("output_file", null); + System.out.println("param - output_file: " + outputPredictionsFileStr); + + if (parentDataObjects.size() != 0) { + // From component input + inputSamplesFileStr = (String) parentDataObjects.get(0); + System.out.println("Connected component input parentDataObjects - 0: " + inputSamplesFileStr); + } else { + // get sample-file path from param + inputSamplesFileStr = (String) params.getOrDefault("samples_file", null); + System.out.println("param - samples_file: " + inputSamplesFileStr); + } + + inputSamplesFilePath = Paths.get(inputSamplesFileStr); + if (!inputSamplesFilePath.toFile().exists()) { + throw new Exception(String.format("Input samples file [%s] does not exists", inputSamplesFilePath)); + } + + modelFilePath = Paths.get(modelPathStr); + if (!modelFilePath.toFile().exists()) { + throw new Exception(String.format("Model file [%s] already exists", modelFilePath)); + } + + String outputFile = "pmml_predictions_" + UUID.randomUUID().toString() + ".out"; + if (outputPredictionsFileStr != null) { + outputPredictionsFilePath = Paths.get(outputPredictionsFileStr); + if (outputPredictionsFilePath.toFile().exists()) { + if (outputPredictionsFilePath.toFile().isDirectory()) { + outputPredictionsFilePath = Paths.get(outputPredictionsFileStr.toString(), outputFile); + } + } + } else { + outputPredictionsFilePath = Paths.get(tmpDir, outputFile); + System.out.println(String.format("No output file/dir was given - using [%s]", outputPredictionsFilePath)); + } + + String desc = ""; + desc += "Model: %s\n"; + desc += "Samples file: %s\n"; + desc += "Predictions file %s\n"; + + System.out.println(String.format(desc, modelFilePath, inputSamplesFileStr, outputPredictionsFilePath)); + + } + + public void loadModel(List parentDataObjects) throws Exception { + checkArgs(parentDataObjects); + ModelEvaluatorFactory evaluatorInstance = ModelEvaluatorFactory.newInstance(); + List modelList = Files.readAllLines(modelFilePath); + String modelString = String.join("\n", modelList); + SAXSource src = + JAXBUtil.createFilteredSource( + new InputSource(new StringReader(modelString)), new ImportFilter()); + evaluator = evaluatorInstance.newModelEvaluator(JAXBUtil.unmarshalPMML(src)); + } + + private static ArrayList fixHeader(Map header) { + ArrayList header2 = new ArrayList(); + header2.ensureCapacity(header.size()); + for (int idx=0 ; idx < header.size() ; idx++) { + header2.add("aa"); + } + for (Map.Entry entry: header.entrySet()) { + header2.set(entry.getValue(), entry.getKey()); + } + return header2; + } + + @Override + public List materialize(List parentDataObjects) throws Exception { + logger.info("PmmlPredictor - materialize"); + checkArgs(parentDataObjects); + loadModel(parentDataObjects); + + CSVSampleReader sampleReader = new CSVSampleReader(inputSamplesFilePath, evaluator); + PredictionWriter predictionWriter = new CSVPredictionWriter(outputPredictionsFilePath); + predict(sampleReader, predictionWriter, true); + + predictionWriter.close(); + List outputs = new ArrayList<>(); + outputs.add(outputPredictionsFilePath.toString()); + return outputs; + } + + private void predict(CSVSampleReader sampleReader, PredictionWriter predictionWriter, + boolean writeHeader) throws Exception { + + Map sample; + int sampleIndex = 0; + logger.info("PMML predictor - Starting predict loop"); + + while((sample = sampleReader.nextSample()) != null) { + Map output = evaluator.evaluate(sample); + + ArrayList resultRecord = new ArrayList<>(); + ArrayList headerRecord = new ArrayList<>(); + headerRecord.add("index"); + resultRecord.add(sampleIndex); + + for(FieldName fn: sample.keySet()) { + headerRecord.add(fn.getValue()); + resultRecord.add(sample.get(fn).asString()); + } + + for(FieldName key: output.keySet()) { + headerRecord.add(key); + resultRecord.add(output.get(key)); + if (sampleIndex == 0) { + logger.info("name: " + key.toString() + " value: " + output.get(key)); + } + } + + if (sampleIndex == 0 && writeHeader) { + predictionWriter.writePrediction(headerRecord); + } + + predictionWriter.writePrediction(resultRecord); + sampleIndex++; + } + if (mlops != null) { + mlops.setStat("pipelinestat.count", sampleIndex); + } + } + + public static void main(String[] args ) throws Exception { + ConsoleAppender console = new ConsoleAppender(); //create appender + //configure the appender + String PATTERN = "%d [%p|%c|%C{1}] %m%n"; + console.setLayout(new PatternLayout(PATTERN)); + console.setThreshold(Level.DEBUG); + console.activateOptions(); + //add appender to any Logger (here is root) + Logger.getRootLogger().addAppender(console); + Logger.getRootLogger().setLevel(Level.INFO); + + PmmlPredictor middleComponent = new PmmlPredictor(); + + ArgumentParser parser = ArgumentParsers.newFor("Checksum").build() + .defaultHelp(true) + .description("Calculate checksum of given files."); + + parser.addArgument("--input-model") + .help("Path to input model to consume"); + + parser.addArgument("--samples-file") + .help("Path to samples to predict"); + + parser.addArgument("--output-file") + .help("Path to record the predictions made"); + + parser.addArgument("--convert-unknown-categorical-levels-to-na") + .type(Boolean.class) + .setDefault(false) + .help("Set the convert_unknown_categorical_levels_to_na property of the Mojo model predictor"); + + parser.addArgument("--convert-invalid-numbers-to-na") + .type(Boolean.class) + .setDefault(false) + .help("Set the convert-invalid-numbers-to-na property of the Mojo model predictor"); + + Namespace options = null; + try { + options = parser.parseArgs(args); + } catch (ArgumentParserException e) { + parser.handleError(e); + System.exit(1); + } + + middleComponent.configure(options.getAttrs()); + List parentObjs = new ArrayList(); + parentObjs.add(options.get("samples_file")); + List outputs = middleComponent.materialize(parentObjs); + for (Object obj: outputs) { + System.out.println("Output: " + obj.toString()); + } + } +} diff --git a/components/Java/pmml_predictor/src/test/main/java/org/mlpiper/mlhub/components/pmml_predictor/PmmlPredictorTest.java b/components/Java/pmml_predictor/src/test/main/java/org/mlpiper/mlhub/components/pmml_predictor/PmmlPredictorTest.java new file mode 100644 index 0000000..c406198 --- /dev/null +++ b/components/Java/pmml_predictor/src/test/main/java/org/mlpiper/mlhub/components/pmml_predictor/PmmlPredictorTest.java @@ -0,0 +1,204 @@ +package org.mlpiper.mlhub.components.pmml_predictor; + +import java.io.*; +import java.net.URL; +import java.nio.file.Paths; +import java.text.ParseException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.*; +import org.junit.rules.TestName; + +import static java.lang.System.out; + +/** + * Unit tests for Pmml Predictor + */ + +public class PmmlPredictorTest { + @Rule public TestName name = new TestName(); + @Test + public void testCmdlineFailPMMLModelImport() throws Exception{ + URL resource = PmmlPredictorTest.class.getResource("/testSVM2.txt"); + File samples_file = Paths.get(resource.toURI()).toFile(); + String[] arguments = new String[] {"--input-model=" + null, + "--samples-file=" + samples_file, + "--output-file=/tmp/predictions.csv", + "--convert-unknown-categorical-levels-to-na=True", + "--convert-invalid-numbers-to-na=True"}; + try { + PmmlPredictor.main(arguments); + Assert.assertEquals(1,2); + } catch (Exception e) { + if (e.getCause() != null) { + out.println(name.getMethodName() + " FAILED"); + Assert.fail("Fail"); + } + } + } + + @Test + public void testCmdlineFailPMMLInputSample() throws Exception{ + URL resource = PmmlPredictorTest.class.getResource("/modelForRf"); + File model_file = Paths.get(resource.toURI()).toFile(); + String[] arguments = new String[] {"--input-model=" + model_file, + "--samples-file=" + null, + "--output-file=/tmp/predictions.csv", + "--convert-unknown-categorical-levels-to-na=True", + "--convert-invalid-numbers-to-na=True"}; + try { + PmmlPredictor.main(arguments); + Assert.assertEquals(1,2); + } catch (Exception e) { + if (e.getCause() != null) { + out.println(name.getMethodName() + " FAILED"); + Assert.fail("Fail"); + } + } + } + + @Test + public void testCmdlinePMMLModelImport() throws Exception{ + URL resource = PmmlPredictorTest.class.getResource("/modelforRf"); + File model_file = Paths.get(resource.toURI()).toFile(); + resource = PmmlPredictorTest.class.getResource("/testSVM2.txt"); + File samples_file = Paths.get(resource.toURI()).toFile(); + String[] arguments = new String[] {"--input-model=" + model_file, + "--samples-file=" + samples_file, + "--output-file=/tmp/predictions.csv", + "--convert-unknown-categorical-levels-to-na=True", + "--convert-invalid-numbers-to-na=True"}; + try { + PmmlPredictor.main(arguments); + } catch (ParseException e) { + out.println(name.getMethodName() + "FAILED"); + Assert.fail("Fail"); + } + } + + @Test + public void testPMMLModelLoad() throws Exception{ + PmmlPredictor predComp = new PmmlPredictor(); + List parentObjs = new ArrayList(); + + URL resource = PmmlPredictorTest.class.getResource("/modelForRf"); + File model_file = Paths.get(resource.toURI()).toFile(); + + out.println(model_file.getAbsolutePath()); + + resource = PmmlPredictorTest.class.getResource("/testSVM2.txt"); + File samples_file = Paths.get(resource.toURI()).toFile(); + + Map params = new HashMap<>(); + params.put("input_model", model_file.getAbsolutePath()); + params.put("output_file", "/tmp/predictions.csv"); + params.put("convert_unknown_categorical_levels_to_na", true); + params.put("convert_invalid_numbers_to_na", true); + + parentObjs.add(samples_file.getAbsolutePath()); + predComp.configure(params); + + try { + predComp.loadModel(parentObjs); + } catch (Exception e) { + out.println(e.getMessage()); + out.println(name.getMethodName() + " FAILED"); + Assert.fail("Fail"); + } + } + + @Test + public void testFailedPMMLModelLoad() throws Exception{ + PmmlPredictor predComp = new PmmlPredictor(); + URL resource = PmmlPredictorTest.class.getResource("/testSVM2.txt"); + File samples_file = Paths.get(resource.toURI()).toFile(); + List parentObjs = new ArrayList(); + + Map params = new HashMap<>(); + params.put("input_model", "BAD_PATH"); + params.put("output_file", "/tmp/predictions.csv"); + params.put("convert_unknown_categorical_levels_to_na", true); + params.put("convert_invalid_numbers_to_na", true); + + predComp.configure(params); + parentObjs.add(samples_file.getAbsolutePath()); + + try { + predComp.loadModel(parentObjs); + out.println(name.getMethodName() + " FAILED"); + Assert.fail("Fail"); + } catch (Exception e) { + // expected exception cause : "null" + if(e.getCause() != null) { + Assert.fail("Fail"); + } + } + } + + @Test + public void testPMMLPredict() throws Exception{ + PmmlPredictor predComp = new PmmlPredictor(); + List parentObjs = new ArrayList(); + + URL resource = PmmlPredictorTest.class.getResource("/modelForRf"); + File model_file = Paths.get(resource.toURI()).toFile(); + + resource = PmmlPredictorTest.class.getResource("/testSVM2.txt"); + File samples_file = Paths.get(resource.toURI()).toFile(); + + Map params = new HashMap<>(); + params.put("input_model", model_file.getAbsolutePath()); + params.put("output_file", "/tmp/predictions.csv"); + params.put("convert_unknown_categorical_levels_to_na", true); + params.put("convert_invalid_numbers_to_na", true); + + parentObjs.add(samples_file.getAbsolutePath()); + predComp.configure(params); + + try { + parentObjs.add(samples_file.getAbsolutePath()); + predComp.materialize(parentObjs); + } catch (Exception e) { + out.println(name.getMethodName() + " FAILED"); + Assert.fail("Fail"); + } + } + + @Test + public void testPMMLPredictResults() throws Exception{ + String output_file = "/tmp/predictions.csv"; + + try { + URL resource = PmmlPredictorTest.class.getResource("/predictions.csv"); + + String diff_results = Paths.get(resource.toURI()).toString(); + BufferedReader s1 = new BufferedReader(new FileReader(diff_results)); + BufferedReader s2 = new BufferedReader(new FileReader(output_file)); + + // For now we compare if both the file contents are identical + // this check will need to get smarter, to allow for tolerance/thresholds + + String s1_line = null; + String s2_line = null; + + for(;((((s1_line = s1.readLine()) != null) + && ((s2_line = s2.readLine()) != null)) + && s1_line.equals(s2_line));) { + } + + if (s1_line == null) { + s2_line = s2.readLine(); + if (s1_line != s2_line) { + Assert.fail("File diff failed"); + } + } + } catch (Exception e) { + out.println(e.getMessage()); + out.println(name.getMethodName() + " FAILED"); + Assert.fail("Fail"); + } + } +} diff --git a/components/Java/pmml_predictor/src/test/resources/modelForRf b/components/Java/pmml_predictor/src/test/resources/modelForRf new file mode 100644 index 0000000..9b836ac --- /dev/null +++ b/components/Java/pmml_predictor/src/test/resources/modelForRf @@ -0,0 +1,322 @@ + + +
+ + 2017-11-06T18:28:11Z +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 + 0 + + + 1 + 1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
diff --git a/components/Java/pmml_predictor/src/test/resources/predictions.csv b/components/Java/pmml_predictor/src/test/resources/predictions.csv new file mode 100644 index 0000000..5c659b8 --- /dev/null +++ b/components/Java/pmml_predictor/src/test/resources/predictions.csv @@ -0,0 +1,11 @@ +index,c4,c11,c5,c6,c13,c24,c7,c23,c12,c0,pmml(prediction),prediction,probability(0),probability(1) +0,0.9612726066144277,0.12297559578341737,0.007106799402097768,0.050828312615106515,0.16766480457512672,0.061588364992562886,0.20895518489642073,0.0,8.733321092584783E-5,"ProbabilityDistribution{result=0, probability_entries=[0=0.9841539763197852, 1=0.01584602368021483]}",0,0.0,0.9841539763197852,0.01584602368021483 +1,0.9885139155467177,0.037672213559525526,0.0022118100179998154,0.050828312615106515,0.0503993379969236,0.02778366689251417,0.20895518489642073,0.0,9.92422851430089E-6,"ProbabilityDistribution{result=0, probability_entries=[0=0.9841539763197852, 1=0.01584602368021483]}",0,0.0,0.9841539763197852,0.01584602368021483 +2,0.9944094227037059,0.027967800438195047,0.004471965610163561,0.050828312615106515,0.03687104886371097,0.022812075983475958,0.20895518489642073,0.0,9.92422851430089E-6,"ProbabilityDistribution{result=0, probability_entries=[0=0.9841539763197852, 1=0.01584602368021483]}",0,0.0,0.9841539763197852,0.01584602368021483 +3,0.9948160094041879,0.009004422666677132,0.018629161244501178,0.050953814621563565,0.01071807532470385,0.013409514925230988,0.20895518489642073,0.10101009998979697,2.3818148434322133E-5,"ProbabilityDistribution{result=0, probability_entries=[0=0.9812768147216934, 1=0.018723185278306662]}",0,0.0,0.9812768147216934,0.018723185278306662 +4,0.9914616691252118,0.0027045085747970186,0.005841756878141589,0.05107931662802063,0.0020463898055107653,0.011448439619960053,0.20895518489642073,0.0,9.92422851430089E-6,"ProbabilityDistribution{result=0, probability_entries=[0=0.9841539763197852, 1=0.01584602368021483]}",0,0.0,0.9841539763197852,0.01584602368021483 +5,0.9999999898353326,0.002895415062429749,0.0030417423744806206,0.05120481863447768,0.002063474116273226,0.010907818860128607,0.20895518489642073,0.0,8.038625096583721E-5,"ProbabilityDistribution{result=0, probability_entries=[0=0.9841539763197852, 1=0.01584602368021483]}",0,0.0,0.9841539763197852,0.01584602368021483 +6,0.9847529885672598,0.003977218492348556,0.01996672212970325,0.05120481863447768,0.002568548467177978,0.01937754409748794,0.20895518489642073,0.11111110998877667,0.005768954035363107,"ProbabilityDistribution{result=0, probability_entries=[0=0.9961538461538462, 1=0.0038461538461538464]}",0,0.0,0.9961538461538462,0.0038461538461538464 +7,0.9834315817906935,0.012981641159025689,0.005801468899671647,0.05158132465384884,0.016051797144744126,0.017098456580551452,0.20895518489642073,0.0,6.153021678866551E-5,"ProbabilityDistribution{result=0, probability_entries=[0=0.9841539763197852, 1=0.01584602368021483]}",0,0.0,0.9841539763197852,0.01584602368021483 +8,0.9817035883136452,0.013013458906964476,0.003078001555103568,0.05158132465384884,0.016481079280629963,0.018232700135491938,0.20895518489642073,0.0,2.8780262691472577E-5,"ProbabilityDistribution{result=0, probability_entries=[0=0.9841539763197852, 1=0.01584602368021483]}",0,0.0,0.9841539763197852,0.01584602368021483 +9,0.17859320818669236,0.3406726271806079,0.7639003597685701,0.09061244866199235,0.43691012532935725,0.347672150606898,0.30410442087604084,0.0,2.12378490206039E-4,"ProbabilityDistribution{result=0, probability_entries=[0=0.9537037037037037, 1=0.046296296296296294]}",0,0.0,0.9537037037037037,0.046296296296296294 diff --git a/components/Java/pmml_predictor/src/test/resources/testSVM2.txt b/components/Java/pmml_predictor/src/test/resources/testSVM2.txt new file mode 100644 index 0000000..10cc994 --- /dev/null +++ b/components/Java/pmml_predictor/src/test/resources/testSVM2.txt @@ -0,0 +1,11 @@ +c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13,c14,c15,c16,c17,c18,c19,c20,c21,c22,c23,c24 +0,5.020201969492909394e-02,3.894080390330157543e-02,5.526301246575666681e-01,9.612726066144277048e-01,7.106799402097767858e-03,5.082831261510651483e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,5.077430819747274148e-04,1.229755957834173657e-01,8.733321092584783155e-05,1.676648045751267246e-01,0,0,0,0,0,0,0,0,0,0,6.158836499256288566e-02 +0,2.616161589735741325e-02,2.647974665424506976e-02,0,9.885139155467177474e-01,2.211810017999815432e-03,5.082831261510651483e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,5.077430819747274284e-05,3.767221355952552614e-02,9.924228514300889486e-06,5.039933799692360161e-02,0,0,0,0,0,0,0,0,0,0,2.778366689251416874e-02 +1,1.979797959800020518e-02,3.271027527877332086e-02,0,9.944094227037059142e-01,4.471965610163560791e-03,5.082831261510651483e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,5.077430819747274284e-05,2.796780043819504721e-02,9.924228514300889486e-06,3.687104886371096735e-02,0,0,0,0,0,0,0,0,0,0,2.281207598347595802e-02 +1,1.939393919804101654e-02,3.271027527877332086e-02,0,9.948160094041879065e-01,1.862916124450117797e-02,5.095381462156356456e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,1.523229245924182353e-04,9.004422666677131676e-03,2.381814843432213341e-05,1.071807532470384940e-02,0,0,0,0,0,0,0,0,0,1.010100999897969726e-01,1.340951492523098816e-02 +1,2.151515129782675256e-02,2.024921802971681867e-02,5.526301246575666681e-01,9.914616691252118308e-01,5.841756878141588991e-03,5.107931662802062817e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,5.077430819747274284e-05,2.704508574797018575e-03,9.924228514300889486e-06,2.046389805510765315e-03,0,0,0,0,0,0,0,0,0,0,1.144843961996005327e-02 +0,1.434343419855116895e-02,1.869158587358475329e-02,2.105257617743111487e-01,9.999999898353325589e-01,3.041742374480620615e-03,5.120481863447767790e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,4.569687737772547059e-04,2.895415062429749012e-03,8.038625096583720874e-05,2.063474116273225894e-03,0,0,0,0,0,0,0,0,0,0,1.090781886012860691e-02 +0,2.898989869707173372e-02,3.894080390330157543e-02,0,9.847529885672597905e-01,1.996672212970325086e-02,5.120481863447767790e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,3.452652957428146378e-03,3.977218492348556268e-03,5.768954035363107156e-03,2.568548467177978085e-03,0,0,0,0,0,0,0,0,0,1.111111099887766712e-01,1.937754409748794135e-02 +0,2.949494919702071519e-02,2.024921802971681867e-02,5.526301246575666681e-01,9.834315817906934543e-01,5.801468899671646806e-03,5.158132465384884097e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,3.554201573823091796e-04,1.298164115902568881e-02,6.153021678866551244e-05,1.605179714474412600e-02,0,0,0,0,0,0,0,0,0,0,1.709845658055145212e-02 +0,3.323232289664320577e-02,1.869158587358475329e-02,0,9.817035883136452368e-01,3.078001555103568061e-03,5.158132465384884097e-02,2.089551848964207315e-01,9.999900000999989436e-01,0,1.523229245924182353e-04,1.301345890696447606e-02,2.878026269147257731e-05,1.648107928062996289e-02,0,0,0,0,0,0,0,0,0,0,1.823270013549193810e-02 +0,7.945454465197429039e-01,5.654204726759388677e-01,0,1.785932081866923593e-01,7.639003597685700697e-01,9.061244866199234804e-02,3.041044208760408374e-01,9.999900000999989436e-01,0,9.647118557519821207e-04,3.406726271806079231e-01,2.123784902060390015e-04,4.369101253293572462e-01,0,0,0,0,0,0,0,0,0,0,3.476721506068979894e-01 diff --git a/components/Python/RestModelServing/restful_h2o_serving/h2o_model_serving-1.0-jar-with-dependencies.jar b/components/Python/RestModelServing/restful_h2o_serving/h2o_model_serving-1.0-jar-with-dependencies.jar deleted file mode 100644 index 5e3843e..0000000 Binary files a/components/Python/RestModelServing/restful_h2o_serving/h2o_model_serving-1.0-jar-with-dependencies.jar and /dev/null differ diff --git a/components/Python/RestModelServing/restful_h2o_serving/__init__.py b/components/Python/RestModelServing/restful_pmml_serving_source/__init__.py similarity index 100% rename from components/Python/RestModelServing/restful_h2o_serving/__init__.py rename to components/Python/RestModelServing/restful_pmml_serving_source/__init__.py diff --git a/components/Python/RestModelServing/restful_pmml_serving_source/client_test.sh b/components/Python/RestModelServing/restful_pmml_serving_source/client_test.sh new file mode 100755 index 0000000..715cdab --- /dev/null +++ b/components/Python/RestModelServing/restful_pmml_serving_source/client_test.sh @@ -0,0 +1,18 @@ +#!/bin/bash +if [ -z $REFLEX_HOME ]; +then PYTHONPATH="../:$HOME/workspace/mlpiper/mlcomp/dist/mlcomp-py2.egg:$HOME/workspace/mlpiper/mlops/dist/mlops-py2.egg"; +else + if [ -d $REFLEX_HOME ]; + then + PYTHONPATH="../:$REFLEX_HOME/sub/mlpiper/mlcomp/dist/mlcomp-py2.egg:$REFLEX_HOME/sub/mlpiper/mlops/dist/mlops-py2.egg"; + else + echo "REFLEX_HOME is invalid $REFLEX_HOME" + exit 1 + fi +fi + +script_name=$(basename ${BASH_SOURCE[0]}) +script_dir=$(realpath $(dirname ${BASH_SOURCE[0]})) + + +PYTHONPATH=$PYTHONPATH python pmml_restful_serving.py 8888 $script_dir/model/modelForRf diff --git a/components/Python/RestModelServing/restful_h2o_serving/component.json b/components/Python/RestModelServing/restful_pmml_serving_source/component.json similarity index 90% rename from components/Python/RestModelServing/restful_h2o_serving/component.json rename to components/Python/RestModelServing/restful_pmml_serving_source/component.json index 0821f50..1d467e5 100644 --- a/components/Python/RestModelServing/restful_h2o_serving/component.json +++ b/components/Python/RestModelServing/restful_pmml_serving_source/component.json @@ -3,11 +3,11 @@ "engineType": "RestModelServing", "userStandalone": false, "language": "Python", - "name": "restful_h2o_serving", - "label": "H2O RESTful serving", - "program": "h2o_restful_serving.py", + "name": "restful_pmml_serving", + "label": "PMML RESTful serving", + "program": "pmml_restful_serving.py", "modelBehavior": "ModelConsumer", - "componentClass": "H2oRESTfulServing", + "componentClass": "org.mlpiper.mlhub.components.restful.PmmlModelServing", "group": "Connectors", "useMLOps": true, "inputInfo": [], diff --git a/components/Python/RestModelServing/restful_pmml_serving_source/create_component.sh b/components/Python/RestModelServing/restful_pmml_serving_source/create_component.sh new file mode 100755 index 0000000..d723f5c --- /dev/null +++ b/components/Python/RestModelServing/restful_pmml_serving_source/create_component.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +set -e + +script_name=$(basename ${BASH_SOURCE[0]}) +script_dir=$(realpath $(dirname ${BASH_SOURCE[0]})) +target_component_dir=${script_dir%_*} + +cd $script_dir +mvn clean install + +rm -rf $target_component_dir +mkdir -p $target_component_dir + +touch $target_component_dir/__init__.py +cp $script_dir/component.json $target_component_dir/ +cp $script_dir/target/*jar-with-dependencies.jar $target_component_dir/ +cp $script_dir/*.py $target_component_dir/ + +printf "\e[92m\nComponent was generated in:\e[1;95m $target_component_dir\e[0m\n\n" diff --git a/components/Python/RestModelServing/restful_pmml_serving_source/curl_test.sh b/components/Python/RestModelServing/restful_pmml_serving_source/curl_test.sh new file mode 100755 index 0000000..0a17ba6 --- /dev/null +++ b/components/Python/RestModelServing/restful_pmml_serving_source/curl_test.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +curl -s -G http://localhost:8888/predict -d '{"c11":1.0,"c10":1.0,"c13":1.0,"c12":1.0,"c15":1.0,"c14":1.0,"c17":1.0,"c16":1.0,"c19":1.0,"c18":1.0,"c20":1.0,"c22":1.0,"c21":1.0,"c24":1.0,"c23":1.0,"c0":1.0,"c1":1.0,"c2":1.0,"c3":1.0,"c4":1.0,"c5":1.0,"c6":1.0,"c7":1.0,"c8":1.0,"c9":1.0}' + diff --git a/components/Python/RestModelServing/restful_pmml_serving_source/mlcomp.jar b/components/Python/RestModelServing/restful_pmml_serving_source/mlcomp.jar new file mode 100644 index 0000000..e97bf57 Binary files /dev/null and b/components/Python/RestModelServing/restful_pmml_serving_source/mlcomp.jar differ diff --git a/components/Python/RestModelServing/restful_pmml_serving_source/model/modelForRf b/components/Python/RestModelServing/restful_pmml_serving_source/model/modelForRf new file mode 100644 index 0000000..9b836ac --- /dev/null +++ b/components/Python/RestModelServing/restful_pmml_serving_source/model/modelForRf @@ -0,0 +1,322 @@ + + +
+ + 2017-11-06T18:28:11Z +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 + 0 + + + 1 + 1 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
diff --git a/components/Python/RestModelServing/restful_h2o_serving/h2o_restful_serving.py b/components/Python/RestModelServing/restful_pmml_serving_source/pmml_restful_serving.py similarity index 79% rename from components/Python/RestModelServing/restful_h2o_serving/h2o_restful_serving.py rename to components/Python/RestModelServing/restful_pmml_serving_source/pmml_restful_serving.py index b70f4ee..a678d1c 100644 --- a/components/Python/RestModelServing/restful_h2o_serving/h2o_restful_serving.py +++ b/components/Python/RestModelServing/restful_pmml_serving_source/pmml_restful_serving.py @@ -1,32 +1,29 @@ -from contextlib import closing import glob import logging import os import signal import socket -import time import subprocess - -from py4j.java_gateway import JavaGateway, get_field -from py4j.java_gateway import GatewayParameters, CallbackServerParameters - +import time +from contextlib import closing from parallelm.common.mlcomp_exception import MLCompException from parallelm.components.restful.flask_route import FlaskRoute from parallelm.components.restful_component import RESTfulComponent -from parallelm.pipeline.components_desc import ComponentsDesc from parallelm.pipeline.component_dir_helper import ComponentDirHelper +from parallelm.pipeline.components_desc import ComponentsDesc +from py4j.java_gateway import GatewayParameters, CallbackServerParameters +from py4j.java_gateway import JavaGateway, get_field -class H2oRESTfulServing(RESTfulComponent): - +class PmmlRESTfulServing(RESTfulComponent): JAVA_COMPONENT_ENTRY_POINT_CLASS = "com.parallelm.mlcomp.ComponentEntryPoint" - JAVA_COMPONENT_CLASS_NAME = "com.parallelm.components.restful.H2oModelServing" + JAVA_COMPONENT_CLASS_NAME = "org.mlpiper.mlhub.components.restful.PmmlModelServing" class Java: implements = ["org.mlpiper.mlops.MLOps"] def __init__(self, engine): - super(H2oRESTfulServing, self).__init__(engine) + super(PmmlRESTfulServing, self).__init__(engine) self._verbose = self._logger.isEnabledFor(logging.DEBUG) self._gateway = None @@ -39,22 +36,26 @@ def post_fork_callback(self): self._prefix_msg = "wid: {}, ".format(self.get_wid()) self._launch_custom_java_gateway() - def _launch_custom_java_gateway(self): - self._run_java_server_entry_point() + def _launch_custom_java_gateway(self, local_mode=False): + self._run_java_server_entry_point(local_mode) self._setup_py4j_client_connection() - def _run_java_server_entry_point(self): + def _run_java_server_entry_point(self, local_mode=False): comp_realpath = os.path.realpath(__file__) comp_filename = os.path.basename(comp_realpath) comp_dirname = os.path.basename(os.path.dirname(comp_realpath)) - comp_module_name = "{}.{}.{}".format( - ComponentsDesc.CODE_COMPONETS_MODULE_NAME, - comp_dirname, - os.path.splitext(comp_filename)[0]) - if self._verbose: - self._logger.debug("comp_module_name: {}".format(comp_module_name)) - comp_helper = ComponentDirHelper(comp_module_name, comp_filename) - comp_dir = comp_helper.extract_component_out_of_egg() + if local_mode is False: + comp_module_name = "{}.{}.{}".format( + ComponentsDesc.CODE_COMPONETS_MODULE_NAME, + comp_dirname, + os.path.splitext(comp_filename)[0]) + if self._verbose: + self._logger.info("comp_module_name: {}".format(comp_module_name)) + comp_helper = ComponentDirHelper(comp_module_name, comp_filename) + comp_dir = comp_helper.extract_component_out_of_egg() + else: + comp_dir = os.path.dirname(comp_realpath) + if self._verbose: self._logger.debug(self._prefix_msg + "comp_dir: {}".format(comp_dir)) @@ -64,9 +65,9 @@ def _run_java_server_entry_point(self): if self._verbose: self._logger.info(self._prefix_msg + "java_cp: {}".format(java_cp)) - self._java_port = H2oRESTfulServing.find_free_port() - cmd = ["java", "-cp", java_cp, H2oRESTfulServing.JAVA_COMPONENT_ENTRY_POINT_CLASS, "--class-name", - H2oRESTfulServing.JAVA_COMPONENT_CLASS_NAME, "--port", str(self._java_port)] + self._java_port = PmmlRESTfulServing.find_free_port() + cmd = ["java", "-cp", java_cp, PmmlRESTfulServing.JAVA_COMPONENT_ENTRY_POINT_CLASS, "--class-name", + PmmlRESTfulServing.JAVA_COMPONENT_CLASS_NAME, "--port", str(self._java_port)] if self._verbose: self._logger.debug(self._prefix_msg + "java gateway cmd: " + " ".join(cmd)) @@ -121,6 +122,10 @@ def load_model_callback(self, model_path, stream, version): if self._verbose: self._logger.debug(self._prefix_msg + "load model callback, path: {}".format(model_path)) + if self._component_via_py4j is None: + self.set_wid(0) + self._launch_custom_java_gateway(local_mode=True) + if self._component_via_py4j: result = self._component_via_py4j.loadModel(model_path) if self._verbose: @@ -138,17 +143,19 @@ def setStat(self, stat_name, stat_value): @FlaskRoute('/predict', raw=True) def predict(self, query_string, body_data): if self._verbose: - self._logger.debug(self._prefix_msg + "predict, query_string: {}, body_data: {}".format(query_string, body_data)) + self._logger.debug( + self._prefix_msg + "predict, query_string: {}, body_data: {}".format(query_string, body_data)) if self._model_loaded: result = self._component_via_py4j.predict(query_string, body_data) returned_code = get_field(result, "returned_code") json = get_field(result, "json") if self._verbose: - self._logger.debug(self._prefix_msg + "got response ... code: {}, json: {}".format(returned_code, str(json))) - return(returned_code, str(json)) + self._logger.debug( + self._prefix_msg + "got response ... code: {}, json: {}".format(returned_code, str(json))) + return (returned_code, str(json)) else: - return 404, '{"error": "H2O model was not loaded yet!"}' + return 404, '{"error": "Pmml model was not loaded yet!"}' def cleanup_callback(self): """ @@ -195,4 +202,4 @@ def find_free_port(): comp_module_name = os.path.splitext(os.path.basename(__file__))[0] ComponentsDesc.CODE_COMPONETS_MODULE_NAME = comp_module_name - H2oRESTfulServing.run(port=args.port, model_path=args.input_model) + PmmlRESTfulServing.run(port=args.port, model_path=args.input_model) diff --git a/components/Python/RestModelServing/restful_pmml_serving_source/pom.xml b/components/Python/RestModelServing/restful_pmml_serving_source/pom.xml new file mode 100644 index 0000000..62990a1 --- /dev/null +++ b/components/Python/RestModelServing/restful_pmml_serving_source/pom.xml @@ -0,0 +1,127 @@ + + 4.0.0 + org.mlpiper.mlhub.components.restful + pmml_model_serving + jar + 1.0 + restful_pmml_serving + http://maven.apache.org + + + + + junit + junit + 3.8.1 + test + + + com.parallelm.mlcomp + mlcomp + 1.0 + compile + + + com.google.code.gson + gson + 2.8.5 + compile + + + net.sourceforge.argparse4j + argparse4j + 0.8.1 + compile + + + org.apache.maven.plugins + maven-install-plugin + 3.0.0-M1 + + + org.jpmml + pmml-evaluator + 1.3.5 + + + org.jpmml + pmml-model + 1.3.7 + + + org.jpmml + pmml-agent + + + + + javax.xml.bind + jaxb-api + 2.3.0 + + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + javax.activation + javax.activation-api + 1.2.0 + + + + + + production + + true + + + + + org.apache.maven.plugins + maven-compiler-plugin + + 1.8 + 1.8 + + + + org.apache.maven.plugins + maven-assembly-plugin + 2.4.1 + + + jar-with-dependencies + + false + + + + org.mlpiper.mlhub.components.restful.PmmlModelServing + + + + + + + make-assembly + package + + single + + + + + + + + + diff --git a/components/Python/RestModelServing/restful_pmml_serving_source/src/main/java/org/mlpiper/mlhub/components/restful/PmmlModelServing.java b/components/Python/RestModelServing/restful_pmml_serving_source/src/main/java/org/mlpiper/mlhub/components/restful/PmmlModelServing.java new file mode 100644 index 0000000..d111e9a --- /dev/null +++ b/components/Python/RestModelServing/restful_pmml_serving_source/src/main/java/org/mlpiper/mlhub/components/restful/PmmlModelServing.java @@ -0,0 +1,134 @@ +package org.mlpiper.mlhub.components.restful; + +import com.google.gson.Gson; +import java.io.*; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.parallelm.mlcomp.MCenterRestfulComponent; +import org.dmg.pmml.FieldName; +import org.jpmml.evaluator.Evaluator; +import org.jpmml.evaluator.FieldValue; +import org.jpmml.evaluator.InputField; +import org.jpmml.evaluator.ModelEvaluatorFactory; +import org.jpmml.model.ImportFilter; +import org.jpmml.model.JAXBUtil; +import org.xml.sax.InputSource; + +import javax.xml.transform.sax.SAXSource; + + +public class PmmlModelServing extends MCenterRestfulComponent { + private int wid = 0; + private boolean isVerbose = false; + private Evaluator evaluator; + + public void setEnvAttributes(int wid, boolean isVerbose) { + this.wid = wid; + this.isVerbose = isVerbose; + } + + public void loadModel(String modelPath) throws IOException { + try { + ModelEvaluatorFactory evaluatorInstance = ModelEvaluatorFactory.newInstance(); + List modelList = Files.readAllLines(Paths.get(modelPath)); + String modelString = String.join("\n", modelList); + SAXSource src = + JAXBUtil.createFilteredSource( + new InputSource(new StringReader(modelString)), new ImportFilter()); + evaluator = evaluatorInstance.newModelEvaluator(JAXBUtil.unmarshalPMML(src)); + } catch (Exception e) { + throw new IOException("Invalid model at " + modelPath); + } + if (isVerbose) { + System.out.println(String.format("(java side) wid: %d, model loaded successfully, " + + "model type %s", wid, evaluator.getMiningFunction().toString())); + } + } + + public MCenterRestfulComponent.Result predict(String query_string, String body_data) throws Exception { + + Gson gson = new Gson(); + Map predictionVector = new HashMap(); + String jsonData = body_data != null && !body_data.isEmpty() ? body_data : query_string; + predictionVector = (Map) gson.fromJson(jsonData, predictionVector.getClass()); + + + List inputField = evaluator.getInputFields(); + Map row = new HashMap(); + + for (InputField field : inputField) { + try { + FieldValue fv = field.prepare(predictionVector.get(field.getName().getValue())); + row.put(field.getName(), fv); + } catch (Exception e) { + throw new InvalidObjectException("Json has missing or invalid value for " + + field.getName().getValue()); + } + } + + try { + Map output = evaluator.evaluate(row); + MCenterRestfulComponent.Result result = new MCenterRestfulComponent.Result(); + + Map resultAttrs = new HashMap(); + for(FieldName key: output.keySet()) { + resultAttrs.put(key.getValue(), output.get(key)); + } + + result.returned_code = 200; + result.json = gson.toJson(resultAttrs); + return result; + } catch (Exception e) { + throw new Exception("Evaluation failed with " + e.getMessage()); + } + } + + public static void main(String[] args) throws Exception { + + PmmlModelServing modelServing = new PmmlModelServing(); + + modelServing.loadModel("./model/modelForRf"); + + Map data = new HashMap<>(); + data.put("c0", 1.0); + data.put("c1", 1.0); + data.put("c2", 1.0); + data.put("c3", 1.0); + data.put("c4", 1.0); + data.put("c5", 1.0); + data.put("c6", 1.0); + data.put("c7", 1.0); + data.put("c8", 1.0); + data.put("c9", 1.0); + data.put("c10", 1.0); + data.put("c11", 1.0); + data.put("c12", 1.0); + data.put("c13", 1.0); + data.put("c14", 1.0); + data.put("c15", 1.0); + data.put("c16", 1.0); + data.put("c17", 1.0); + data.put("c18", 1.0); + data.put("c19", 1.0); + data.put("c20", 1.0); + data.put("c21", 1.0); + data.put("c22", 1.0); + data.put("c23", 1.0); + data.put("c24", 1.0); + + Gson gson = new Gson(); + String json = gson.toJson(data); + System.out.println(json); + + MCenterRestfulComponent.Result result = modelServing.predict(null, json); + + System.out.println("Retrurned code: " + result.returned_code); + System.out.println("Retrurned json: " + result.json); + + System.out.println(); + } +}