dengzedong
2024-12-24 aa0382e44311f9f7e62a688c8fcaa9c69a512e0f
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.predict.impl;
2
6957a3 3 import com.alibaba.fastjson.JSON;
7fd198 4 import com.alibaba.fastjson.JSONArray;
5 import com.alibaba.fastjson.JSONObject;
6 import com.iail.model.IAILModel;
4f1717 7 import com.iailab.module.model.common.enums.CommonConstant;
373ab1 8 import com.iailab.module.model.common.enums.OutResultType;
1178da 9 import com.iailab.module.model.common.exception.ModelResultErrorException;
69bd5e 10 import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
7fd198 11 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
12 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
7f0bcd 13 import com.iailab.module.model.mcs.pre.enums.ItemRunStatusEnum;
69bd5e 14 import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
7fd198 15 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
16 import com.iailab.module.model.mdk.common.enums.TypeA;
17 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
18 import com.iailab.module.model.mdk.predict.PredictModelHandler;
19 import com.iailab.module.model.mdk.sample.SampleConstructor;
20 import com.iailab.module.model.mdk.sample.dto.SampleData;
21 import com.iailab.module.model.mdk.vo.PredictResultVO;
45520a 22 import com.iailab.module.model.mpk.common.MdkConstant;
1a2b62 23 import com.iailab.module.model.mpk.common.utils.DllUtils;
7fd198 24 import lombok.extern.slf4j.Slf4j;
25 import org.springframework.beans.factory.annotation.Autowired;
26 import org.springframework.stereotype.Component;
27
45520a 28 import java.util.Date;
D 29 import java.util.HashMap;
30 import java.util.List;
31 import java.util.Map;
7fd198 32
33 /**
34  * @author PanZhibao
35  * @Description
36  * @createTime 2024年09月01日
37  */
38 @Slf4j
39 @Component
40 public class PredictModelHandlerImpl implements PredictModelHandler {
41
42     @Autowired
43     private MmModelArithSettingsService mmModelArithSettingsService;
44
45     @Autowired
69bd5e 46     private MmItemOutputService mmItemOutputService;
7fd198 47
48     @Autowired
49     private SampleConstructor sampleConstructor;
50
4f1717 51     /**
52      * 根据模型预测,返回预测结果
53      *
54      * @param predictTime
55      * @param predictModel
56      * @return
57      * @throws ModelInvokeException
58      */
7fd198 59     @Override
28c2db 60     public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel,String itemName,String itemNo) throws ModelInvokeException {
7fd198 61         PredictResultVO result = new PredictResultVO();
62         if (predictModel == null) {
63             throw new ModelInvokeException("modelEntity is null");
64         }
65         String modelId = predictModel.getId();
66         try {
50084d 67             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName);
7fd198 68             String modelPath = predictModel.getModelpath();
69             if (modelPath == null) {
373ab1 70                 log.info("模型路径不存在,modelId=" + modelId);
7fd198 71                 return null;
72             }
73             IAILModel newModelBean = composeNewModelBean(predictModel);
74             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
45520a 75             // 校验setting必须有pyFile,否则可能导致程序崩溃
D 76             if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
77                 throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY +  "】,请重新上传模型!");
78             }
79
7fd198 80             if (settings == null) {
81                 log.error("模型setting不存在,modelId=" + modelId);
82                 return null;
83             }
84             int portLength = sampleDataList.size();
85             Object[] param2Values = new Object[portLength + 2];
86             for (int i = 0; i < portLength; i++) {
4f1717 87                 param2Values[i] = sampleDataList.get(i).getMatrix();
7fd198 88             }
89             param2Values[portLength] = newModelBean.getDataMap().get("models");
4f1717 90             param2Values[portLength + 1] = settings;
7fd198 91
595ccf 92             log.info("####################### 预测模型 "+ "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################");
b82ba2 93 //            JSONObject jsonObjNewModelBean = new JSONObject();
D 94 //            jsonObjNewModelBean.put("newModelBean", newModelBean);
95 //            log.info(String.valueOf(jsonObjNewModelBean));
96 //            JSONObject jsonObjParam2Values = new JSONObject();
97 //            jsonObjParam2Values.put("param2Values", param2Values);
6957a3 98             log.info("参数: " + JSON.toJSONString(param2Values));
7fd198 99
100             //IAILMDK.run
1a2b62 101             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
4f1717 102             if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
103                     !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
1178da 104                 throw new ModelResultErrorException("模型结果异常:" + modelResult);
b2aca2 105             }
4f1717 106             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
7fd198 107             //打印结果
108             JSONObject jsonObjResult = new JSONObject();
1a2b62 109             jsonObjResult.put("result", modelResult);
595ccf 110             log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getMethodname() + ",modelResult=" + String.valueOf(jsonObjResult));
7fd198 111
373ab1 112             List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
a6e46f 113             Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>();
D 114             Map<MmItemOutputEntity, Double> predictDoubleValues = new HashMap<>();
373ab1 115             for (MmItemOutputEntity output : itemOutputList) {
116                 if (!modelResult.containsKey(output.getResultstr())) {
117                     continue;
118                 }
119                 OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
120                 switch (outResultType) {
121                     case D1:
122                         double[] temp1 = (double[]) modelResult.get(output.getResultstr());
123                         predictMatrixs.put(output, temp1);
124                         break;
125                     case D2:
126                         double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
127                         double[] tempColumn = new double[temp2.length];
128                         for (int i = 0; i < tempColumn.length; i++) {
129                             tempColumn[i] = temp2[i][output.getResultIndex()];
69bd5e 130                         }
373ab1 131                         predictMatrixs.put(output, tempColumn);
132                         break;
a6e46f 133                     case D:
D 134                         Double temp3 = (Double) modelResult.get(output.getResultstr());
135                         predictDoubleValues.put(output, temp3);
136                         break;
373ab1 137                     default:
138                         break;
1a2b62 139                 }
D 140             }
69bd5e 141             result.setPredictMatrixs(predictMatrixs);
a6e46f 142             result.setPredictDoubleValues(predictDoubleValues);
1a2b62 143             result.setModelResult(modelResult);
7fd198 144             result.setPredictTime(predictTime);
50084d 145         } catch (ModelResultErrorException ex) {
7fd198 146             ex.printStackTrace();
50084d 147             throw ex;
D 148         } catch (Exception ex) {
c9dd12 149 //            log.error("调用发生异常,异常信息为:{0}", ex.getMessage());
D 150             ex.printStackTrace();
ead005 151             throw new ModelInvokeException(ex.getMessage());
7fd198 152         }
153         return result;
154     }
155
156     /**
157      * 构造IAILMDK.run()方法的newModelBean参数
158      *
159      * @param predictModel
160      * @return
161      */
162     private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
163         IAILModel newModelBean = new IAILModel();
164         newModelBean.setClassName(predictModel.getClassname().trim());
165         newModelBean.setMethodName(predictModel.getMethodname().trim());
166         //构造参数类型
167         String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
168         Class<?>[] paramsArray = new Class[paArStr.length];
169         for (int i = 0; i < paArStr.length; i++) {
170             if ("[[D".equals(paArStr[i])) {
171                 paramsArray[i] = double[][].class;
172             } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
173                 paramsArray[i] = HashMap.class;
174             }
175         }
176         newModelBean.setParamsArray(paramsArray);
177         HashMap<String, Object> dataMap = new HashMap<>();
178         HashMap<String, String> models = new HashMap<>(1);
b2aca2 179         models.put("model_path", predictModel.getModelpath());
7fd198 180         dataMap.put("models", models);
181         newModelBean.setDataMap(dataMap);
182         return newModelBean;
183     }
184
185     /**
186      * 根据模型id获取参数map
187      *
188      * @param modelId
189      * @return
190      */
191     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
192         List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
193         HashMap<String, Object> result = new HashMap<>();
194         for (MmModelArithSettingsEntity entry : list) {
195             String valueType = entry.getValuetype().trim(); //去除两端空格
196             if ("int".equals(valueType)) {
197                 int value = Integer.parseInt(entry.getValue());
198                 result.put(entry.getKey(), value);
199             } else if ("double".equals(valueType)) {
200                 double value = Double.parseDouble(entry.getValue());
201                 result.put(entry.getKey(), value);
202             } else if ("string".equals(valueType)) {
203                 String value = entry.getValue();
204                 result.put(entry.getKey(), value);
205             } else if ("decimalArray".equals(valueType)) {
206                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
207                 double[] value = new double[valueArray.size()];
208                 for (int i = 0; i < valueArray.size(); i++) {
209                     value[i] = Double.parseDouble(valueArray.get(i).toString());
210                 }
211                 result.put(entry.getKey(), value);
212             } else if ("decimal".equals(valueType)) {
213                 double value = Double.parseDouble(entry.getValue());
214                 result.put(entry.getKey(), value);
215             }
216         }
217         return result;
218     }
219 }