package com.iailab.module.model.mdk.predict.impl;
|
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONArray;
|
import com.alibaba.fastjson.JSONObject;
|
import com.iail.model.IAILModel;
|
import com.iailab.module.model.common.enums.CommonConstant;
|
import com.iailab.module.model.common.enums.OutResultType;
|
import com.iailab.module.model.common.exception.ModelResultErrorException;
|
import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
|
import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
|
import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
|
import com.iailab.module.model.mcs.pre.enums.ItemRunStatusEnum;
|
import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
|
import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
|
import com.iailab.module.model.mdk.common.enums.TypeA;
|
import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
|
import com.iailab.module.model.mdk.predict.PredictModelHandler;
|
import com.iailab.module.model.mdk.sample.SampleConstructor;
|
import com.iailab.module.model.mdk.sample.dto.SampleData;
|
import com.iailab.module.model.mdk.vo.PredictResultVO;
|
import com.iailab.module.model.mpk.common.MdkConstant;
|
import com.iailab.module.model.mpk.common.utils.DllUtils;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.stereotype.Component;
|
|
import java.util.Date;
|
import java.util.HashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
/**
|
* @author PanZhibao
|
* @Description
|
* @createTime 2024年09月01日
|
*/
|
@Slf4j
|
@Component
|
public class PredictModelHandlerImpl implements PredictModelHandler {
|
|
@Autowired
|
private MmModelArithSettingsService mmModelArithSettingsService;
|
|
@Autowired
|
private MmItemOutputService mmItemOutputService;
|
|
@Autowired
|
private SampleConstructor sampleConstructor;
|
|
/**
|
* 根据模型预测,返回预测结果
|
*
|
* @param predictTime
|
* @param predictModel
|
* @return
|
* @throws ModelInvokeException
|
*/
|
@Override
|
public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel,String itemName,String itemNo) throws ModelInvokeException {
|
PredictResultVO result = new PredictResultVO();
|
if (predictModel == null) {
|
throw new ModelInvokeException("modelEntity is null");
|
}
|
String modelId = predictModel.getId();
|
try {
|
List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName, new HashMap<>());
|
String modelPath = predictModel.getModelpath();
|
if (modelPath == null) {
|
log.info("模型路径不存在,modelId=" + modelId);
|
return null;
|
}
|
IAILModel newModelBean = composeNewModelBean(predictModel);
|
HashMap<String, Object> settings = getPredictSettingsByModelId(modelId);
|
// 校验setting必须有pyFile,否则可能导致程序崩溃
|
if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
|
throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!");
|
}
|
|
if (settings == null) {
|
log.error("模型setting不存在,modelId=" + modelId);
|
return null;
|
}
|
int portLength = sampleDataList.size();
|
Object[] param2Values = new Object[portLength + 2];
|
for (int i = 0; i < portLength; i++) {
|
param2Values[i] = sampleDataList.get(i).getMatrix();
|
}
|
param2Values[portLength] = newModelBean.getDataMap().get("models");
|
param2Values[portLength + 1] = settings;
|
|
log.info("####################### 预测模型 "+ "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################");
|
// JSONObject jsonObjNewModelBean = new JSONObject();
|
// jsonObjNewModelBean.put("newModelBean", newModelBean);
|
// log.info(String.valueOf(jsonObjNewModelBean));
|
// JSONObject jsonObjParam2Values = new JSONObject();
|
// jsonObjParam2Values.put("param2Values", param2Values);
|
log.info("参数: " + JSON.toJSONString(param2Values));
|
|
//IAILMDK.run
|
HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid());
|
//打印结果
|
log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getMethodname() + ",modelResult=" + JSON.toJSONString(modelResult));
|
//判断模型结果
|
if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
|
!modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
|
throw new ModelResultErrorException("模型结果异常:" + modelResult);
|
}
|
modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
|
|
List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
|
Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>();
|
Map<MmItemOutputEntity, Double> predictDoubleValues = new HashMap<>();
|
for (MmItemOutputEntity output : itemOutputList) {
|
if (!modelResult.containsKey(output.getResultstr())) {
|
continue;
|
}
|
OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
|
switch (outResultType) {
|
case D1:
|
double[] temp1 = (double[]) modelResult.get(output.getResultstr());
|
predictMatrixs.put(output, temp1);
|
break;
|
case D2:
|
double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
|
double[] tempColumn = new double[temp2.length];
|
for (int i = 0; i < tempColumn.length; i++) {
|
tempColumn[i] = temp2[i][output.getResultIndex()];
|
}
|
predictMatrixs.put(output, tempColumn);
|
break;
|
case D:
|
Double temp3 = (Double) modelResult.get(output.getResultstr());
|
predictDoubleValues.put(output, temp3);
|
break;
|
default:
|
break;
|
}
|
}
|
result.setPredictMatrixs(predictMatrixs);
|
result.setPredictDoubleValues(predictDoubleValues);
|
result.setModelResult(modelResult);
|
result.setPredictTime(predictTime);
|
} catch (ModelResultErrorException ex) {
|
// ex.printStackTrace();
|
log.error("模型结果异常", ex);
|
throw ex;
|
} catch (Exception ex) {
|
// log.error("调用发生异常,异常信息为:{0}", ex.getMessage());
|
// ex.printStackTrace();
|
log.error("模型运行异常", ex);
|
throw new ModelInvokeException(ex.getMessage());
|
}
|
return result;
|
}
|
|
/**
|
* 构造IAILMDK.run()方法的newModelBean参数
|
*
|
* @param predictModel
|
* @return
|
*/
|
private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) {
|
IAILModel newModelBean = new IAILModel();
|
newModelBean.setClassName(predictModel.getClassname().trim());
|
newModelBean.setMethodName(predictModel.getMethodname().trim());
|
//构造参数类型
|
String[] paArStr = predictModel.getModelparamstructure().trim().split(",");
|
Class<?>[] paramsArray = new Class[paArStr.length];
|
for (int i = 0; i < paArStr.length; i++) {
|
if ("[[D".equals(paArStr[i])) {
|
paramsArray[i] = double[][].class;
|
} else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) {
|
paramsArray[i] = HashMap.class;
|
}
|
}
|
newModelBean.setParamsArray(paramsArray);
|
HashMap<String, Object> dataMap = new HashMap<>();
|
HashMap<String, String> models = new HashMap<>(1);
|
models.put("model_path", predictModel.getModelpath());
|
dataMap.put("models", models);
|
newModelBean.setDataMap(dataMap);
|
return newModelBean;
|
}
|
|
/**
|
* 根据模型id获取参数map
|
*
|
* @param modelId
|
* @return
|
*/
|
private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
|
List<MmModelArithSettingsEntity> list = mmModelArithSettingsService.getByModelId(modelId);
|
HashMap<String, Object> result = new HashMap<>();
|
for (MmModelArithSettingsEntity entry : list) {
|
String valueType = entry.getValuetype().trim(); //去除两端空格
|
if ("int".equals(valueType)) {
|
int value = Integer.parseInt(entry.getValue());
|
result.put(entry.getKey(), value);
|
} else if ("double".equals(valueType)) {
|
double value = Double.parseDouble(entry.getValue());
|
result.put(entry.getKey(), value);
|
} else if ("string".equals(valueType)) {
|
String value = entry.getValue();
|
result.put(entry.getKey(), value);
|
} else if ("decimalArray".equals(valueType)) {
|
JSONArray valueArray = JSONArray.parseArray(entry.getValue());
|
double[] value = new double[valueArray.size()];
|
for (int i = 0; i < valueArray.size(); i++) {
|
value[i] = Double.parseDouble(valueArray.get(i).toString());
|
}
|
result.put(entry.getKey(), value);
|
} else if ("decimal".equals(valueType)) {
|
double value = Double.parseDouble(entry.getValue());
|
result.put(entry.getKey(), value);
|
}
|
}
|
return result;
|
}
|
}
|