package com.iailab.module.model.mdk.predict.impl; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.iail.model.IAILModel; import com.iailab.module.model.common.enums.CommonConstant; import com.iailab.module.model.common.enums.OutResultType; import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity; import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity; import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; import com.iailab.module.model.mcs.pre.service.MmItemOutputService; import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService; import com.iailab.module.model.mdk.common.enums.TypeA; import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException; import com.iailab.module.model.mdk.predict.PredictModelHandler; import com.iailab.module.model.mdk.sample.SampleConstructor; import com.iailab.module.model.mdk.sample.dto.SampleData; import com.iailab.module.model.mdk.vo.PredictResultVO; import com.iailab.module.model.mpk.common.utils.DllUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.util.*; /** * @author PanZhibao * @Description * @createTime 2024年09月01日 */ @Slf4j @Component public class PredictModelHandlerImpl implements PredictModelHandler { @Autowired private MmModelArithSettingsService mmModelArithSettingsService; @Autowired private MmItemOutputService mmItemOutputService; @Autowired private SampleConstructor sampleConstructor; /** * 根据模型预测,返回预测结果 * * @param predictTime * @param predictModel * @return * @throws ModelInvokeException */ @Override public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException { PredictResultVO result = new PredictResultVO(); if (predictModel == null) { throw new ModelInvokeException("modelEntity is null"); } String modelId = predictModel.getId(); try { List sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime); String modelPath = predictModel.getModelpath(); if (modelPath == null) { log.info("模型路径不存在,modelId=" + modelId); return null; } IAILModel newModelBean = composeNewModelBean(predictModel); HashMap settings = getPredictSettingsByModelId(modelId); if (settings == null) { log.error("模型setting不存在,modelId=" + modelId); return null; } int portLength = sampleDataList.size(); Object[] param2Values = new Object[portLength + 2]; for (int i = 0; i < portLength; i++) { param2Values[i] = sampleDataList.get(i).getMatrix(); } param2Values[portLength] = newModelBean.getDataMap().get("models"); param2Values[portLength + 1] = settings; log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################"); JSONObject jsonObjNewModelBean = new JSONObject(); jsonObjNewModelBean.put("newModelBean", newModelBean); log.info(String.valueOf(jsonObjNewModelBean)); JSONObject jsonObjParam2Values = new JSONObject(); jsonObjParam2Values.put("param2Values", param2Values); log.info(String.valueOf(jsonObjParam2Values)); //IAILMDK.run HashMap modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid()); if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { throw new RuntimeException("模型结果异常:" + modelResult); } modelResult = (HashMap) modelResult.get(CommonConstant.MDK_RESULT); //打印结果 log.info("模型计算完成:modelId=" + modelId + modelResult); JSONObject jsonObjResult = new JSONObject(); jsonObjResult.put("result", modelResult); log.info(String.valueOf(jsonObjResult)); List itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid()); Map predictMatrixs = new HashMap<>(itemOutputList.size()); for (MmItemOutputEntity output : itemOutputList) { if (!modelResult.containsKey(output.getResultstr())) { continue; } OutResultType outResultType = OutResultType.getEumByCode(output.getResultType()); switch (outResultType) { case D1: double[] temp1 = (double[]) modelResult.get(output.getResultstr()); predictMatrixs.put(output, temp1); break; case D2: double[][] temp2 = (double[][]) modelResult.get(output.getResultstr()); double[] tempColumn = new double[temp2.length]; for (int i = 0; i < tempColumn.length; i++) { tempColumn[i] = temp2[i][output.getResultIndex()]; } predictMatrixs.put(output, tempColumn); break; default: break; } } result.setPredictMatrixs(predictMatrixs); result.setModelResult(modelResult); result.setPredictTime(predictTime); } catch (Exception ex) { log.error("调用发生异常,异常信息为:{}", ex); ex.printStackTrace(); throw new ModelInvokeException(ex.getMessage()); } return result; } /** * 构造IAILMDK.run()方法的newModelBean参数 * * @param predictModel * @return */ private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) { IAILModel newModelBean = new IAILModel(); newModelBean.setClassName(predictModel.getClassname().trim()); newModelBean.setMethodName(predictModel.getMethodname().trim()); //构造参数类型 String[] paArStr = predictModel.getModelparamstructure().trim().split(","); Class[] paramsArray = new Class[paArStr.length]; for (int i = 0; i < paArStr.length; i++) { if ("[[D".equals(paArStr[i])) { paramsArray[i] = double[][].class; } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) { paramsArray[i] = HashMap.class; } } newModelBean.setParamsArray(paramsArray); HashMap dataMap = new HashMap<>(); HashMap models = new HashMap<>(1); models.put("model_path", predictModel.getModelpath()); dataMap.put("models", models); newModelBean.setDataMap(dataMap); return newModelBean; } /** * 根据模型id获取参数map * * @param modelId * @return */ private HashMap getPredictSettingsByModelId(String modelId) { List list = mmModelArithSettingsService.getByModelId(modelId); HashMap result = new HashMap<>(); for (MmModelArithSettingsEntity entry : list) { String valueType = entry.getValuetype().trim(); //去除两端空格 if ("int".equals(valueType)) { int value = Integer.parseInt(entry.getValue()); result.put(entry.getKey(), value); } else if ("double".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } else if ("string".equals(valueType)) { String value = entry.getValue(); result.put(entry.getKey(), value); } else if ("decimalArray".equals(valueType)) { JSONArray valueArray = JSONArray.parseArray(entry.getValue()); double[] value = new double[valueArray.size()]; for (int i = 0; i < valueArray.size(); i++) { value[i] = Double.parseDouble(valueArray.get(i).toString()); } result.put(entry.getKey(), value); } else if ("decimal".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } } return result; } }