package com.iailab.module.model.mdk.predict.impl;
|
|
import com.alibaba.fastjson.JSONArray;
|
import com.alibaba.fastjson.JSONObject;
|
import com.iail.IAILMDK;
|
import com.iail.model.IAILModel;
|
import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
|
import com.iailab.module.model.mcs.pre.entity.MmModelResultstrEntity;
|
import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
|
import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
|
import com.iailab.module.model.mcs.pre.service.MmModelResultstrService;
|
import com.iailab.module.model.mdk.common.enums.TypeA;
|
import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
|
import com.iailab.module.model.mdk.predict.PredictModelHandler;
|
import com.iailab.module.model.mdk.sample.SampleConstructor;
|
import com.iailab.module.model.mdk.sample.dto.SampleData;
|
import com.iailab.module.model.mdk.vo.PredictResultVO;
|
import com.iailab.module.model.mpk.common.utils.DllUtils;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.stereotype.Component;
|
|
import java.util.ArrayList;
|
import java.util.Date;
|
import java.util.HashMap;
|
import java.util.List;
|
|
/**
|
* @author PanZhibao
|
* @Description
|
* @createTime 2024年09月01日
|
*/
|
@Slf4j
|
@Component
|
public class PredictModelHandlerImpl implements PredictModelHandler {
|
|
@Autowired
|
private MmModelArithSettingsService mmModelArithSettingsService;
|
|
@Autowired
|
private MmModelResultstrService mmModelResultstrService;
|
|
@Autowired
|
private SampleConstructor sampleConstructor;
|
|
@Override
|
public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException {
|
PredictResultVO result = new PredictResultVO();
|
if (predictModel == null) {
|
throw new ModelInvokeException("modelEntity is null");
|
}
|
String modelId = predictModel.getId();
|
|
try {
|
List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime);
|
String modelPath = predictModel.getModelpath();
|
if (modelPath == null) {
|
System.out.println("模型路径不存在,modelId=" + modelId);
|
return null;
|
}
|
IAILModel newModelBean = composeNewModelBean(predictModel);
|
HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
|
if (settings == null) {
|
log.error("模型setting不存在,modelId=" + modelId);
|
return null;
|
}
|
int portLength = sampleDataList.size();
|
Object[] param2Values = new Object[portLength + 2];
|
for (int i = 0; i < portLength; i++) {
|
param2Values[i]=sampleDataList.get(i).getMatrix();
|
}
|
param2Values[portLength] = newModelBean.getDataMap().get("models");
|
param2Values[portLength+1] = settings;
|
|
log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################");
|
JSONObject jsonObjNewModelBean = new JSONObject();
|
jsonObjNewModelBean.put("newModelBean", newModelBean);
|
log.info(String.valueOf(jsonObjNewModelBean));
|
JSONObject jsonObjParam2Values = new JSONObject();
|
jsonObjParam2Values.put("param2Values", param2Values);
|
log.info(String.valueOf(jsonObjParam2Values));
|
|
//IAILMDK.run
|
// HashMap<String, Object> modelResult = IAILMDK.run(newModelBean, param2Values);
|
HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
|
if(!modelResult.containsKey("status_code") || !modelResult.containsKey("result") || Integer.parseInt(modelResult.get("status_code").toString()) != 100) {
|
throw new RuntimeException("模型结果异常:" + modelResult);
|
}
|
|
modelResult = (HashMap<String, Object>) modelResult.get("result");
|
//打印结果
|
JSONObject jsonObjResult = new JSONObject();
|
jsonObjResult.put("result", modelResult);
|
log.info(String.valueOf(jsonObjResult));
|
|
MmModelResultstrEntity modelResultstr = mmModelResultstrService.getInfo(predictModel.getResultstrid());
|
log.info("模型计算完成:modelId=" + modelId + result);
|
if (modelResult.containsKey(modelResultstr.getResultstr())) {
|
Double[][] temp = (Double[][]) modelResult.get(modelResultstr.getResultstr());
|
double[][] temp1 = new double[temp.length][temp[0].length];
|
for (int i = 0; i < temp.length; i++) {
|
for (int j = 0; j < temp[i].length; j++) {
|
temp1[i][j] = temp[i][j].doubleValue();
|
}
|
}
|
result.setPredictMatrix(temp1);
|
}
|
result.setModelResult(modelResult);
|
result.setPredictTime(predictTime);
|
} catch (Exception ex) {
|
log.error("IAILModel对象构造失败,modelId=" + modelId);
|
log.error(ex.getMessage());
|
log.error("调用发生异常,异常信息为:{}" , ex);
|
ex.printStackTrace();
|
|
}
|
return result;
|
}
|
|
/**
|
* 构造IAILMDK.run()方法的newModelBean参数
|
*
|
* @param predictModel
|
* @return
|
*/
|
private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
|
IAILModel newModelBean = new IAILModel();
|
newModelBean.setClassName(predictModel.getClassname().trim());
|
newModelBean.setMethodName(predictModel.getMethodname().trim());
|
//构造参数类型
|
String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
|
Class<?>[] paramsArray = new Class[paArStr.length];
|
for (int i = 0; i < paArStr.length; i++) {
|
if ("[[D".equals(paArStr[i])) {
|
paramsArray[i] = double[][].class;
|
} else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
|
paramsArray[i] = HashMap.class;
|
}
|
}
|
newModelBean.setParamsArray(paramsArray);
|
HashMap<String, Object> dataMap = new HashMap<>();
|
HashMap<String, String> models = new HashMap<>(1);
|
models.put("model_path", predictModel.getModelpath());
|
dataMap.put("models", models);
|
newModelBean.setDataMap(dataMap);
|
return newModelBean;
|
}
|
|
/**
|
* 根据模型id获取参数map
|
*
|
* @param modelId
|
* @return
|
*/
|
private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
|
List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
|
HashMap<String, Object> result = new HashMap<>();
|
for (MmModelArithSettingsEntity entry : list) {
|
String valueType = entry.getValuetype().trim(); //去除两端空格
|
if ("int".equals(valueType)) {
|
int value = Integer.parseInt(entry.getValue());
|
result.put(entry.getKey(), value);
|
} else if ("double".equals(valueType)) {
|
double value = Double.parseDouble(entry.getValue());
|
result.put(entry.getKey(), value);
|
} else if ("string".equals(valueType)) {
|
String value = entry.getValue();
|
result.put(entry.getKey(), value);
|
} else if ("decimalArray".equals(valueType)) {
|
JSONArray valueArray = JSONArray.parseArray(entry.getValue());
|
double[] value = new double[valueArray.size()];
|
for (int i = 0; i < valueArray.size(); i++) {
|
value[i] = Double.parseDouble(valueArray.get(i).toString());
|
}
|
result.put(entry.getKey(), value);
|
} else if ("decimal".equals(valueType)) {
|
double value = Double.parseDouble(entry.getValue());
|
result.put(entry.getKey(), value);
|
}
|
}
|
return result;
|
}
|
}
|