提交 | 用户 | 时间
|
7fd198
|
1 |
package com.iailab.module.model.mdk.predict.impl; |
潘 |
2 |
|
|
3 |
import com.alibaba.fastjson.JSONArray; |
|
4 |
import com.alibaba.fastjson.JSONObject; |
|
5 |
import com.iail.IAILMDK; |
|
6 |
import com.iail.model.IAILModel; |
|
7 |
import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity; |
|
8 |
import com.iailab.module.model.mcs.pre.entity.MmModelResultstrEntity; |
|
9 |
import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; |
|
10 |
import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService; |
|
11 |
import com.iailab.module.model.mcs.pre.service.MmModelResultstrService; |
|
12 |
import com.iailab.module.model.mdk.common.enums.TypeA; |
|
13 |
import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException; |
|
14 |
import com.iailab.module.model.mdk.predict.PredictModelHandler; |
|
15 |
import com.iailab.module.model.mdk.sample.SampleConstructor; |
|
16 |
import com.iailab.module.model.mdk.sample.dto.SampleData; |
|
17 |
import com.iailab.module.model.mdk.vo.PredictResultVO; |
|
18 |
import lombok.extern.slf4j.Slf4j; |
|
19 |
import org.springframework.beans.factory.annotation.Autowired; |
|
20 |
import org.springframework.stereotype.Component; |
|
21 |
|
|
22 |
import java.util.ArrayList; |
|
23 |
import java.util.Date; |
|
24 |
import java.util.HashMap; |
|
25 |
import java.util.List; |
|
26 |
|
|
27 |
/** |
|
28 |
* @author PanZhibao |
|
29 |
* @Description |
|
30 |
* @createTime 2024年09月01日 |
|
31 |
*/ |
|
32 |
@Slf4j |
|
33 |
@Component |
|
34 |
public class PredictModelHandlerImpl implements PredictModelHandler { |
|
35 |
|
|
36 |
@Autowired |
|
37 |
private MmModelArithSettingsService mmModelArithSettingsService; |
|
38 |
|
|
39 |
@Autowired |
|
40 |
private MmModelResultstrService mmModelResultstrService; |
|
41 |
|
|
42 |
@Autowired |
|
43 |
private SampleConstructor sampleConstructor; |
|
44 |
|
|
45 |
@Override |
|
46 |
public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException { |
|
47 |
PredictResultVO result = new PredictResultVO(); |
|
48 |
if (predictModel == null) { |
|
49 |
throw new ModelInvokeException("modelEntity is null"); |
|
50 |
} |
|
51 |
String modelId = predictModel.getId(); |
|
52 |
|
|
53 |
try { |
|
54 |
List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime); |
|
55 |
String modelPath = predictModel.getModelpath(); |
|
56 |
if (modelPath == null) { |
|
57 |
System.out.println("模型路径不存在,modelId=" + modelId); |
|
58 |
return null; |
|
59 |
} |
|
60 |
IAILModel newModelBean = composeNewModelBean(predictModel); |
|
61 |
HashMap<String, Object> settings = getPredictSettingsByModelId(modelId); |
|
62 |
if (settings == null) { |
|
63 |
log.error("模型setting不存在,modelId=" + modelId); |
|
64 |
return null; |
|
65 |
} |
|
66 |
int portLength = sampleDataList.size(); |
|
67 |
Object[] param2Values = new Object[portLength + 2]; |
|
68 |
for (int i = 0; i < portLength; i++) { |
|
69 |
param2Values[i]=sampleDataList.get(i).getMatrix(); |
|
70 |
} |
|
71 |
param2Values[portLength] = newModelBean.getDataMap().get("models"); |
|
72 |
param2Values[portLength+1] = settings; |
|
73 |
|
|
74 |
log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################"); |
|
75 |
JSONObject jsonObjNewModelBean = new JSONObject(); |
|
76 |
jsonObjNewModelBean.put("newModelBean", newModelBean); |
|
77 |
log.info(String.valueOf(jsonObjNewModelBean)); |
|
78 |
JSONObject jsonObjParam2Values = new JSONObject(); |
|
79 |
jsonObjParam2Values.put("param2Values", param2Values); |
|
80 |
log.info(String.valueOf(jsonObjParam2Values)); |
|
81 |
|
|
82 |
//IAILMDK.run |
|
83 |
HashMap<String, Object> modelResult = IAILMDK.run(newModelBean, param2Values); |
|
84 |
//打印结果 |
|
85 |
JSONObject jsonObjResult = new JSONObject(); |
|
86 |
jsonObjResult.put("result", result); |
|
87 |
log.info(String.valueOf(jsonObjResult)); |
|
88 |
|
|
89 |
MmModelResultstrEntity modelResultstr = mmModelResultstrService.getInfo(predictModel.getResultstrid()); |
|
90 |
log.info("模型计算完成:modelId=" + modelId + result); |
|
91 |
double[][] temp = (double[][]) modelResult.get(modelResultstr.getResultstr()); |
|
92 |
result.setPredictMatrix(temp); |
|
93 |
result.setPredictTime(predictTime); |
|
94 |
} catch (Exception ex) { |
|
95 |
log.error("IAILModel对象构造失败,modelId=" + modelId); |
|
96 |
log.error(ex.getMessage()); |
|
97 |
log.error("调用发生异常,异常信息为:{}" , ex); |
|
98 |
ex.printStackTrace(); |
|
99 |
|
|
100 |
} |
|
101 |
return result; |
|
102 |
} |
|
103 |
|
|
104 |
/** |
|
105 |
* 构造IAILMDK.run()方法的newModelBean参数 |
|
106 |
* |
|
107 |
* @param predictModel |
|
108 |
* @return |
|
109 |
*/ |
|
110 |
private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) { |
|
111 |
IAILModel newModelBean = new IAILModel(); |
|
112 |
newModelBean.setClassName(predictModel.getClassname().trim()); |
|
113 |
newModelBean.setMethodName(predictModel.getMethodname().trim()); |
|
114 |
//构造参数类型 |
|
115 |
String[] paArStr = predictModel.getModelparamstructure().trim().split(","); |
|
116 |
Class<?>[] paramsArray = new Class[paArStr.length]; |
|
117 |
for (int i = 0; i < paArStr.length; i++) { |
|
118 |
if ("[[D".equals(paArStr[i])) { |
|
119 |
paramsArray[i] = double[][].class; |
|
120 |
} else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) { |
|
121 |
paramsArray[i] = HashMap.class; |
|
122 |
} |
|
123 |
} |
|
124 |
newModelBean.setParamsArray(paramsArray); |
|
125 |
HashMap<String, Object> dataMap = new HashMap<>(); |
|
126 |
HashMap<String, String> models = new HashMap<>(1); |
|
127 |
models.put("paramFile", predictModel.getModelpath()); |
|
128 |
dataMap.put("models", models); |
|
129 |
newModelBean.setDataMap(dataMap); |
|
130 |
return newModelBean; |
|
131 |
} |
|
132 |
|
|
133 |
/** |
|
134 |
* 根据模型id获取参数map |
|
135 |
* |
|
136 |
* @param modelId |
|
137 |
* @return |
|
138 |
*/ |
|
139 |
private HashMap<String, Object> getPredictSettingsByModelId(String modelId) { |
|
140 |
List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId); |
|
141 |
HashMap<String, Object> result = new HashMap<>(); |
|
142 |
for (MmModelArithSettingsEntity entry : list) { |
|
143 |
String valueType = entry.getValuetype().trim(); //去除两端空格 |
|
144 |
if ("int".equals(valueType)) { |
|
145 |
int value = Integer.parseInt(entry.getValue()); |
|
146 |
result.put(entry.getKey(), value); |
|
147 |
} else if ("double".equals(valueType)) { |
|
148 |
double value = Double.parseDouble(entry.getValue()); |
|
149 |
result.put(entry.getKey(), value); |
|
150 |
} else if ("string".equals(valueType)) { |
|
151 |
String value = entry.getValue(); |
|
152 |
result.put(entry.getKey(), value); |
|
153 |
} else if ("decimalArray".equals(valueType)) { |
|
154 |
JSONArray valueArray = JSONArray.parseArray(entry.getValue()); |
|
155 |
double[] value = new double[valueArray.size()]; |
|
156 |
for (int i = 0; i < valueArray.size(); i++) { |
|
157 |
value[i] = Double.parseDouble(valueArray.get(i).toString()); |
|
158 |
} |
|
159 |
result.put(entry.getKey(), value); |
|
160 |
} else if ("decimal".equals(valueType)) { |
|
161 |
double value = Double.parseDouble(entry.getValue()); |
|
162 |
result.put(entry.getKey(), value); |
|
163 |
} |
|
164 |
} |
|
165 |
return result; |
|
166 |
} |
|
167 |
} |