提交 | 用户 | 时间
|
7fd198
|
1 |
package com.iailab.module.model.mdk.predict.impl; |
潘 |
2 |
|
6957a3
|
3 |
import com.alibaba.fastjson.JSON; |
7fd198
|
4 |
import com.alibaba.fastjson.JSONArray; |
潘 |
5 |
import com.iail.model.IAILModel; |
b3674c
|
6 |
import com.iailab.module.model.mdk.vo.StAdjustDeviationDTO; |
401492
|
7 |
import com.iailab.module.model.enums.CommonConstant; |
373ab1
|
8 |
import com.iailab.module.model.common.enums.OutResultType; |
1178da
|
9 |
import com.iailab.module.model.common.exception.ModelResultErrorException; |
69bd5e
|
10 |
import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity; |
7fd198
|
11 |
import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity; |
潘 |
12 |
import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; |
69bd5e
|
13 |
import com.iailab.module.model.mcs.pre.service.MmItemOutputService; |
7fd198
|
14 |
import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService; |
潘 |
15 |
import com.iailab.module.model.mdk.common.enums.TypeA; |
|
16 |
import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException; |
|
17 |
import com.iailab.module.model.mdk.predict.PredictModelHandler; |
|
18 |
import com.iailab.module.model.mdk.sample.SampleConstructor; |
|
19 |
import com.iailab.module.model.mdk.sample.dto.SampleData; |
|
20 |
import com.iailab.module.model.mdk.vo.PredictResultVO; |
45520a
|
21 |
import com.iailab.module.model.mpk.common.MdkConstant; |
1a2b62
|
22 |
import com.iailab.module.model.mpk.common.utils.DllUtils; |
7fd198
|
23 |
import lombok.extern.slf4j.Slf4j; |
潘 |
24 |
import org.springframework.beans.factory.annotation.Autowired; |
|
25 |
import org.springframework.stereotype.Component; |
|
26 |
|
45520a
|
27 |
import java.util.Date; |
D |
28 |
import java.util.HashMap; |
|
29 |
import java.util.List; |
|
30 |
import java.util.Map; |
7fd198
|
31 |
|
潘 |
32 |
/** |
|
33 |
* @author PanZhibao |
|
34 |
* @Description |
|
35 |
* @createTime 2024年09月01日 |
|
36 |
*/ |
|
37 |
@Slf4j |
|
38 |
@Component |
|
39 |
public class PredictModelHandlerImpl implements PredictModelHandler { |
|
40 |
|
|
41 |
@Autowired |
|
42 |
private MmModelArithSettingsService mmModelArithSettingsService; |
|
43 |
|
|
44 |
@Autowired |
69bd5e
|
45 |
private MmItemOutputService mmItemOutputService; |
7fd198
|
46 |
|
潘 |
47 |
@Autowired |
|
48 |
private SampleConstructor sampleConstructor; |
|
49 |
|
4f1717
|
50 |
/** |
潘 |
51 |
* 根据模型预测,返回预测结果 |
|
52 |
* |
|
53 |
* @param predictTime |
|
54 |
* @param predictModel |
|
55 |
* @return |
|
56 |
* @throws ModelInvokeException |
|
57 |
*/ |
7fd198
|
58 |
@Override |
b3674c
|
59 |
public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel, String itemName, String itemNo) throws ModelInvokeException { |
7fd198
|
60 |
PredictResultVO result = new PredictResultVO(); |
潘 |
61 |
if (predictModel == null) { |
|
62 |
throw new ModelInvokeException("modelEntity is null"); |
|
63 |
} |
|
64 |
String modelId = predictModel.getId(); |
|
65 |
try { |
81ce77
|
66 |
List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName, new HashMap<>()); |
7fd198
|
67 |
String modelPath = predictModel.getModelpath(); |
潘 |
68 |
if (modelPath == null) { |
373ab1
|
69 |
log.info("模型路径不存在,modelId=" + modelId); |
7fd198
|
70 |
return null; |
潘 |
71 |
} |
|
72 |
IAILModel newModelBean = composeNewModelBean(predictModel); |
|
73 |
HashMap<String, Object> settings = getPredictSettingsByModelId(modelId); |
45520a
|
74 |
// 校验setting必须有pyFile,否则可能导致程序崩溃 |
D |
75 |
if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { |
b3674c
|
76 |
throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); |
45520a
|
77 |
} |
D |
78 |
|
7fd198
|
79 |
if (settings == null) { |
潘 |
80 |
log.error("模型setting不存在,modelId=" + modelId); |
|
81 |
return null; |
|
82 |
} |
|
83 |
int portLength = sampleDataList.size(); |
|
84 |
Object[] param2Values = new Object[portLength + 2]; |
|
85 |
for (int i = 0; i < portLength; i++) { |
4f1717
|
86 |
param2Values[i] = sampleDataList.get(i).getMatrix(); |
7fd198
|
87 |
} |
潘 |
88 |
param2Values[portLength] = newModelBean.getDataMap().get("models"); |
4f1717
|
89 |
param2Values[portLength + 1] = settings; |
7fd198
|
90 |
|
b3674c
|
91 |
log.info("####################### 预测模型 " + "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################"); |
6957a3
|
92 |
log.info("参数: " + JSON.toJSONString(param2Values)); |
7fd198
|
93 |
|
潘 |
94 |
//IAILMDK.run |
1a2b62
|
95 |
HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid()); |
73a05d
|
96 |
//打印结果 |
D |
97 |
log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getMethodname() + ",modelResult=" + JSON.toJSONString(modelResult)); |
|
98 |
//判断模型结果 |
4f1717
|
99 |
if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || |
潘 |
100 |
!modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { |
1178da
|
101 |
throw new ModelResultErrorException("模型结果异常:" + modelResult); |
b2aca2
|
102 |
} |
4f1717
|
103 |
modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT); |
7fd198
|
104 |
|
373ab1
|
105 |
List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid()); |
a6e46f
|
106 |
Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(); |
373ab1
|
107 |
for (MmItemOutputEntity output : itemOutputList) { |
潘 |
108 |
if (!modelResult.containsKey(output.getResultstr())) { |
|
109 |
continue; |
|
110 |
} |
|
111 |
OutResultType outResultType = OutResultType.getEumByCode(output.getResultType()); |
|
112 |
switch (outResultType) { |
|
113 |
case D1: |
|
114 |
double[] temp1 = (double[]) modelResult.get(output.getResultstr()); |
|
115 |
predictMatrixs.put(output, temp1); |
|
116 |
break; |
|
117 |
case D2: |
|
118 |
double[][] temp2 = (double[][]) modelResult.get(output.getResultstr()); |
|
119 |
double[] tempColumn = new double[temp2.length]; |
|
120 |
for (int i = 0; i < tempColumn.length; i++) { |
|
121 |
tempColumn[i] = temp2[i][output.getResultIndex()]; |
69bd5e
|
122 |
} |
373ab1
|
123 |
predictMatrixs.put(output, tempColumn); |
潘 |
124 |
break; |
a6e46f
|
125 |
case D: |
D |
126 |
Double temp3 = (Double) modelResult.get(output.getResultstr()); |
51e472
|
127 |
predictMatrixs.put(output, new double[]{temp3}); |
a6e46f
|
128 |
break; |
373ab1
|
129 |
default: |
潘 |
130 |
break; |
1a2b62
|
131 |
} |
D |
132 |
} |
69bd5e
|
133 |
result.setPredictMatrixs(predictMatrixs); |
1a2b62
|
134 |
result.setModelResult(modelResult); |
7fd198
|
135 |
result.setPredictTime(predictTime); |
50084d
|
136 |
} catch (ModelResultErrorException ex) { |
efdc38
|
137 |
// ex.printStackTrace(); |
D |
138 |
log.error("模型结果异常", ex); |
50084d
|
139 |
throw ex; |
D |
140 |
} catch (Exception ex) { |
c9dd12
|
141 |
// log.error("调用发生异常,异常信息为:{0}", ex.getMessage()); |
efdc38
|
142 |
// ex.printStackTrace(); |
D |
143 |
log.error("模型运行异常", ex); |
ead005
|
144 |
throw new ModelInvokeException(ex.getMessage()); |
7fd198
|
145 |
} |
潘 |
146 |
return result; |
|
147 |
} |
|
148 |
|
|
149 |
/** |
fdcde1
|
150 |
* 预测,模拟调整 |
潘 |
151 |
* |
|
152 |
* @param predictTime |
|
153 |
* @param predictModel |
|
154 |
* @param itemName |
|
155 |
* @param itemNo |
b3674c
|
156 |
* @param deviationList |
fdcde1
|
157 |
* @return |
潘 |
158 |
* @throws ModelInvokeException |
|
159 |
*/ |
|
160 |
@Override |
b3674c
|
161 |
public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel, String itemName, String itemNo, List<StAdjustDeviationDTO> deviationList) throws ModelInvokeException { |
fdcde1
|
162 |
PredictResultVO result = new PredictResultVO(); |
潘 |
163 |
if (predictModel == null) { |
|
164 |
throw new ModelInvokeException("modelEntity is null"); |
|
165 |
} |
|
166 |
String modelId = predictModel.getId(); |
|
167 |
try { |
b3674c
|
168 |
List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName, new HashMap<>(), deviationList); |
fdcde1
|
169 |
String modelPath = predictModel.getModelpath(); |
潘 |
170 |
if (modelPath == null) { |
|
171 |
log.info("模型路径不存在,modelId=" + modelId); |
|
172 |
return null; |
|
173 |
} |
|
174 |
IAILModel newModelBean = composeNewModelBean(predictModel); |
|
175 |
HashMap<String, Object> settings = getPredictSettingsByModelId(modelId); |
|
176 |
// 校验setting必须有pyFile,否则可能导致程序崩溃 |
|
177 |
if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { |
b3674c
|
178 |
throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); |
fdcde1
|
179 |
} |
潘 |
180 |
|
|
181 |
if (settings == null) { |
|
182 |
log.error("模型setting不存在,modelId=" + modelId); |
|
183 |
return null; |
|
184 |
} |
|
185 |
int portLength = sampleDataList.size(); |
|
186 |
Object[] param2Values = new Object[portLength + 2]; |
|
187 |
for (int i = 0; i < portLength; i++) { |
|
188 |
param2Values[i] = sampleDataList.get(i).getMatrix(); |
|
189 |
} |
|
190 |
param2Values[portLength] = newModelBean.getDataMap().get("models"); |
|
191 |
param2Values[portLength + 1] = settings; |
|
192 |
|
b3674c
|
193 |
log.info("####################### 模拟调整 " + "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################"); |
fdcde1
|
194 |
log.info("参数: " + JSON.toJSONString(param2Values)); |
潘 |
195 |
|
|
196 |
//IAILMDK.run |
|
197 |
HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid()); |
|
198 |
//打印结果 |
|
199 |
log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getMethodname() + ",modelResult=" + JSON.toJSONString(modelResult)); |
|
200 |
//判断模型结果 |
|
201 |
if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || |
|
202 |
!modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { |
|
203 |
throw new ModelResultErrorException("模型结果异常:" + modelResult); |
|
204 |
} |
|
205 |
modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT); |
|
206 |
|
|
207 |
List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid()); |
|
208 |
Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(); |
|
209 |
for (MmItemOutputEntity output : itemOutputList) { |
|
210 |
if (!modelResult.containsKey(output.getResultstr())) { |
|
211 |
continue; |
|
212 |
} |
|
213 |
OutResultType outResultType = OutResultType.getEumByCode(output.getResultType()); |
|
214 |
switch (outResultType) { |
|
215 |
case D1: |
|
216 |
double[] temp1 = (double[]) modelResult.get(output.getResultstr()); |
|
217 |
predictMatrixs.put(output, temp1); |
|
218 |
break; |
|
219 |
case D2: |
|
220 |
double[][] temp2 = (double[][]) modelResult.get(output.getResultstr()); |
|
221 |
double[] tempColumn = new double[temp2.length]; |
|
222 |
for (int i = 0; i < tempColumn.length; i++) { |
|
223 |
tempColumn[i] = temp2[i][output.getResultIndex()]; |
|
224 |
} |
|
225 |
predictMatrixs.put(output, tempColumn); |
|
226 |
break; |
|
227 |
case D: |
|
228 |
Double temp3 = (Double) modelResult.get(output.getResultstr()); |
|
229 |
predictMatrixs.put(output, new double[]{temp3}); |
|
230 |
break; |
|
231 |
default: |
|
232 |
break; |
|
233 |
} |
|
234 |
} |
|
235 |
result.setPredictMatrixs(predictMatrixs); |
|
236 |
result.setModelResult(modelResult); |
|
237 |
result.setPredictTime(predictTime); |
|
238 |
} catch (ModelResultErrorException ex) { |
|
239 |
log.error("模型结果异常", ex); |
|
240 |
throw ex; |
|
241 |
} catch (Exception ex) { |
|
242 |
log.error("调用发生异常,异常信息为:{0}", ex.getMessage()); |
|
243 |
throw new ModelInvokeException(ex.getMessage()); |
|
244 |
} |
|
245 |
return result; |
|
246 |
} |
b3674c
|
247 |
|
fdcde1
|
248 |
/** |
7fd198
|
249 |
* 构造IAILMDK.run()方法的newModelBean参数 |
潘 |
250 |
* |
|
251 |
* @param predictModel |
|
252 |
* @return |
|
253 |
*/ |
|
254 |
private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) { |
|
255 |
IAILModel newModelBean = new IAILModel(); |
|
256 |
newModelBean.setClassName(predictModel.getClassname().trim()); |
|
257 |
newModelBean.setMethodName(predictModel.getMethodname().trim()); |
|
258 |
//构造参数类型 |
|
259 |
String[] paArStr = predictModel.getModelparamstructure().trim().split(","); |
|
260 |
Class<?>[] paramsArray = new Class[paArStr.length]; |
|
261 |
for (int i = 0; i < paArStr.length; i++) { |
|
262 |
if ("[[D".equals(paArStr[i])) { |
|
263 |
paramsArray[i] = double[][].class; |
|
264 |
} else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) { |
|
265 |
paramsArray[i] = HashMap.class; |
|
266 |
} |
|
267 |
} |
|
268 |
newModelBean.setParamsArray(paramsArray); |
|
269 |
HashMap<String, Object> dataMap = new HashMap<>(); |
|
270 |
HashMap<String, String> models = new HashMap<>(1); |
b2aca2
|
271 |
models.put("model_path", predictModel.getModelpath()); |
7fd198
|
272 |
dataMap.put("models", models); |
潘 |
273 |
newModelBean.setDataMap(dataMap); |
|
274 |
return newModelBean; |
|
275 |
} |
|
276 |
|
|
277 |
/** |
|
278 |
* 根据模型id获取参数map |
|
279 |
* |
|
280 |
* @param modelId |
|
281 |
* @return |
|
282 |
*/ |
|
283 |
private HashMap<String, Object> getPredictSettingsByModelId(String modelId) { |
|
284 |
List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId); |
|
285 |
HashMap<String, Object> result = new HashMap<>(); |
|
286 |
for (MmModelArithSettingsEntity entry : list) { |
|
287 |
String valueType = entry.getValuetype().trim(); //去除两端空格 |
|
288 |
if ("int".equals(valueType)) { |
|
289 |
int value = Integer.parseInt(entry.getValue()); |
|
290 |
result.put(entry.getKey(), value); |
|
291 |
} else if ("double".equals(valueType)) { |
|
292 |
double value = Double.parseDouble(entry.getValue()); |
|
293 |
result.put(entry.getKey(), value); |
|
294 |
} else if ("string".equals(valueType)) { |
|
295 |
String value = entry.getValue(); |
|
296 |
result.put(entry.getKey(), value); |
|
297 |
} else if ("decimalArray".equals(valueType)) { |
|
298 |
JSONArray valueArray = JSONArray.parseArray(entry.getValue()); |
|
299 |
double[] value = new double[valueArray.size()]; |
|
300 |
for (int i = 0; i < valueArray.size(); i++) { |
|
301 |
value[i] = Double.parseDouble(valueArray.get(i).toString()); |
|
302 |
} |
|
303 |
result.put(entry.getKey(), value); |
|
304 |
} else if ("decimal".equals(valueType)) { |
|
305 |
double value = Double.parseDouble(entry.getValue()); |
|
306 |
result.put(entry.getKey(), value); |
|
307 |
} |
|
308 |
} |
|
309 |
return result; |
|
310 |
} |
|
311 |
} |