houzhongjian
8 天以前 3058865fa4dfa634a92b4ebd826d8b1264dc90a3
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.predict.impl;
2
6957a3 3 import com.alibaba.fastjson.JSON;
7fd198 4 import com.alibaba.fastjson.JSONArray;
5 import com.iail.model.IAILModel;
b3674c 6 import com.iailab.module.model.mdk.vo.StAdjustDeviationDTO;
401492 7 import com.iailab.module.model.enums.CommonConstant;
373ab1 8 import com.iailab.module.model.common.enums.OutResultType;
1178da 9 import com.iailab.module.model.common.exception.ModelResultErrorException;
69bd5e 10 import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
7fd198 11 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
12 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
69bd5e 13 import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
7fd198 14 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
15 import com.iailab.module.model.mdk.common.enums.TypeA;
16 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
17 import com.iailab.module.model.mdk.predict.PredictModelHandler;
18 import com.iailab.module.model.mdk.sample.SampleConstructor;
19 import com.iailab.module.model.mdk.sample.dto.SampleData;
20 import com.iailab.module.model.mdk.vo.PredictResultVO;
45520a 21 import com.iailab.module.model.mpk.common.MdkConstant;
1a2b62 22 import com.iailab.module.model.mpk.common.utils.DllUtils;
7fd198 23 import lombok.extern.slf4j.Slf4j;
24 import org.springframework.beans.factory.annotation.Autowired;
25 import org.springframework.stereotype.Component;
26
45520a 27 import java.util.Date;
D 28 import java.util.HashMap;
29 import java.util.List;
30 import java.util.Map;
7fd198 31
32 /**
33  * @author PanZhibao
34  * @Description
35  * @createTime 2024年09月01日
36  */
37 @Slf4j
38 @Component
39 public class PredictModelHandlerImpl implements PredictModelHandler {
40
41     @Autowired
42     private MmModelArithSettingsService mmModelArithSettingsService;
43
44     @Autowired
69bd5e 45     private MmItemOutputService mmItemOutputService;
7fd198 46
47     @Autowired
48     private SampleConstructor sampleConstructor;
49
4f1717 50     /**
51      * 根据模型预测,返回预测结果
52      *
53      * @param predictTime
54      * @param predictModel
55      * @return
56      * @throws ModelInvokeException
57      */
7fd198 58     @Override
b3674c 59     public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel, String itemName, String itemNo) throws ModelInvokeException {
7fd198 60         PredictResultVO result = new PredictResultVO();
61         if (predictModel == null) {
62             throw new ModelInvokeException("modelEntity is null");
63         }
64         String modelId = predictModel.getId();
65         try {
81ce77 66             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName, new HashMap<>());
7fd198 67             String modelPath = predictModel.getModelpath();
68             if (modelPath == null) {
373ab1 69                 log.info("模型路径不存在,modelId=" + modelId);
7fd198 70                 return null;
71             }
72             IAILModel newModelBean = composeNewModelBean(predictModel);
73             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
45520a 74             // 校验setting必须有pyFile,否则可能导致程序崩溃
D 75             if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
b3674c 76                 throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!");
45520a 77             }
D 78
7fd198 79             if (settings == null) {
80                 log.error("模型setting不存在,modelId=" + modelId);
81                 return null;
82             }
83             int portLength = sampleDataList.size();
84             Object[] param2Values = new Object[portLength + 2];
85             for (int i = 0; i < portLength; i++) {
4f1717 86                 param2Values[i] = sampleDataList.get(i).getMatrix();
7fd198 87             }
88             param2Values[portLength] = newModelBean.getDataMap().get("models");
4f1717 89             param2Values[portLength + 1] = settings;
7fd198 90
b3674c 91             log.info("####################### 预测模型 " + "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################");
6957a3 92             log.info("参数: " + JSON.toJSONString(param2Values));
7fd198 93
94             //IAILMDK.run
1a2b62 95             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
73a05d 96             //打印结果
6c30a0 97             log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getModelname() + ",modelResult=" + JSON.toJSONString(modelResult));
73a05d 98             //判断模型结果
4f1717 99             if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
100                     !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
1178da 101                 throw new ModelResultErrorException("模型结果异常:" + modelResult);
b2aca2 102             }
4f1717 103             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
7fd198 104
373ab1 105             List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
a6e46f 106             Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>();
373ab1 107             for (MmItemOutputEntity output : itemOutputList) {
108                 if (!modelResult.containsKey(output.getResultstr())) {
109                     continue;
110                 }
111                 OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
112                 switch (outResultType) {
113                     case D1:
114                         double[] temp1 = (double[]) modelResult.get(output.getResultstr());
115                         predictMatrixs.put(output, temp1);
116                         break;
117                     case D2:
118                         double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
119                         double[] tempColumn = new double[temp2.length];
120                         for (int i = 0; i < tempColumn.length; i++) {
121                             tempColumn[i] = temp2[i][output.getResultIndex()];
69bd5e 122                         }
373ab1 123                         predictMatrixs.put(output, tempColumn);
124                         break;
a6e46f 125                     case D:
D 126                         Double temp3 = (Double) modelResult.get(output.getResultstr());
51e472 127                         predictMatrixs.put(output, new double[]{temp3});
a6e46f 128                         break;
373ab1 129                     default:
130                         break;
1a2b62 131                 }
D 132             }
69bd5e 133             result.setPredictMatrixs(predictMatrixs);
1a2b62 134             result.setModelResult(modelResult);
7fd198 135             result.setPredictTime(predictTime);
50084d 136         } catch (ModelResultErrorException ex) {
efdc38 137 //            ex.printStackTrace();
D 138             log.error("模型结果异常", ex);
50084d 139             throw ex;
D 140         } catch (Exception ex) {
c9dd12 141 //            log.error("调用发生异常,异常信息为:{0}", ex.getMessage());
efdc38 142 //            ex.printStackTrace();
D 143             log.error("模型运行异常", ex);
ead005 144             throw new ModelInvokeException(ex.getMessage());
7fd198 145         }
146         return result;
147     }
148
149     /**
fdcde1 150      * 预测,模拟调整
151      *
152      * @param predictTime
153      * @param predictModel
154      * @param itemName
155      * @param itemNo
b3674c 156      * @param deviationList
fdcde1 157      * @return
158      * @throws ModelInvokeException
159      */
160     @Override
b3674c 161     public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel, String itemName, String itemNo, List<StAdjustDeviationDTO> deviationList) throws ModelInvokeException {
fdcde1 162         PredictResultVO result = new PredictResultVO();
163         if (predictModel == null) {
164             throw new ModelInvokeException("modelEntity is null");
165         }
166         String modelId = predictModel.getId();
167         try {
b3674c 168             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName, new HashMap<>(), deviationList);
fdcde1 169             String modelPath = predictModel.getModelpath();
170             if (modelPath == null) {
171                 log.info("模型路径不存在,modelId=" + modelId);
172                 return null;
173             }
174             IAILModel newModelBean = composeNewModelBean(predictModel);
175             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
176             // 校验setting必须有pyFile,否则可能导致程序崩溃
177             if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
b3674c 178                 throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!");
fdcde1 179             }
180
181             if (settings == null) {
182                 log.error("模型setting不存在,modelId=" + modelId);
183                 return null;
184             }
185             int portLength = sampleDataList.size();
186             Object[] param2Values = new Object[portLength + 2];
187             for (int i = 0; i < portLength; i++) {
188                 param2Values[i] = sampleDataList.get(i).getMatrix();
189             }
190             param2Values[portLength] = newModelBean.getDataMap().get("models");
191             param2Values[portLength + 1] = settings;
192
b3674c 193             log.info("####################### 模拟调整 " + "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################");
fdcde1 194             log.info("参数: " + JSON.toJSONString(param2Values));
195
196             //IAILMDK.run
197             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
198             //打印结果
6c30a0 199             log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getModelname() + ",modelResult=" + JSON.toJSONString(modelResult));
fdcde1 200             //判断模型结果
201             if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
202                     !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
203                 throw new ModelResultErrorException("模型结果异常:" + modelResult);
204             }
205             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
206
207             List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
208             Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>();
209             for (MmItemOutputEntity output : itemOutputList) {
210                 if (!modelResult.containsKey(output.getResultstr())) {
211                     continue;
212                 }
213                 OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
214                 switch (outResultType) {
215                     case D1:
216                         double[] temp1 = (double[]) modelResult.get(output.getResultstr());
217                         predictMatrixs.put(output, temp1);
218                         break;
219                     case D2:
220                         double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
221                         double[] tempColumn = new double[temp2.length];
222                         for (int i = 0; i < tempColumn.length; i++) {
223                             tempColumn[i] = temp2[i][output.getResultIndex()];
224                         }
225                         predictMatrixs.put(output, tempColumn);
226                         break;
227                     case D:
228                         Double temp3 = (Double) modelResult.get(output.getResultstr());
229                         predictMatrixs.put(output, new double[]{temp3});
230                         break;
231                     default:
232                         break;
233                 }
234             }
235             result.setPredictMatrixs(predictMatrixs);
236             result.setModelResult(modelResult);
237             result.setPredictTime(predictTime);
238         } catch (ModelResultErrorException ex) {
239             log.error("模型结果异常", ex);
240             throw ex;
241         } catch (Exception ex) {
242             log.error("调用发生异常,异常信息为:{0}", ex.getMessage());
243             throw new ModelInvokeException(ex.getMessage());
244         }
245         return result;
246     }
b3674c 247
fdcde1 248     /**
7fd198 249      * 构造IAILMDK.run()方法的newModelBean参数
250      *
251      * @param predictModel
252      * @return
253      */
254     private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
255         IAILModel newModelBean = new IAILModel();
256         newModelBean.setClassName(predictModel.getClassname().trim());
257         newModelBean.setMethodName(predictModel.getMethodname().trim());
258         //构造参数类型
259         String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
260         Class<?>[] paramsArray = new Class[paArStr.length];
261         for (int i = 0; i < paArStr.length; i++) {
262             if ("[[D".equals(paArStr[i])) {
263                 paramsArray[i] = double[][].class;
264             } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
265                 paramsArray[i] = HashMap.class;
266             }
267         }
268         newModelBean.setParamsArray(paramsArray);
269         HashMap<String, Object> dataMap = new HashMap<>();
270         HashMap<String, String> models = new HashMap<>(1);
b2aca2 271         models.put("model_path", predictModel.getModelpath());
7fd198 272         dataMap.put("models", models);
273         newModelBean.setDataMap(dataMap);
274         return newModelBean;
275     }
276
277     /**
278      * 根据模型id获取参数map
279      *
280      * @param modelId
281      * @return
282      */
283     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
284         List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
285         HashMap<String, Object> result = new HashMap<>();
286         for (MmModelArithSettingsEntity entry : list) {
287             String valueType = entry.getValuetype().trim(); //去除两端空格
288             if ("int".equals(valueType)) {
289                 int value = Integer.parseInt(entry.getValue());
290                 result.put(entry.getKey(), value);
291             } else if ("double".equals(valueType)) {
292                 double value = Double.parseDouble(entry.getValue());
293                 result.put(entry.getKey(), value);
294             } else if ("string".equals(valueType)) {
295                 String value = entry.getValue();
296                 result.put(entry.getKey(), value);
297             } else if ("decimalArray".equals(valueType)) {
298                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
299                 double[] value = new double[valueArray.size()];
300                 for (int i = 0; i < valueArray.size(); i++) {
301                     value[i] = Double.parseDouble(valueArray.get(i).toString());
302                 }
303                 result.put(entry.getKey(), value);
304             } else if ("decimal".equals(valueType)) {
305                 double value = Double.parseDouble(entry.getValue());
306                 result.put(entry.getKey(), value);
307             }
308         }
309         return result;
310     }
311 }