潘志宝
2024-11-19 97d38f7b7f7d95fe38cdbb79960106c15454b6ba
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.predict.impl;
2
3 import com.alibaba.fastjson.JSONArray;
4 import com.alibaba.fastjson.JSONObject;
5 import com.iail.model.IAILModel;
4f1717 6 import com.iailab.module.model.common.enums.CommonConstant;
69bd5e 7 import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
7fd198 8 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
9 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
69bd5e 10 import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
7fd198 11 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
12 import com.iailab.module.model.mdk.common.enums.TypeA;
13 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
14 import com.iailab.module.model.mdk.predict.PredictModelHandler;
15 import com.iailab.module.model.mdk.sample.SampleConstructor;
16 import com.iailab.module.model.mdk.sample.dto.SampleData;
17 import com.iailab.module.model.mdk.vo.PredictResultVO;
1a2b62 18 import com.iailab.module.model.mpk.common.utils.DllUtils;
7fd198 19 import lombok.extern.slf4j.Slf4j;
20 import org.springframework.beans.factory.annotation.Autowired;
21 import org.springframework.stereotype.Component;
22
69bd5e 23 import java.util.*;
7fd198 24
25 /**
26  * @author PanZhibao
27  * @Description
28  * @createTime 2024年09月01日
29  */
30 @Slf4j
31 @Component
32 public class PredictModelHandlerImpl implements PredictModelHandler {
33
34     @Autowired
35     private MmModelArithSettingsService mmModelArithSettingsService;
36
37     @Autowired
69bd5e 38     private MmItemOutputService mmItemOutputService;
7fd198 39
40     @Autowired
41     private SampleConstructor sampleConstructor;
42
4f1717 43     /**
44      * 根据模型预测,返回预测结果
45      *
46      * @param predictTime
47      * @param predictModel
48      * @return
49      * @throws ModelInvokeException
50      */
7fd198 51     @Override
4f1717 52     public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException {
7fd198 53         PredictResultVO result = new PredictResultVO();
54         if (predictModel == null) {
55             throw new ModelInvokeException("modelEntity is null");
56         }
57         String modelId = predictModel.getId();
58
59         try {
60             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime);
61             String modelPath = predictModel.getModelpath();
62             if (modelPath == null) {
63                 System.out.println("模型路径不存在,modelId=" + modelId);
64                 return null;
65             }
66             IAILModel newModelBean = composeNewModelBean(predictModel);
67             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
68             if (settings == null) {
69                 log.error("模型setting不存在,modelId=" + modelId);
70                 return null;
71             }
72             int portLength = sampleDataList.size();
73             Object[] param2Values = new Object[portLength + 2];
74             for (int i = 0; i < portLength; i++) {
4f1717 75                 param2Values[i] = sampleDataList.get(i).getMatrix();
7fd198 76             }
77             param2Values[portLength] = newModelBean.getDataMap().get("models");
4f1717 78             param2Values[portLength + 1] = settings;
7fd198 79
80             log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################");
81             JSONObject jsonObjNewModelBean = new JSONObject();
82             jsonObjNewModelBean.put("newModelBean", newModelBean);
83             log.info(String.valueOf(jsonObjNewModelBean));
84             JSONObject jsonObjParam2Values = new JSONObject();
85             jsonObjParam2Values.put("param2Values", param2Values);
86             log.info(String.valueOf(jsonObjParam2Values));
87
88             //IAILMDK.run
1a2b62 89             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
4f1717 90             if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
91                     !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
b2aca2 92                 throw new RuntimeException("模型结果异常:" + modelResult);
D 93             }
4f1717 94             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
7fd198 95             //打印结果
96             JSONObject jsonObjResult = new JSONObject();
1a2b62 97             jsonObjResult.put("result", modelResult);
7fd198 98             log.info(String.valueOf(jsonObjResult));
99
69bd5e 100             List<MmItemOutputEntity> ItemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
D 101             log.info("模型计算完成:modelId=" + modelId + modelResult);
102
4f1717 103             Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(ItemOutputList.size());
69bd5e 104
D 105             for (MmItemOutputEntity outputEntity : ItemOutputList) {
106                 String resultStr = outputEntity.getResultstr();
107                 if (modelResult.containsKey(resultStr)) {
108                     if (outputEntity.getResultType() == 1) {
109                         // 一维数组
110                         Double[] temp = (Double[]) modelResult.get(resultStr);
111                         double[] temp1 = new double[temp.length];
112                         for (int i = 0; i < temp.length; i++) {
113                             temp1[i] = temp[i].doubleValue();
114                         }
4f1717 115                         predictMatrixs.put(outputEntity, temp1);
116                     } else if (outputEntity.getResultType() == 2) {
69bd5e 117                         // 二维数组
D 118                         Double[][] temp = (Double[][]) modelResult.get(resultStr);
119                         Double[] temp2 = temp[outputEntity.getResultIndex()];
120                         double[] temp1 = new double[temp2.length];
121                         for (int i = 0; i < temp2.length; i++) {
122                             temp1[i] = temp2[i].doubleValue();
123                         }
4f1717 124                         predictMatrixs.put(outputEntity, temp1);
1a2b62 125                     }
D 126                 }
127             }
69bd5e 128             result.setPredictMatrixs(predictMatrixs);
1a2b62 129             result.setModelResult(modelResult);
7fd198 130             result.setPredictTime(predictTime);
131         } catch (Exception ex) {
4f1717 132             log.error("调用发生异常,异常信息为:{}", ex);
7fd198 133             ex.printStackTrace();
ead005 134             throw new ModelInvokeException(ex.getMessage());
7fd198 135         }
136         return result;
137     }
138
139     /**
140      * 构造IAILMDK.run()方法的newModelBean参数
141      *
142      * @param predictModel
143      * @return
144      */
145     private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
146         IAILModel newModelBean = new IAILModel();
147         newModelBean.setClassName(predictModel.getClassname().trim());
148         newModelBean.setMethodName(predictModel.getMethodname().trim());
149         //构造参数类型
150         String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
151         Class<?>[] paramsArray = new Class[paArStr.length];
152         for (int i = 0; i < paArStr.length; i++) {
153             if ("[[D".equals(paArStr[i])) {
154                 paramsArray[i] = double[][].class;
155             } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
156                 paramsArray[i] = HashMap.class;
157             }
158         }
159         newModelBean.setParamsArray(paramsArray);
160         HashMap<String, Object> dataMap = new HashMap<>();
161         HashMap<String, String> models = new HashMap<>(1);
b2aca2 162         models.put("model_path", predictModel.getModelpath());
7fd198 163         dataMap.put("models", models);
164         newModelBean.setDataMap(dataMap);
165         return newModelBean;
166     }
167
168     /**
169      * 根据模型id获取参数map
170      *
171      * @param modelId
172      * @return
173      */
174     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
175         List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
176         HashMap<String, Object> result = new HashMap<>();
177         for (MmModelArithSettingsEntity entry : list) {
178             String valueType = entry.getValuetype().trim(); //去除两端空格
179             if ("int".equals(valueType)) {
180                 int value = Integer.parseInt(entry.getValue());
181                 result.put(entry.getKey(), value);
182             } else if ("double".equals(valueType)) {
183                 double value = Double.parseDouble(entry.getValue());
184                 result.put(entry.getKey(), value);
185             } else if ("string".equals(valueType)) {
186                 String value = entry.getValue();
187                 result.put(entry.getKey(), value);
188             } else if ("decimalArray".equals(valueType)) {
189                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
190                 double[] value = new double[valueArray.size()];
191                 for (int i = 0; i < valueArray.size(); i++) {
192                     value[i] = Double.parseDouble(valueArray.get(i).toString());
193                 }
194                 result.put(entry.getKey(), value);
195             } else if ("decimal".equals(valueType)) {
196                 double value = Double.parseDouble(entry.getValue());
197                 result.put(entry.getKey(), value);
198             }
199         }
200         return result;
201     }
202 }