dengzedong
2024-11-12 69bd5efa5e2849a560e604d0aa608d5492113915
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.predict.impl;
2
3 import com.alibaba.fastjson.JSONArray;
4 import com.alibaba.fastjson.JSONObject;
5 import com.iail.IAILMDK;
6 import com.iail.model.IAILModel;
69bd5e 7 import com.iailab.module.model.mcs.pre.controller.admin.MmItemOutputController;
D 8 import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
7fd198 9 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
10 import com.iailab.module.model.mcs.pre.entity.MmModelResultstrEntity;
11 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
69bd5e 12 import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
7fd198 13 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
14 import com.iailab.module.model.mcs.pre.service.MmModelResultstrService;
15 import com.iailab.module.model.mdk.common.enums.TypeA;
16 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
17 import com.iailab.module.model.mdk.predict.PredictModelHandler;
18 import com.iailab.module.model.mdk.sample.SampleConstructor;
19 import com.iailab.module.model.mdk.sample.dto.SampleData;
20 import com.iailab.module.model.mdk.vo.PredictResultVO;
1a2b62 21 import com.iailab.module.model.mpk.common.utils.DllUtils;
7fd198 22 import lombok.extern.slf4j.Slf4j;
23 import org.springframework.beans.factory.annotation.Autowired;
69bd5e 24 import org.springframework.security.core.parameters.P;
7fd198 25 import org.springframework.stereotype.Component;
26
69bd5e 27 import java.util.*;
7fd198 28
29 /**
30  * @author PanZhibao
31  * @Description
32  * @createTime 2024年09月01日
33  */
34 @Slf4j
35 @Component
36 public class PredictModelHandlerImpl implements PredictModelHandler {
37
38     @Autowired
39     private MmModelArithSettingsService mmModelArithSettingsService;
40
41     @Autowired
42     private MmModelResultstrService mmModelResultstrService;
69bd5e 43     @Autowired
D 44     private MmItemOutputService mmItemOutputService;
7fd198 45
46     @Autowired
47     private SampleConstructor sampleConstructor;
48
49     @Override
50     public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException {
51         PredictResultVO result = new PredictResultVO();
52         if (predictModel == null) {
53             throw new ModelInvokeException("modelEntity is null");
54         }
55         String modelId = predictModel.getId();
56
57         try {
58             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime);
59             String modelPath = predictModel.getModelpath();
60             if (modelPath == null) {
61                 System.out.println("模型路径不存在,modelId=" + modelId);
62                 return null;
63             }
64             IAILModel newModelBean = composeNewModelBean(predictModel);
65             HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
66             if (settings == null) {
67                 log.error("模型setting不存在,modelId=" + modelId);
68                 return null;
69             }
70             int portLength = sampleDataList.size();
71             Object[] param2Values = new Object[portLength + 2];
72             for (int i = 0; i < portLength; i++) {
73                 param2Values[i]=sampleDataList.get(i).getMatrix();
74             }
75             param2Values[portLength] = newModelBean.getDataMap().get("models");
76             param2Values[portLength+1] = settings;
77
78             log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################");
79             JSONObject jsonObjNewModelBean = new JSONObject();
80             jsonObjNewModelBean.put("newModelBean", newModelBean);
81             log.info(String.valueOf(jsonObjNewModelBean));
82             JSONObject jsonObjParam2Values = new JSONObject();
83             jsonObjParam2Values.put("param2Values", param2Values);
84             log.info(String.valueOf(jsonObjParam2Values));
85
86             //IAILMDK.run
1a2b62 87 //            HashMap<String, Object> modelResult = IAILMDK.run(newModelBean, param2Values);
D 88             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
b2aca2 89             if(!modelResult.containsKey("status_code") || !modelResult.containsKey("result") || Integer.parseInt(modelResult.get("status_code").toString()) != 100) {
D 90                 throw new RuntimeException("模型结果异常:" + modelResult);
91             }
92
93             modelResult = (HashMap<String, Object>) modelResult.get("result");
7fd198 94             //打印结果
95             JSONObject jsonObjResult = new JSONObject();
1a2b62 96             jsonObjResult.put("result", modelResult);
7fd198 97             log.info(String.valueOf(jsonObjResult));
98
69bd5e 99             List<MmItemOutputEntity> ItemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
D 100             log.info("模型计算完成:modelId=" + modelId + modelResult);
101
102             Map<MmItemOutputEntity,double[]> predictMatrixs = new HashMap<>(ItemOutputList.size());
103
104             for (MmItemOutputEntity outputEntity : ItemOutputList) {
105                 String resultStr = outputEntity.getResultstr();
106                 if (modelResult.containsKey(resultStr)) {
107                     if (outputEntity.getResultType() == 1) {
108                         // 一维数组
109                         Double[] temp = (Double[]) modelResult.get(resultStr);
110                         double[] temp1 = new double[temp.length];
111                         for (int i = 0; i < temp.length; i++) {
112                             temp1[i] = temp[i].doubleValue();
113                         }
114                         predictMatrixs.put(outputEntity,temp1);
115                     }else if (outputEntity.getResultType() == 2) {
116                         // 二维数组
117                         Double[][] temp = (Double[][]) modelResult.get(resultStr);
118                         Double[] temp2 = temp[outputEntity.getResultIndex()];
119                         double[] temp1 = new double[temp2.length];
120                         for (int i = 0; i < temp2.length; i++) {
121                             temp1[i] = temp2[i].doubleValue();
122                         }
123                         predictMatrixs.put(outputEntity,temp1);
1a2b62 124                     }
D 125                 }
126             }
69bd5e 127
D 128             result.setPredictMatrixs(predictMatrixs);
1a2b62 129             result.setModelResult(modelResult);
7fd198 130             result.setPredictTime(predictTime);
131         } catch (Exception ex) {
132             log.error("IAILModel对象构造失败,modelId=" + modelId);
133             log.error(ex.getMessage());
134             log.error("调用发生异常,异常信息为:{}" , ex);
135             ex.printStackTrace();
136
137         }
138         return result;
139     }
140
141     /**
142      * 构造IAILMDK.run()方法的newModelBean参数
143      *
144      * @param predictModel
145      * @return
146      */
147     private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
148         IAILModel newModelBean = new IAILModel();
149         newModelBean.setClassName(predictModel.getClassname().trim());
150         newModelBean.setMethodName(predictModel.getMethodname().trim());
151         //构造参数类型
152         String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
153         Class<?>[] paramsArray = new Class[paArStr.length];
154         for (int i = 0; i < paArStr.length; i++) {
155             if ("[[D".equals(paArStr[i])) {
156                 paramsArray[i] = double[][].class;
157             } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
158                 paramsArray[i] = HashMap.class;
159             }
160         }
161         newModelBean.setParamsArray(paramsArray);
162         HashMap<String, Object> dataMap = new HashMap<>();
163         HashMap<String, String> models = new HashMap<>(1);
b2aca2 164         models.put("model_path", predictModel.getModelpath());
7fd198 165         dataMap.put("models", models);
166         newModelBean.setDataMap(dataMap);
167         return newModelBean;
168     }
169
170     /**
171      * 根据模型id获取参数map
172      *
173      * @param modelId
174      * @return
175      */
176     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
177         List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
178         HashMap<String, Object> result = new HashMap<>();
179         for (MmModelArithSettingsEntity entry : list) {
180             String valueType = entry.getValuetype().trim(); //去除两端空格
181             if ("int".equals(valueType)) {
182                 int value = Integer.parseInt(entry.getValue());
183                 result.put(entry.getKey(), value);
184             } else if ("double".equals(valueType)) {
185                 double value = Double.parseDouble(entry.getValue());
186                 result.put(entry.getKey(), value);
187             } else if ("string".equals(valueType)) {
188                 String value = entry.getValue();
189                 result.put(entry.getKey(), value);
190             } else if ("decimalArray".equals(valueType)) {
191                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
192                 double[] value = new double[valueArray.size()];
193                 for (int i = 0; i < valueArray.size(); i++) {
194                     value[i] = Double.parseDouble(valueArray.get(i).toString());
195                 }
196                 result.put(entry.getKey(), value);
197             } else if ("decimal".equals(valueType)) {
198                 double value = Double.parseDouble(entry.getValue());
199                 result.put(entry.getKey(), value);
200             }
201         }
202         return result;
203     }
204 }