dengzedong
6 天以前 a6e46fe2b5729e7468b6f3c4e079232801c22520
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.predict.impl;
2
3 import com.alibaba.fastjson.JSONArray;
4 import com.alibaba.fastjson.JSONObject;
5 import com.iail.model.IAILModel;
4f1717 6 import com.iailab.module.model.common.enums.CommonConstant;
373ab1 7 import com.iailab.module.model.common.enums.OutResultType;
69bd5e 8 import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
7fd198 9 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
10 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
69bd5e 11 import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
7fd198 12 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
13 import com.iailab.module.model.mdk.common.enums.TypeA;
14 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
15 import com.iailab.module.model.mdk.predict.PredictModelHandler;
16 import com.iailab.module.model.mdk.sample.SampleConstructor;
17 import com.iailab.module.model.mdk.sample.dto.SampleData;
18 import com.iailab.module.model.mdk.vo.PredictResultVO;
45520a 19 import com.iailab.module.model.mpk.common.MdkConstant;
1a2b62 20 import com.iailab.module.model.mpk.common.utils.DllUtils;
7fd198 21 import lombok.extern.slf4j.Slf4j;
22 import org.springframework.beans.factory.annotation.Autowired;
23 import org.springframework.stereotype.Component;
24
45520a 25 import java.util.Date;
D 26 import java.util.HashMap;
27 import java.util.List;
28 import java.util.Map;
7fd198 29
30 /**
31  * @author PanZhibao
32  * @Description
33  * @createTime 2024年09月01日
34  */
35 @Slf4j
36 @Component
37 public class PredictModelHandlerImpl implements PredictModelHandler {
38
39     @Autowired
40     private MmModelArithSettingsService mmModelArithSettingsService;
41
42     @Autowired
69bd5e 43     private MmItemOutputService mmItemOutputService;
7fd198 44
45     @Autowired
46     private SampleConstructor sampleConstructor;
47
4f1717 48     /**
49      * 根据模型预测,返回预测结果
50      *
51      * @param predictTime
52      * @param predictModel
53      * @return
54      * @throws ModelInvokeException
55      */
7fd198 56     @Override
4f1717 57     public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException {
7fd198 58         PredictResultVO result = new PredictResultVO();
59         if (predictModel == null) {
60             throw new ModelInvokeException("modelEntity is null");
61         }
62         String modelId = predictModel.getId();
63         try {
64             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime);
65             String modelPath = predictModel.getModelpath();
66             if (modelPath == null) {
373ab1 67                 log.info("模型路径不存在,modelId=" + modelId);
7fd198 68                 return null;
69             }
70             IAILModel newModelBean = composeNewModelBean(predictModel);
71             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
45520a 72             // 校验setting必须有pyFile,否则可能导致程序崩溃
D 73             if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
74                 throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY +  "】,请重新上传模型!");
75             }
76
7fd198 77             if (settings == null) {
78                 log.error("模型setting不存在,modelId=" + modelId);
79                 return null;
80             }
81             int portLength = sampleDataList.size();
82             Object[] param2Values = new Object[portLength + 2];
83             for (int i = 0; i < portLength; i++) {
4f1717 84                 param2Values[i] = sampleDataList.get(i).getMatrix();
7fd198 85             }
86             param2Values[portLength] = newModelBean.getDataMap().get("models");
4f1717 87             param2Values[portLength + 1] = settings;
7fd198 88
b82ba2 89             log.info("####################### 预测模型 "+ "【itemId:" + predictModel.getItemid() + ",modelName" + predictModel.getMethodname() + "】 ##########################");
D 90 //            JSONObject jsonObjNewModelBean = new JSONObject();
91 //            jsonObjNewModelBean.put("newModelBean", newModelBean);
92 //            log.info(String.valueOf(jsonObjNewModelBean));
93 //            JSONObject jsonObjParam2Values = new JSONObject();
94 //            jsonObjParam2Values.put("param2Values", param2Values);
95 //            log.info(String.valueOf(jsonObjParam2Values));
7fd198 96
97             //IAILMDK.run
1a2b62 98             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
4f1717 99             if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
100                     !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
b2aca2 101                 throw new RuntimeException("模型结果异常:" + modelResult);
D 102             }
4f1717 103             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
7fd198 104             //打印结果
b82ba2 105             log.info("预测模型计算完成:modelId=" + modelId + ",modelName" + predictModel.getMethodname());
7fd198 106             JSONObject jsonObjResult = new JSONObject();
1a2b62 107             jsonObjResult.put("result", modelResult);
7fd198 108             log.info(String.valueOf(jsonObjResult));
109
373ab1 110             List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
a6e46f 111             Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>();
D 112             Map<MmItemOutputEntity, Double> predictDoubleValues = new HashMap<>();
373ab1 113             for (MmItemOutputEntity output : itemOutputList) {
114                 if (!modelResult.containsKey(output.getResultstr())) {
115                     continue;
116                 }
117                 OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
118                 switch (outResultType) {
119                     case D1:
120                         double[] temp1 = (double[]) modelResult.get(output.getResultstr());
121                         predictMatrixs.put(output, temp1);
122                         break;
123                     case D2:
124                         double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
125                         double[] tempColumn = new double[temp2.length];
126                         for (int i = 0; i < tempColumn.length; i++) {
127                             tempColumn[i] = temp2[i][output.getResultIndex()];
69bd5e 128                         }
373ab1 129                         predictMatrixs.put(output, tempColumn);
130                         break;
a6e46f 131                     case D:
D 132                         Double temp3 = (Double) modelResult.get(output.getResultstr());
133                         predictDoubleValues.put(output, temp3);
134                         break;
373ab1 135                     default:
136                         break;
1a2b62 137                 }
D 138             }
69bd5e 139             result.setPredictMatrixs(predictMatrixs);
a6e46f 140             result.setPredictDoubleValues(predictDoubleValues);
1a2b62 141             result.setModelResult(modelResult);
7fd198 142             result.setPredictTime(predictTime);
143         } catch (Exception ex) {
4f1717 144             log.error("调用发生异常,异常信息为:{}", ex);
7fd198 145             ex.printStackTrace();
ead005 146             throw new ModelInvokeException(ex.getMessage());
7fd198 147         }
148         return result;
149     }
150
151     /**
152      * 构造IAILMDK.run()方法的newModelBean参数
153      *
154      * @param predictModel
155      * @return
156      */
157     private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
158         IAILModel newModelBean = new IAILModel();
159         newModelBean.setClassName(predictModel.getClassname().trim());
160         newModelBean.setMethodName(predictModel.getMethodname().trim());
161         //构造参数类型
162         String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
163         Class<?>[] paramsArray = new Class[paArStr.length];
164         for (int i = 0; i < paArStr.length; i++) {
165             if ("[[D".equals(paArStr[i])) {
166                 paramsArray[i] = double[][].class;
167             } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
168                 paramsArray[i] = HashMap.class;
169             }
170         }
171         newModelBean.setParamsArray(paramsArray);
172         HashMap<String, Object> dataMap = new HashMap<>();
173         HashMap<String, String> models = new HashMap<>(1);
b2aca2 174         models.put("model_path", predictModel.getModelpath());
7fd198 175         dataMap.put("models", models);
176         newModelBean.setDataMap(dataMap);
177         return newModelBean;
178     }
179
180     /**
181      * 根据模型id获取参数map
182      *
183      * @param modelId
184      * @return
185      */
186     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
187         List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
188         HashMap<String, Object> result = new HashMap<>();
189         for (MmModelArithSettingsEntity entry : list) {
190             String valueType = entry.getValuetype().trim(); //去除两端空格
191             if ("int".equals(valueType)) {
192                 int value = Integer.parseInt(entry.getValue());
193                 result.put(entry.getKey(), value);
194             } else if ("double".equals(valueType)) {
195                 double value = Double.parseDouble(entry.getValue());
196                 result.put(entry.getKey(), value);
197             } else if ("string".equals(valueType)) {
198                 String value = entry.getValue();
199                 result.put(entry.getKey(), value);
200             } else if ("decimalArray".equals(valueType)) {
201                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
202                 double[] value = new double[valueArray.size()];
203                 for (int i = 0; i < valueArray.size(); i++) {
204                     value[i] = Double.parseDouble(valueArray.get(i).toString());
205                 }
206                 result.put(entry.getKey(), value);
207             } else if ("decimal".equals(valueType)) {
208                 double value = Double.parseDouble(entry.getValue());
209                 result.put(entry.getKey(), value);
210             }
211         }
212         return result;
213     }
214 }