| | |
| | | package com.iailab.module.model.mdk.predict.impl; |
| | | |
| | | import com.alibaba.fastjson.JSON; |
| | | import com.alibaba.fastjson.JSONArray; |
| | | import com.alibaba.fastjson.JSONObject; |
| | | import com.iail.IAILMDK; |
| | | import com.iail.model.IAILModel; |
| | | import com.iailab.module.model.common.enums.CommonConstant; |
| | | import com.iailab.module.model.common.enums.OutResultType; |
| | | import com.iailab.module.model.common.exception.ModelResultErrorException; |
| | | import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity; |
| | | import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity; |
| | | import com.iailab.module.model.mcs.pre.entity.MmModelResultstrEntity; |
| | | import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; |
| | | import com.iailab.module.model.mcs.pre.enums.ItemRunStatusEnum; |
| | | import com.iailab.module.model.mcs.pre.service.MmItemOutputService; |
| | | import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService; |
| | | import com.iailab.module.model.mcs.pre.service.MmModelResultstrService; |
| | | import com.iailab.module.model.mdk.common.enums.TypeA; |
| | | import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException; |
| | | import com.iailab.module.model.mdk.predict.PredictModelHandler; |
| | | import com.iailab.module.model.mdk.sample.SampleConstructor; |
| | | import com.iailab.module.model.mdk.sample.dto.SampleData; |
| | | import com.iailab.module.model.mdk.vo.PredictResultVO; |
| | | import com.iailab.module.model.mpk.common.MdkConstant; |
| | | import com.iailab.module.model.mpk.common.utils.DllUtils; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import org.springframework.stereotype.Component; |
| | | |
| | | import java.util.ArrayList; |
| | | import java.util.Date; |
| | | import java.util.HashMap; |
| | | import java.util.List; |
| | | import java.util.Map; |
| | | |
| | | /** |
| | | * @author PanZhibao |
| | |
| | | private MmModelArithSettingsService mmModelArithSettingsService; |
| | | |
| | | @Autowired |
| | | private MmModelResultstrService mmModelResultstrService; |
| | | private MmItemOutputService mmItemOutputService; |
| | | |
| | | @Autowired |
| | | private SampleConstructor sampleConstructor; |
| | | |
| | | /** |
| | | * 根据模型预测,返回预测结果 |
| | | * |
| | | * @param predictTime |
| | | * @param predictModel |
| | | * @return |
| | | * @throws ModelInvokeException |
| | | */ |
| | | @Override |
| | | public PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel) throws ModelInvokeException { |
| | | public synchronized PredictResultVO predictByModel(Date predictTime, MmPredictModelEntity predictModel,String itemName) throws ModelInvokeException { |
| | | PredictResultVO result = new PredictResultVO(); |
| | | if (predictModel == null) { |
| | | throw new ModelInvokeException("modelEntity is null"); |
| | | } |
| | | String modelId = predictModel.getId(); |
| | | |
| | | try { |
| | | List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime); |
| | | String modelPath = predictModel.getModelpath(); |
| | | if (modelPath == null) { |
| | | System.out.println("模型路径不存在,modelId=" + modelId); |
| | | log.info("模型路径不存在,modelId=" + modelId); |
| | | return null; |
| | | } |
| | | IAILModel newModelBean = composeNewModelBean(predictModel); |
| | | HashMap<String, Object> settings = getPredictSettingsByModelId(modelId); |
| | | // 校验setting必须有pyFile,否则可能导致程序崩溃 |
| | | if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { |
| | | throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); |
| | | } |
| | | |
| | | if (settings == null) { |
| | | log.error("模型setting不存在,modelId=" + modelId); |
| | | return null; |
| | |
| | | int portLength = sampleDataList.size(); |
| | | Object[] param2Values = new Object[portLength + 2]; |
| | | for (int i = 0; i < portLength; i++) { |
| | | param2Values[i]=sampleDataList.get(i).getMatrix(); |
| | | param2Values[i] = sampleDataList.get(i).getMatrix(); |
| | | } |
| | | param2Values[portLength] = newModelBean.getDataMap().get("models"); |
| | | param2Values[portLength+1] = settings; |
| | | param2Values[portLength + 1] = settings; |
| | | |
| | | log.info("#######################预测模型 " + predictModel.getItemid() + " ##########################"); |
| | | JSONObject jsonObjNewModelBean = new JSONObject(); |
| | | jsonObjNewModelBean.put("newModelBean", newModelBean); |
| | | log.info(String.valueOf(jsonObjNewModelBean)); |
| | | JSONObject jsonObjParam2Values = new JSONObject(); |
| | | jsonObjParam2Values.put("param2Values", param2Values); |
| | | log.info(String.valueOf(jsonObjParam2Values)); |
| | | log.info("####################### 预测模型 "+ "【itemId:" + predictModel.getItemid() + ",itemName" + itemName + "】 ##########################"); |
| | | // JSONObject jsonObjNewModelBean = new JSONObject(); |
| | | // jsonObjNewModelBean.put("newModelBean", newModelBean); |
| | | // log.info(String.valueOf(jsonObjNewModelBean)); |
| | | // JSONObject jsonObjParam2Values = new JSONObject(); |
| | | // jsonObjParam2Values.put("param2Values", param2Values); |
| | | log.info("参数: " + JSON.toJSONString(param2Values)); |
| | | |
| | | //IAILMDK.run |
| | | HashMap<String, Object> modelResult = IAILMDK.run(newModelBean, param2Values); |
| | | HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid()); |
| | | if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || |
| | | !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { |
| | | throw new ModelResultErrorException("模型结果异常:" + modelResult); |
| | | } |
| | | modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT); |
| | | //打印结果 |
| | | log.info("预测模型计算完成:modelId=" + modelId + ",modelName" + predictModel.getMethodname()); |
| | | JSONObject jsonObjResult = new JSONObject(); |
| | | jsonObjResult.put("result", result); |
| | | jsonObjResult.put("result", modelResult); |
| | | log.info(String.valueOf(jsonObjResult)); |
| | | |
| | | MmModelResultstrEntity modelResultstr = mmModelResultstrService.getInfo(predictModel.getResultstrid()); |
| | | log.info("模型计算完成:modelId=" + modelId + result); |
| | | double[][] temp = (double[][]) modelResult.get(modelResultstr.getResultstr()); |
| | | result.setPredictMatrix(temp); |
| | | List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid()); |
| | | Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(); |
| | | Map<MmItemOutputEntity, Double> predictDoubleValues = new HashMap<>(); |
| | | for (MmItemOutputEntity output : itemOutputList) { |
| | | if (!modelResult.containsKey(output.getResultstr())) { |
| | | continue; |
| | | } |
| | | OutResultType outResultType = OutResultType.getEumByCode(output.getResultType()); |
| | | switch (outResultType) { |
| | | case D1: |
| | | double[] temp1 = (double[]) modelResult.get(output.getResultstr()); |
| | | predictMatrixs.put(output, temp1); |
| | | break; |
| | | case D2: |
| | | double[][] temp2 = (double[][]) modelResult.get(output.getResultstr()); |
| | | double[] tempColumn = new double[temp2.length]; |
| | | for (int i = 0; i < tempColumn.length; i++) { |
| | | tempColumn[i] = temp2[i][output.getResultIndex()]; |
| | | } |
| | | predictMatrixs.put(output, tempColumn); |
| | | break; |
| | | case D: |
| | | Double temp3 = (Double) modelResult.get(output.getResultstr()); |
| | | predictDoubleValues.put(output, temp3); |
| | | break; |
| | | default: |
| | | break; |
| | | } |
| | | } |
| | | result.setPredictMatrixs(predictMatrixs); |
| | | result.setPredictDoubleValues(predictDoubleValues); |
| | | result.setModelResult(modelResult); |
| | | result.setPredictTime(predictTime); |
| | | } catch (Exception ex) { |
| | | log.error("IAILModel对象构造失败,modelId=" + modelId); |
| | | log.error(ex.getMessage()); |
| | | log.error("调用发生异常,异常信息为:{}" , ex); |
| | | log.error("调用发生异常,异常信息为:{}", ex); |
| | | ex.printStackTrace(); |
| | | |
| | | throw new ModelInvokeException(ex.getMessage()); |
| | | } |
| | | return result; |
| | | } |
| | |
| | | newModelBean.setParamsArray(paramsArray); |
| | | HashMap<String, Object> dataMap = new HashMap<>(); |
| | | HashMap<String, String> models = new HashMap<>(1); |
| | | models.put("paramFile", predictModel.getModelpath()); |
| | | models.put("model_path", predictModel.getModelpath()); |
| | | dataMap.put("models", models); |
| | | newModelBean.setDataMap(dataMap); |
| | | return newModelBean; |