package com.iailab.module.model.mdk.predict.impl; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.iail.IAILMDK; import com.iail.model.IAILModel; import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity; import com.iailab.module.model.mcs.pre.entity.MmModelResultstrEntity; import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService; import com.iailab.module.model.mcs.pre.service.MmModelResultstrService; import com.iailab.module.model.mdk.common.enums.TypeA; import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException; import com.iailab.module.model.mdk.predict.PredictModelHandler; import com.iailab.module.model.mdk.sample.SampleConstructor; import com.iailab.module.model.mdk.sample.dto.SampleData; import com.iailab.module.model.mdk.vo.PredictResultVO; import com.iailab.module.model.mpk.common.utils.DllUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.util.ArrayList; import java.util.Date; import java.util.HashMap; import java.util.List; /** * @author PanZhibao * @Description * @createTime 2024年09月01日 */ @Slf4j @Component public class PredictModelHandlerImpl implements PredictModelHandler { @Autowired private MmModelArithSettingsService mmModelArithSettingsService; @Autowired private MmModelResultstrService mmModelResultstrService; @Autowired private SampleConstructor sampleConstructor; @Override public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException { PredictResultVO result = new PredictResultVO(); if (predictModel == null) { throw new ModelInvokeException("modelEntity is null"); } String modelId = predictModel.getId(); try { List sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime); String modelPath = predictModel.getModelpath(); if (modelPath == null) { System.out.println("模型路径不存在,modelId=" + modelId); return null; } IAILModel newModelBean = composeNewModelBean(predictModel); HashMap settings = getPredictSettingsByModelId(modelId); if (settings == null) { log.error("模型setting不存在,modelId=" + modelId); return null; } int portLength = sampleDataList.size(); Object[] param2Values = new Object[portLength + 2]; for (int i = 0; i < portLength; i++) { param2Values[i]=sampleDataList.get(i).getMatrix(); } param2Values[portLength] = newModelBean.getDataMap().get("models"); param2Values[portLength+1] = settings; log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################"); JSONObject jsonObjNewModelBean = new JSONObject(); jsonObjNewModelBean.put("newModelBean", newModelBean); log.info(String.valueOf(jsonObjNewModelBean)); JSONObject jsonObjParam2Values = new JSONObject(); jsonObjParam2Values.put("param2Values", param2Values); log.info(String.valueOf(jsonObjParam2Values)); //IAILMDK.run // HashMap modelResult = IAILMDK.run(newModelBean, param2Values); HashMap modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid()); if(!modelResult.containsKey("status_code") || !modelResult.containsKey("result") || Integer.parseInt(modelResult.get("status_code").toString()) != 100) { throw new RuntimeException("模型结果异常:" + modelResult); } modelResult = (HashMap) modelResult.get("result"); //打印结果 JSONObject jsonObjResult = new JSONObject(); jsonObjResult.put("result", modelResult); log.info(String.valueOf(jsonObjResult)); MmModelResultstrEntity modelResultstr = mmModelResultstrService.getInfo(predictModel.getResultstrid()); log.info("模型计算完成:modelId=" + modelId + result); if (modelResult.containsKey(modelResultstr.getResultstr())) { Double[][] temp = (Double[][]) modelResult.get(modelResultstr.getResultstr()); double[][] temp1 = new double[temp.length][temp[0].length]; for (int i = 0; i < temp.length; i++) { for (int j = 0; j < temp[i].length; j++) { temp1[i][j] = temp[i][j].doubleValue(); } } result.setPredictMatrix(temp1); } result.setModelResult(modelResult); result.setPredictTime(predictTime); } catch (Exception ex) { log.error("IAILModel对象构造失败,modelId=" + modelId); log.error(ex.getMessage()); log.error("调用发生异常,异常信息为:{}" , ex); ex.printStackTrace(); } return result; } /** * 构造IAILMDK.run()方法的newModelBean参数 * * @param predictModel * @return */ private IAILModel composeNewModelBean(MmPredictModelEntity predictModel) { IAILModel newModelBean = new IAILModel(); newModelBean.setClassName(predictModel.getClassname().trim()); newModelBean.setMethodName(predictModel.getMethodname().trim()); //构造参数类型 String[] paArStr = predictModel.getModelparamstructure().trim().split(","); Class[] paramsArray = new Class[paArStr.length]; for (int i = 0; i < paArStr.length; i++) { if ("[[D".equals(paArStr[i])) { paramsArray[i] = double[][].class; } else if ("Map".equals(paArStr[i]) || "java.util.HashMap".equals(paArStr[i])) { paramsArray[i] = HashMap.class; } } newModelBean.setParamsArray(paramsArray); HashMap dataMap = new HashMap<>(); HashMap models = new HashMap<>(1); models.put("model_path", predictModel.getModelpath()); dataMap.put("models", models); newModelBean.setDataMap(dataMap); return newModelBean; } /** * 根据模型id获取参数map * * @param modelId * @return */ private HashMap getPredictSettingsByModelId(String modelId) { List list = mmModelArithSettingsService.getByModelId(modelId); HashMap result = new HashMap<>(); for (MmModelArithSettingsEntity entry : list) { String valueType = entry.getValuetype().trim(); //去除两端空格 if ("int".equals(valueType)) { int value = Integer.parseInt(entry.getValue()); result.put(entry.getKey(), value); } else if ("double".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } else if ("string".equals(valueType)) { String value = entry.getValue(); result.put(entry.getKey(), value); } else if ("decimalArray".equals(valueType)) { JSONArray valueArray = JSONArray.parseArray(entry.getValue()); double[] value = new double[valueArray.size()]; for (int i = 0; i < valueArray.size(); i++) { value[i] = Double.parseDouble(valueArray.get(i).toString()); } result.put(entry.getKey(), value); } else if ("decimal".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } } return result; } }