houzhongjian
2024-11-07 a874b928e16320839315b9abcdf2cece1229a424
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.predict.impl;
2
3 import com.alibaba.fastjson.JSONArray;
4 import com.alibaba.fastjson.JSONObject;
5 import com.iail.IAILMDK;
6 import com.iail.model.IAILModel;
7 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
8 import com.iailab.module.model.mcs.pre.entity.MmModelResultstrEntity;
9 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
10 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
11 import com.iailab.module.model.mcs.pre.service.MmModelResultstrService;
12 import com.iailab.module.model.mdk.common.enums.TypeA;
13 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
14 import com.iailab.module.model.mdk.predict.PredictModelHandler;
15 import com.iailab.module.model.mdk.sample.SampleConstructor;
16 import com.iailab.module.model.mdk.sample.dto.SampleData;
17 import com.iailab.module.model.mdk.vo.PredictResultVO;
1a2b62 18 import com.iailab.module.model.mpk.common.utils.DllUtils;
7fd198 19 import lombok.extern.slf4j.Slf4j;
20 import org.springframework.beans.factory.annotation.Autowired;
21 import org.springframework.stereotype.Component;
22
23 import java.util.ArrayList;
24 import java.util.Date;
25 import java.util.HashMap;
26 import java.util.List;
27
28 /**
29  * @author PanZhibao
30  * @Description
31  * @createTime 2024年09月01日
32  */
33 @Slf4j
34 @Component
35 public class PredictModelHandlerImpl implements PredictModelHandler {
36
37     @Autowired
38     private MmModelArithSettingsService mmModelArithSettingsService;
39
40     @Autowired
41     private MmModelResultstrService mmModelResultstrService;
42
43     @Autowired
44     private SampleConstructor sampleConstructor;
45
46     @Override
47     public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException {
48         PredictResultVO result = new PredictResultVO();
49         if (predictModel == null) {
50             throw new ModelInvokeException("modelEntity is null");
51         }
52         String modelId = predictModel.getId();
53
54         try {
55             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime);
56             String modelPath = predictModel.getModelpath();
57             if (modelPath == null) {
58                 System.out.println("模型路径不存在,modelId=" + modelId);
59                 return null;
60             }
61             IAILModel newModelBean = composeNewModelBean(predictModel);
62             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
63             if (settings == null) {
64                 log.error("模型setting不存在,modelId=" + modelId);
65                 return null;
66             }
67             int portLength = sampleDataList.size();
68             Object[] param2Values = new Object[portLength + 2];
69             for (int i = 0; i < portLength; i++) {
70                 param2Values[i]=sampleDataList.get(i).getMatrix();
71             }
72             param2Values[portLength] = newModelBean.getDataMap().get("models");
73             param2Values[portLength+1] = settings;
74
75             log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################");
76             JSONObject jsonObjNewModelBean = new JSONObject();
77             jsonObjNewModelBean.put("newModelBean", newModelBean);
78             log.info(String.valueOf(jsonObjNewModelBean));
79             JSONObject jsonObjParam2Values = new JSONObject();
80             jsonObjParam2Values.put("param2Values", param2Values);
81             log.info(String.valueOf(jsonObjParam2Values));
82
83             //IAILMDK.run
1a2b62 84 //            HashMap<String, Object> modelResult = IAILMDK.run(newModelBean, param2Values);
D 85             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
7fd198 86             //打印结果
87             JSONObject jsonObjResult = new JSONObject();
1a2b62 88             jsonObjResult.put("result", modelResult);
7fd198 89             log.info(String.valueOf(jsonObjResult));
90
91             MmModelResultstrEntity modelResultstr = mmModelResultstrService.getInfo(predictModel.getResultstrid());
92             log.info("模型计算完成:modelId=" + modelId + result);
1a2b62 93             if (modelResult.containsKey(modelResultstr.getResultstr())) {
D 94                 Double[][] temp = (Double[][]) modelResult.get(modelResultstr.getResultstr());
95                 double[][] temp1 = new double[temp.length][temp[0].length];
96                 for (int i = 0; i < temp.length; i++) {
97                     for (int j = 0; j < temp[i].length; j++) {
98                         temp1[i][j] = temp[i][j].doubleValue();
99                     }
100                 }
101                 result.setPredictMatrix(temp1);
102             }
103             result.setModelResult(modelResult);
7fd198 104             result.setPredictTime(predictTime);
105         } catch (Exception ex) {
106             log.error("IAILModel对象构造失败,modelId=" + modelId);
107             log.error(ex.getMessage());
108             log.error("调用发生异常,异常信息为:{}" , ex);
109             ex.printStackTrace();
110
111         }
112         return result;
113     }
114
115     /**
116      * 构造IAILMDK.run()方法的newModelBean参数
117      *
118      * @param predictModel
119      * @return
120      */
121     private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
122         IAILModel newModelBean = new IAILModel();
123         newModelBean.setClassName(predictModel.getClassname().trim());
124         newModelBean.setMethodName(predictModel.getMethodname().trim());
125         //构造参数类型
126         String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
127         Class<?>[] paramsArray = new Class[paArStr.length];
128         for (int i = 0; i < paArStr.length; i++) {
129             if ("[[D".equals(paArStr[i])) {
130                 paramsArray[i] = double[][].class;
131             } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
132                 paramsArray[i] = HashMap.class;
133             }
134         }
135         newModelBean.setParamsArray(paramsArray);
136         HashMap<String, Object> dataMap = new HashMap<>();
137         HashMap<String, String> models = new HashMap<>(1);
138         models.put("paramFile", predictModel.getModelpath());
139         dataMap.put("models", models);
140         newModelBean.setDataMap(dataMap);
141         return newModelBean;
142     }
143
144     /**
145      * 根据模型id获取参数map
146      *
147      * @param modelId
148      * @return
149      */
150     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
151         List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
152         HashMap<String, Object> result = new HashMap<>();
153         for (MmModelArithSettingsEntity entry : list) {
154             String valueType = entry.getValuetype().trim(); //去除两端空格
155             if ("int".equals(valueType)) {
156                 int value = Integer.parseInt(entry.getValue());
157                 result.put(entry.getKey(), value);
158             } else if ("double".equals(valueType)) {
159                 double value = Double.parseDouble(entry.getValue());
160                 result.put(entry.getKey(), value);
161             } else if ("string".equals(valueType)) {
162                 String value = entry.getValue();
163                 result.put(entry.getKey(), value);
164             } else if ("decimalArray".equals(valueType)) {
165                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
166                 double[] value = new double[valueArray.size()];
167                 for (int i = 0; i < valueArray.size(); i++) {
168                     value[i] = Double.parseDouble(valueArray.get(i).toString());
169                 }
170                 result.put(entry.getKey(), value);
171             } else if ("decimal".equals(valueType)) {
172                 double value = Double.parseDouble(entry.getValue());
173                 result.put(entry.getKey(), value);
174             }
175         }
176         return result;
177     }
178 }