package com.iailab.module.model.mdk.predict.impl; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.iail.model.IAILModel; import com.iailab.module.model.enums.CommonConstant; import com.iailab.module.model.common.enums.OutResultType; import com.iailab.module.model.common.exception.ModelResultErrorException; import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity; import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity; import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; import com.iailab.module.model.mcs.pre.service.MmItemOutputService; import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService; import com.iailab.module.model.mdk.common.enums.TypeA; import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException; import com.iailab.module.model.mdk.predict.PredictModelHandler; import com.iailab.module.model.mdk.sample.SampleConstructor; import com.iailab.module.model.mdk.sample.dto.SampleData; import com.iailab.module.model.mdk.vo.PredictResultVO; import com.iailab.module.model.mpk.common.MdkConstant; import com.iailab.module.model.mpk.common.utils.DllUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; /** * @author PanZhibao * @Description * @createTime 2024年09月01日 */ @Slf4j @Component public class PredictModelHandlerImpl implements PredictModelHandler { @Autowired private MmModelArithSettingsService mmModelArithSettingsService; @Autowired private MmItemOutputService mmItemOutputService; @Autowired private SampleConstructor sampleConstructor; /** * 根据模型预测,返回预测结果 * * @param predictTime * @param predictModel * @return * @throws ModelInvokeException */ @Override public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel,String itemName,String itemNo) throws ModelInvokeException { PredictResultVO result = new PredictResultVO(); if (predictModel == null) { throw new ModelInvokeException("modelEntity is null"); } String modelId = predictModel.getId(); try { List sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime, itemName, new HashMap<>()); String modelPath = predictModel.getModelpath(); if (modelPath == null) { log.info("模型路径不存在,modelId=" + modelId); return null; } IAILModel newModelBean = composeNewModelBean(predictModel); HashMap settings = getPredictSettingsByModelId(modelId); // 校验setting必须有pyFile,否则可能导致程序崩溃 if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); } if (settings == null) { log.error("模型setting不存在,modelId=" + modelId); return null; } int portLength = sampleDataList.size(); Object[] param2Values = new Object[portLength + 2]; for (int i = 0; i < portLength; i++) { param2Values[i] = sampleDataList.get(i).getMatrix(); } param2Values[portLength] = newModelBean.getDataMap().get("models"); param2Values[portLength + 1] = settings; log.info("####################### 预测模型 "+ "【itemId:" + predictModel.getItemid() + ",itemName:" + itemName + ",itemNo:" + itemNo + "】 ##########################"); // JSONObject jsonObjNewModelBean = new JSONObject(); // jsonObjNewModelBean.put("newModelBean", newModelBean); // log.info(String.valueOf(jsonObjNewModelBean)); // JSONObject jsonObjParam2Values = new JSONObject(); // jsonObjParam2Values.put("param2Values", param2Values); log.info("参数: " + JSON.toJSONString(param2Values)); //IAILMDK.run HashMap modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid()); //打印结果 log.info("预测模型计算完成:modelId=" + modelId + ",modelName=" + predictModel.getMethodname() + ",modelResult=" + JSON.toJSONString(modelResult)); //判断模型结果 if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { throw new ModelResultErrorException("模型结果异常:" + modelResult); } modelResult = (HashMap) modelResult.get(CommonConstant.MDK_RESULT); List itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid()); Map predictMatrixs = new HashMap<>(); Map predictDoubleValues = new HashMap<>(); for (MmItemOutputEntity output : itemOutputList) { if (!modelResult.containsKey(output.getResultstr())) { continue; } OutResultType outResultType = OutResultType.getEumByCode(output.getResultType()); switch (outResultType) { case D1: double[] temp1 = (double[]) modelResult.get(output.getResultstr()); predictMatrixs.put(output, temp1); break; case D2: double[][] temp2 = (double[][]) modelResult.get(output.getResultstr()); double[] tempColumn = new double[temp2.length]; for (int i = 0; i < tempColumn.length; i++) { tempColumn[i] = temp2[i][output.getResultIndex()]; } predictMatrixs.put(output, tempColumn); break; case D: Double temp3 = (Double) modelResult.get(output.getResultstr()); predictDoubleValues.put(output, temp3); break; default: break; } } result.setPredictMatrixs(predictMatrixs); result.setPredictDoubleValues(predictDoubleValues); result.setModelResult(modelResult); result.setPredictTime(predictTime); } catch (ModelResultErrorException ex) { // ex.printStackTrace(); log.error("模型结果异常", ex); throw ex; } catch (Exception ex) { // log.error("调用发生异常,异常信息为:{0}", ex.getMessage()); // ex.printStackTrace(); log.error("模型运行异常", ex); throw new ModelInvokeException(ex.getMessage()); } return result; } /** * 构造IAILMDK.run()方法的newModelBean参数 * * @param predictModel * @return */ private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) { IAILModel newModelBean = new IAILModel(); newModelBean.setClassName(predictModel.getClassname().trim()); newModelBean.setMethodName(predictModel.getMethodname().trim()); //构造参数类型 String[] paArStr = predictModel.getModelparamstructure().trim().split(","); Class[] paramsArray = new Class[paArStr.length]; for (int i = 0; i < paArStr.length; i++) { if ("[[D".equals(paArStr[i])) { paramsArray[i] = double[][].class; } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) { paramsArray[i] = HashMap.class; } } newModelBean.setParamsArray(paramsArray); HashMap dataMap = new HashMap<>(); HashMap models = new HashMap<>(1); models.put("model_path", predictModel.getModelpath()); dataMap.put("models", models); newModelBean.setDataMap(dataMap); return newModelBean; } /** * 根据模型id获取参数map * * @param modelId * @return */ private HashMap getPredictSettingsByModelId(String modelId) { List list = mmModelArithSettingsService.getByModelId(modelId); HashMap result = new HashMap<>(); for (MmModelArithSettingsEntity entry : list) { String valueType = entry.getValuetype().trim(); //去除两端空格 if ("int".equals(valueType)) { int value = Integer.parseInt(entry.getValue()); result.put(entry.getKey(), value); } else if ("double".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } else if ("string".equals(valueType)) { String value = entry.getValue(); result.put(entry.getKey(), value); } else if ("decimalArray".equals(valueType)) { JSONArray valueArray = JSONArray.parseArray(entry.getValue()); double[] value = new double[valueArray.size()]; for (int i = 0; i < valueArray.size(); i++) { value[i] = Double.parseDouble(valueArray.get(i).toString()); } result.put(entry.getKey(), value); } else if ("decimal".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } } return result; } }