Jay
2025-02-24 c3c7a6918f0e2dfe597c339117e4185b641be95f
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.predict.impl;
2
6957a3 3 import com.alibaba.fastjson.JSON;
7fd198 4 import com.alibaba.fastjson.JSONArray;
5 import com.iail.model.IAILModel;
401492 6 import com.iailab.module.model.enums.CommonConstant;
373ab1 7 import com.iailab.module.model.common.enums.OutResultType;
1178da 8 import com.iailab.module.model.common.exception.ModelResultErrorException;
69bd5e 9 import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
7fd198 10 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
11 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
69bd5e 12 import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
7fd198 13 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
14 import com.iailab.module.model.mdk.common.enums.TypeA;
15 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
16 import com.iailab.module.model.mdk.predict.PredictModelHandler;
17 import com.iailab.module.model.mdk.sample.SampleConstructor;
18 import com.iailab.module.model.mdk.sample.dto.SampleData;
19 import com.iailab.module.model.mdk.vo.PredictResultVO;
45520a 20 import com.iailab.module.model.mpk.common.MdkConstant;
1a2b62 21 import com.iailab.module.model.mpk.common.utils.DllUtils;
7fd198 22 import lombok.extern.slf4j.Slf4j;
23 import org.springframework.beans.factory.annotation.Autowired;
24 import org.springframework.stereotype.Component;
25
45520a 26 import java.util.Date;
D 27 import java.util.HashMap;
28 import java.util.List;
29 import java.util.Map;
7fd198 30
31 /**
32  * @author PanZhibao
33  * @Description
34  * @createTime 2024年09月01日
35  */
36 @Slf4j
37 @Component
38 public class PredictModelHandlerImpl implements PredictModelHandler {
39
40     @Autowired
41     private MmModelArithSettingsService mmModelArithSettingsService;
42
43     @Autowired
69bd5e 44     private MmItemOutputService mmItemOutputService;
7fd198 45
46     @Autowired
47     private SampleConstructor sampleConstructor;
48
4f1717 49     /**
50      * 根据模型预测,返回预测结果
51      *
52      * @param predictTime
53      * @param predictModel
54      * @return
55      * @throws ModelInvokeException
56      */
7fd198 57     @Override
28c2db 58     public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel,String itemName,String itemNo) throws ModelInvokeException {
7fd198 59         PredictResultVO result = new PredictResultVO();
60         if (predictModel == null) {
61             throw new ModelInvokeException("modelEntity is null");
62         }
63         String modelId = predictModel.getId();
64         try {
81ce77 65             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName, new HashMap<>());
7fd198 66             String modelPath = predictModel.getModelpath();
67             if (modelPath == null) {
373ab1 68                 log.info("模型路径不存在,modelId=" + modelId);
7fd198 69                 return null;
70             }
71             IAILModel newModelBean = composeNewModelBean(predictModel);
72             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
45520a 73             // 校验setting必须有pyFile,否则可能导致程序崩溃
D 74             if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
75                 throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY +  "】,请重新上传模型!");
76             }
77
7fd198 78             if (settings == null) {
79                 log.error("模型setting不存在,modelId=" + modelId);
80                 return null;
81             }
82             int portLength = sampleDataList.size();
83             Object[] param2Values = new Object[portLength + 2];
84             for (int i = 0; i < portLength; i++) {
4f1717 85                 param2Values[i] = sampleDataList.get(i).getMatrix();
7fd198 86             }
87             param2Values[portLength] = newModelBean.getDataMap().get("models");
4f1717 88             param2Values[portLength + 1] = settings;
7fd198 89
595ccf 90             log.info("####################### 预测模型 "+ "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################");
b82ba2 91 //            JSONObject jsonObjNewModelBean = new JSONObject();
D 92 //            jsonObjNewModelBean.put("newModelBean", newModelBean);
93 //            log.info(String.valueOf(jsonObjNewModelBean));
94 //            JSONObject jsonObjParam2Values = new JSONObject();
95 //            jsonObjParam2Values.put("param2Values", param2Values);
6957a3 96             log.info("参数: " + JSON.toJSONString(param2Values));
7fd198 97
98             //IAILMDK.run
1a2b62 99             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
73a05d 100             //打印结果
D 101             log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getMethodname() + ",modelResult=" + JSON.toJSONString(modelResult));
102             //判断模型结果
4f1717 103             if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
104                     !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
1178da 105                 throw new ModelResultErrorException("模型结果异常:" + modelResult);
b2aca2 106             }
4f1717 107             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
7fd198 108
373ab1 109             List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
a6e46f 110             Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>();
373ab1 111             for (MmItemOutputEntity output : itemOutputList) {
112                 if (!modelResult.containsKey(output.getResultstr())) {
113                     continue;
114                 }
115                 OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
116                 switch (outResultType) {
117                     case D1:
118                         double[] temp1 = (double[]) modelResult.get(output.getResultstr());
119                         predictMatrixs.put(output, temp1);
120                         break;
121                     case D2:
122                         double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
123                         double[] tempColumn = new double[temp2.length];
124                         for (int i = 0; i < tempColumn.length; i++) {
125                             tempColumn[i] = temp2[i][output.getResultIndex()];
69bd5e 126                         }
373ab1 127                         predictMatrixs.put(output, tempColumn);
128                         break;
a6e46f 129                     case D:
D 130                         Double temp3 = (Double) modelResult.get(output.getResultstr());
51e472 131                         predictMatrixs.put(output, new double[]{temp3});
a6e46f 132                         break;
373ab1 133                     default:
134                         break;
1a2b62 135                 }
D 136             }
69bd5e 137             result.setPredictMatrixs(predictMatrixs);
1a2b62 138             result.setModelResult(modelResult);
7fd198 139             result.setPredictTime(predictTime);
50084d 140         } catch (ModelResultErrorException ex) {
efdc38 141 //            ex.printStackTrace();
D 142             log.error("模型结果异常", ex);
50084d 143             throw ex;
D 144         } catch (Exception ex) {
c9dd12 145 //            log.error("调用发生异常,异常信息为:{0}", ex.getMessage());
efdc38 146 //            ex.printStackTrace();
D 147             log.error("模型运行异常", ex);
ead005 148             throw new ModelInvokeException(ex.getMessage());
7fd198 149         }
150         return result;
151     }
152
153     /**
fdcde1 154      * 预测,模拟调整
155      *
156      * @param predictTime
157      * @param predictModel
158      * @param itemName
159      * @param itemNo
160      * @return
161      * @throws ModelInvokeException
162      */
163     @Override
164     public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel,String itemName,String itemNo, double[][] deviation) throws ModelInvokeException {
165         PredictResultVO result = new PredictResultVO();
166         if (predictModel == null) {
167             throw new ModelInvokeException("modelEntity is null");
168         }
169         String modelId = predictModel.getId();
170         try {
171             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName, new HashMap<>());
172             String modelPath = predictModel.getModelpath();
173             if (modelPath == null) {
174                 log.info("模型路径不存在,modelId=" + modelId);
175                 return null;
176             }
177             IAILModel newModelBean = composeNewModelBean(predictModel);
178             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
179             // 校验setting必须有pyFile,否则可能导致程序崩溃
180             if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
181                 throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY +  "】,请重新上传模型!");
182             }
183
184             if (settings == null) {
185                 log.error("模型setting不存在,modelId=" + modelId);
186                 return null;
187             }
188             int portLength = sampleDataList.size();
189             Object[] param2Values = new Object[portLength + 2];
190             for (int i = 0; i < portLength; i++) {
191                 param2Values[i] = sampleDataList.get(i).getMatrix();
192             }
193             param2Values[portLength] = newModelBean.getDataMap().get("models");
194             param2Values[portLength + 1] = settings;
195
196             log.info("####################### 模拟调整 "+ "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################");
197             log.info("参数: " + JSON.toJSONString(param2Values));
198
199             //IAILMDK.run
200             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
201             //打印结果
202             log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getMethodname() + ",modelResult=" + JSON.toJSONString(modelResult));
203             //判断模型结果
204             if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
205                     !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
206                 throw new ModelResultErrorException("模型结果异常:" + modelResult);
207             }
208             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
209
210             List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
211             Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>();
212             for (MmItemOutputEntity output : itemOutputList) {
213                 if (!modelResult.containsKey(output.getResultstr())) {
214                     continue;
215                 }
216                 OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
217                 switch (outResultType) {
218                     case D1:
219                         double[] temp1 = (double[]) modelResult.get(output.getResultstr());
220                         predictMatrixs.put(output, temp1);
221                         break;
222                     case D2:
223                         double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
224                         double[] tempColumn = new double[temp2.length];
225                         for (int i = 0; i < tempColumn.length; i++) {
226                             tempColumn[i] = temp2[i][output.getResultIndex()];
227                         }
228                         predictMatrixs.put(output, tempColumn);
229                         break;
230                     case D:
231                         Double temp3 = (Double) modelResult.get(output.getResultstr());
232                         predictMatrixs.put(output, new double[]{temp3});
233                         break;
234                     default:
235                         break;
236                 }
237             }
238             result.setPredictMatrixs(predictMatrixs);
239             result.setModelResult(modelResult);
240             result.setPredictTime(predictTime);
241         } catch (ModelResultErrorException ex) {
242             log.error("模型结果异常", ex);
243             throw ex;
244         } catch (Exception ex) {
245             log.error("调用发生异常,异常信息为:{0}", ex.getMessage());
246             throw new ModelInvokeException(ex.getMessage());
247         }
248         return result;
249     }
250     /**
7fd198 251      * 构造IAILMDK.run()方法的newModelBean参数
252      *
253      * @param predictModel
254      * @return
255      */
256     private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
257         IAILModel newModelBean = new IAILModel();
258         newModelBean.setClassName(predictModel.getClassname().trim());
259         newModelBean.setMethodName(predictModel.getMethodname().trim());
260         //构造参数类型
261         String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
262         Class<?>[] paramsArray = new Class[paArStr.length];
263         for (int i = 0; i < paArStr.length; i++) {
264             if ("[[D".equals(paArStr[i])) {
265                 paramsArray[i] = double[][].class;
266             } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
267                 paramsArray[i] = HashMap.class;
268             }
269         }
270         newModelBean.setParamsArray(paramsArray);
271         HashMap<String, Object> dataMap = new HashMap<>();
272         HashMap<String, String> models = new HashMap<>(1);
b2aca2 273         models.put("model_path", predictModel.getModelpath());
7fd198 274         dataMap.put("models", models);
275         newModelBean.setDataMap(dataMap);
276         return newModelBean;
277     }
278
279     /**
280      * 根据模型id获取参数map
281      *
282      * @param modelId
283      * @return
284      */
285     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
286         List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
287         HashMap<String, Object> result = new HashMap<>();
288         for (MmModelArithSettingsEntity entry : list) {
289             String valueType = entry.getValuetype().trim(); //去除两端空格
290             if ("int".equals(valueType)) {
291                 int value = Integer.parseInt(entry.getValue());
292                 result.put(entry.getKey(), value);
293             } else if ("double".equals(valueType)) {
294                 double value = Double.parseDouble(entry.getValue());
295                 result.put(entry.getKey(), value);
296             } else if ("string".equals(valueType)) {
297                 String value = entry.getValue();
298                 result.put(entry.getKey(), value);
299             } else if ("decimalArray".equals(valueType)) {
300                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
301                 double[] value = new double[valueArray.size()];
302                 for (int i = 0; i < valueArray.size(); i++) {
303                     value[i] = Double.parseDouble(valueArray.get(i).toString());
304                 }
305                 result.put(entry.getKey(), value);
306             } else if ("decimal".equals(valueType)) {
307                 double value = Double.parseDouble(entry.getValue());
308                 result.put(entry.getKey(), value);
309             }
310         }
311         return result;
312     }
313 }