潘志宝
2024-09-30 536b8e85745a42806214ee708297642f493d3cd8
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.predict.impl;
2
3 import com.alibaba.fastjson.JSONArray;
4 import com.alibaba.fastjson.JSONObject;
5 import com.iail.IAILMDK;
6 import com.iail.model.IAILModel;
7 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
8 import com.iailab.module.model.mcs.pre.entity.MmModelResultstrEntity;
9 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
10 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
11 import com.iailab.module.model.mcs.pre.service.MmModelResultstrService;
12 import com.iailab.module.model.mdk.common.enums.TypeA;
13 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
14 import com.iailab.module.model.mdk.predict.PredictModelHandler;
15 import com.iailab.module.model.mdk.sample.SampleConstructor;
16 import com.iailab.module.model.mdk.sample.dto.SampleData;
17 import com.iailab.module.model.mdk.vo.PredictResultVO;
18 import lombok.extern.slf4j.Slf4j;
19 import org.springframework.beans.factory.annotation.Autowired;
20 import org.springframework.stereotype.Component;
21
22 import java.util.ArrayList;
23 import java.util.Date;
24 import java.util.HashMap;
25 import java.util.List;
26
27 /**
28  * @author PanZhibao
29  * @Description
30  * @createTime 2024年09月01日
31  */
32 @Slf4j
33 @Component
34 public class PredictModelHandlerImpl implements PredictModelHandler {
35
36     @Autowired
37     private MmModelArithSettingsService mmModelArithSettingsService;
38
39     @Autowired
40     private MmModelResultstrService mmModelResultstrService;
41
42     @Autowired
43     private SampleConstructor sampleConstructor;
44
45     @Override
46     public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException {
47         PredictResultVO result = new PredictResultVO();
48         if (predictModel == null) {
49             throw new ModelInvokeException("modelEntity is null");
50         }
51         String modelId = predictModel.getId();
52
53         try {
54             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime);
55             String modelPath = predictModel.getModelpath();
56             if (modelPath == null) {
57                 System.out.println("模型路径不存在,modelId=" + modelId);
58                 return null;
59             }
60             IAILModel newModelBean = composeNewModelBean(predictModel);
61             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
62             if (settings == null) {
63                 log.error("模型setting不存在,modelId=" + modelId);
64                 return null;
65             }
66             int portLength = sampleDataList.size();
67             Object[] param2Values = new Object[portLength + 2];
68             for (int i = 0; i < portLength; i++) {
69                 param2Values[i]=sampleDataList.get(i).getMatrix();
70             }
71             param2Values[portLength] = newModelBean.getDataMap().get("models");
72             param2Values[portLength+1] = settings;
73
74             log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################");
75             JSONObject jsonObjNewModelBean = new JSONObject();
76             jsonObjNewModelBean.put("newModelBean", newModelBean);
77             log.info(String.valueOf(jsonObjNewModelBean));
78             JSONObject jsonObjParam2Values = new JSONObject();
79             jsonObjParam2Values.put("param2Values", param2Values);
80             log.info(String.valueOf(jsonObjParam2Values));
81
82             //IAILMDK.run
83             HashMap<String, Object> modelResult = IAILMDK.run(newModelBean, param2Values);
84             //打印结果
85             JSONObject jsonObjResult = new JSONObject();
86             jsonObjResult.put("result", result);
87             log.info(String.valueOf(jsonObjResult));
88
89             MmModelResultstrEntity modelResultstr = mmModelResultstrService.getInfo(predictModel.getResultstrid());
90             log.info("模型计算完成:modelId=" + modelId + result);
91             double[][] temp = (double[][]) modelResult.get(modelResultstr.getResultstr());
92             result.setPredictMatrix(temp);
93             result.setPredictTime(predictTime);
94         } catch (Exception ex) {
95             log.error("IAILModel对象构造失败,modelId=" + modelId);
96             log.error(ex.getMessage());
97             log.error("调用发生异常,异常信息为:{}" , ex);
98             ex.printStackTrace();
99
100         }
101         return result;
102     }
103
104     /**
105      * 构造IAILMDK.run()方法的newModelBean参数
106      *
107      * @param predictModel
108      * @return
109      */
110     private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
111         IAILModel newModelBean = new IAILModel();
112         newModelBean.setClassName(predictModel.getClassname().trim());
113         newModelBean.setMethodName(predictModel.getMethodname().trim());
114         //构造参数类型
115         String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
116         Class<?>[] paramsArray = new Class[paArStr.length];
117         for (int i = 0; i < paArStr.length; i++) {
118             if ("[[D".equals(paArStr[i])) {
119                 paramsArray[i] = double[][].class;
120             } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
121                 paramsArray[i] = HashMap.class;
122             }
123         }
124         newModelBean.setParamsArray(paramsArray);
125         HashMap<String, Object> dataMap = new HashMap<>();
126         HashMap<String, String> models = new HashMap<>(1);
127         models.put("paramFile", predictModel.getModelpath());
128         dataMap.put("models", models);
129         newModelBean.setDataMap(dataMap);
130         return newModelBean;
131     }
132
133     /**
134      * 根据模型id获取参数map
135      *
136      * @param modelId
137      * @return
138      */
139     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
140         List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
141         HashMap<String, Object> result = new HashMap<>();
142         for (MmModelArithSettingsEntity entry : list) {
143             String valueType = entry.getValuetype().trim(); //去除两端空格
144             if ("int".equals(valueType)) {
145                 int value = Integer.parseInt(entry.getValue());
146                 result.put(entry.getKey(), value);
147             } else if ("double".equals(valueType)) {
148                 double value = Double.parseDouble(entry.getValue());
149                 result.put(entry.getKey(), value);
150             } else if ("string".equals(valueType)) {
151                 String value = entry.getValue();
152                 result.put(entry.getKey(), value);
153             } else if ("decimalArray".equals(valueType)) {
154                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
155                 double[] value = new double[valueArray.size()];
156                 for (int i = 0; i < valueArray.size(); i++) {
157                     value[i] = Double.parseDouble(valueArray.get(i).toString());
158                 }
159                 result.put(entry.getKey(), value);
160             } else if ("decimal".equals(valueType)) {
161                 double value = Double.parseDouble(entry.getValue());
162                 result.put(entry.getKey(), value);
163             }
164         }
165         return result;
166     }
167 }