【PMML】Java调用Python算法模型

2021年09月12日    Author:Guofei

文章归类: 0x11_算法平台    文章编号: 173


版权声明:本文作者是郭飞。转载随意,但需要标明原文链接,并通知本人
原文链接:https://www.guofei.site/2021/09/12/pmml.html

Edit

描述

Python 训练算法模型,保存为 PMML 文件。然后 Java 调用 PMML 文件做预测,这样线上就变成纯 JAVA 项目了。

相关项目

  • Python 端的项目:https://github.com/jpmml/sklearn2pmml
  • Java端 项目:https://github.com/jpmml/jpmml-model
  • 对应的库:https://repo1.maven.org/maven2/org/jpmml/

安装

  • Python
    pip install scikit-learn==0.24.2
    pip install sklearn2pmml==0.74.4
    
  • JAVA ```xml
org.jpmml pmml-evaluator 1.5.9







## 使用



python 训练模型
```python
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml import sklearn2pmml

iris = load_iris()
X, y = iris.data, iris.target
pipeline = PMMLPipeline([("classifier", tree.DecisionTreeClassifier())])  # 用决策树分类
pipeline.fit(X, y)
sklearn2pmml(pipeline, "iris.pmml", with_repr=True)

java 调用模型做预测

package com.alipay.tianqian;

import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.*;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.*;
import java.util.*;


public class TestPmml {

    public static void main(String[] args) throws JAXBException, SAXException, IOException {
        TestPmml obj = new TestPmml();

        Evaluator model = new LoadingModelEvaluatorBuilder()
                .load(new File("src/iris.pmml"))
                .build();

        List<Map<String, Object>> inputs = new ArrayList<>();
        inputs.add(obj.getRawMap(5.1, 3.5, 1.4, 0.2));
        inputs.add(obj.getRawMap(4.9, 3, 1.5, 4));
        for (Map<String, Object> input : inputs) {
            Map<String, Object> output = obj.predict(model, input);
            System.out.println("X=" + input + " -> y=" + output.get("y"));
        }
    }



    private Map<String, Object> getRawMap(Object a, Object b, Object c, Object d) {
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        return data;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> predict(Evaluator evaluator, Map<String, Object> data) {
        Map<FieldName, FieldValue> input = getFieldMap(evaluator, data);
        Map<String, Object> output = evaluate(evaluator, input);
        return output;
    }

    /**
     * 把原始输入转换成PMML格式的输入。
     */
    private Map<FieldName, FieldValue> getFieldMap(Evaluator evaluator, Map<String, Object> input) {
        List<InputField> inputFields = evaluator.getInputFields();
        Map<FieldName, FieldValue> map = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField field : inputFields) {
            FieldName fieldName = field.getName();
            Object rawValue = input.get(fieldName.getValue());
            FieldValue value = field.prepare(rawValue);
            map.put(fieldName, value);
        }
        return map;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> evaluate(Evaluator evaluator, Map<FieldName, FieldValue> input) {
        Map<FieldName, ?> results = evaluator.evaluate(input);
        List<TargetField> targetFields = evaluator.getTargetFields();
        Map<String, Object> output = new LinkedHashMap<String, Object>();
        for (TargetField field : targetFields) {
            FieldName fieldName = field.getName();
            Object value = results.get(fieldName);
            if (value instanceof Computable) {
                Computable computable = (Computable) value;
                value = computable.getResult();
            }
            output.put(fieldName.getValue(), value);
        }
        return output;
    }
}

旧版本

版本

  • PMML 文件:4.3版本
  • python 版本
    pip install sklearn2pmml==0.56.2
    pip install scikit-learn==0.22.2
    
  • java 版本: ```xml
org.jpmml pmml-evaluator 1.4.1 org.jpmml pmml-evaluator-extension 1.4.1



### 旧版本的使用

Python 训练模型,步骤同上


Java调模型,做预测
```java
package com.alipay.tianqian.service;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;


public class TestPmml2 {
    public static void main(String args[]) throws FileNotFoundException, JAXBException, SAXException {
        String fp = "src/iris.pmml";
        TestPmml2 obj = new TestPmml2();
        Evaluator model = obj.loadPmml(fp);
        List<Map<String, Object>> inputs = new ArrayList<>();
        inputs.add(obj.getRawMap(5.1, 3.5, 1.4, 0.2));
        inputs.add(obj.getRawMap(4.9, 3, 1.5, 0.9));
        for (int i = 0; i < inputs.size(); i++) {
            Map<String, Object> output = obj.predict(model, inputs.get(i));
            System.out.println("X=" + inputs.get(i) + " -> y=" + output.get("y"));
        }
    }

    private Evaluator loadPmml(String fp) throws FileNotFoundException, JAXBException, SAXException {
        InputStream is = new FileInputStream(fp);
        PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        try {
            is.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        ModelEvaluatorFactory factory = ModelEvaluatorFactory.newInstance();
        return factory.newModelEvaluator(pmml);
    }

    private Map<String, Object> getRawMap(Object a, Object b, Object c, Object d) {
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        return data;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> predict(Evaluator evaluator, Map<String, Object> data) {
        Map<FieldName, FieldValue> input = getFieldMap(evaluator, data);
        Map<String, Object> output = evaluate(evaluator, input);
        return output;
    }

    /**
     * 把原始输入转换成PMML格式的输入。
     */
    private Map<FieldName, FieldValue> getFieldMap(Evaluator evaluator, Map<String, Object> input) {
        List<InputField> inputFields = evaluator.getInputFields();
        Map<FieldName, FieldValue> map = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField field : inputFields) {
            FieldName fieldName = field.getName();
            Object rawValue = input.get(fieldName.getValue());
            FieldValue value = field.prepare(rawValue);
            map.put(fieldName, value);
        }
        return map;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> evaluate(Evaluator evaluator, Map<FieldName, FieldValue> input) {
        Map<FieldName, ?> results = evaluator.evaluate(input);
        List<TargetField> targetFields = evaluator.getTargetFields();
        Map<String, Object> output = new LinkedHashMap<String, Object>();
        for (int i = 0; i < targetFields.size(); i++) {
            TargetField field = targetFields.get(i);
            FieldName fieldName = field.getName();
            Object value = results.get(fieldName);
            if (value instanceof Computable) {
                Computable computable = (Computable) value;
                value = computable.getResult();
            }
            output.put(fieldName.getValue(), value);
        }
        return output;
    }
}

您的支持将鼓励我继续创作!
WeChatQR AliPayQR qr_wechat