From b651cbfd94d8d636c01b61e483ed1cff98e1bcb9 Mon Sep 17 00:00:00 2001 From: 潘志宝 <979469083@qq.com> Date: 星期一, 23 十二月 2024 16:13:56 +0800 Subject: [PATCH] Merge remote-tracking branch 'origin/master' --- iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java | 137 ++++++++++++++++++--------------------------- 1 files changed, 56 insertions(+), 81 deletions(-) diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java index 3905e29..70a3172 100644 --- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java +++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java @@ -1,8 +1,8 @@ package com.iailab.module.model.mdk.schedule.impl; +import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; -import com.iail.IAILMDK; import com.iail.model.IAILModel; import com.iailab.module.model.common.enums.CommonConstant; import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity; @@ -17,6 +17,7 @@ import com.iailab.module.model.mdk.sample.dto.SampleData; import com.iailab.module.model.mdk.schedule.ScheduleModelHandler; import com.iailab.module.model.mdk.vo.ScheduleResultVO; +import com.iailab.module.model.mpk.common.MdkConstant; import com.iailab.module.model.mpk.common.utils.DllUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -24,7 +25,9 @@ import org.springframework.util.CollectionUtils; import java.text.MessageFormat; -import java.util.*; +import java.util.Date; +import java.util.HashMap; +import java.util.List; /** * @author PanZhibao @@ -58,63 +61,56 @@ } String modelId = scheduleModel.getId(); try { - IAILModel newModelBean = new IAILModel(); //1.根据模型id构造模型输入样本 - List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime); + long now = System.currentTimeMillis(); + List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime,scheduleScheme.getName()); + log.info("构造模型输入样本消耗时长:" + (System.currentTimeMillis() - now) / 1000 + "秒"); if (CollectionUtils.isEmpty(sampleDataList)) { log.info("调度模型构造样本失败,schemeCode=" + schemeCode); return null; } - //2.拼接newModelBean的参数结构:a.类名、方法名 b.参数类型 - String className = scheduleModel.getClassName() .trim(); - String methodName = scheduleModel.getMethodName().trim(); - newModelBean.setClassName(className); - newModelBean.setMethodName(methodName); - - Class<?>[] paramsArray = new Class[3]; - paramsArray[0] = double[][].class; - paramsArray[1] = double[][].class; - paramsArray[2] = HashMap.class; - newModelBean.setParamsArray(paramsArray); - - //3.拼接settings参数 - HashMap<String, Object> settings_predict = getPredictSettingsByModelId(modelId); - - //4.构造param2Values参数结构 - int count = sampleDataList.size(); - Object[] param2Values = new Object[count + 1]; - for (int i = 0; i < count; i++) { + IAILModel newModelBean = composeNewModelBean(scheduleModel); + HashMap<String, Object> settings = getScheduleSettingsByModelId(modelId); + if (settings == null) { + log.error("模型setting不存在,modelId=" + modelId); + return null; + } + // 校验setting必须有pyFile,否则可能导致程序崩溃 + if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { + log.error("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); + return null; + } + int portLength = sampleDataList.size(); + Object[] param2Values = new Object[portLength + 1]; + for (int i = 0; i < portLength; i++) { param2Values[i] = sampleDataList.get(i).getMatrix(); } - param2Values[count] = settings_predict; + param2Values[portLength] = settings; - //打印参数 - log.info("##############调度模型:modelId=" + modelId + " ##########################"); - JSONObject jsonObjNewModelBean = new JSONObject(); - jsonObjNewModelBean.put("newModelBean", newModelBean); - log.info(String.valueOf(jsonObjNewModelBean)); - JSONObject jsonObjParam2Values = new JSONObject(); - jsonObjParam2Values.put("param2Values", param2Values); - log.info(String.valueOf(jsonObjParam2Values)); + log.info("#######################调度模型 " + scheduleModel.getModelName() + " ##########################"); +// JSONObject jsonObjNewModelBean = new JSONObject(); +// jsonObjNewModelBean.put("newModelBean", newModelBean); +// log.info(String.valueOf(jsonObjNewModelBean)); +// JSONObject jsonObjParam2Values = new JSONObject(); +// jsonObjParam2Values.put("param2Values", param2Values); + log.info("参数: " + JSON.toJSONString(param2Values)); //IAILMDK.run - HashMap<String, Object> result = IAILMDK.run(newModelBean, param2Values); - /*HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, predictModel.getMpkprojectid()); + HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, scheduleScheme.getMpkprojectid()); if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { throw new RuntimeException("模型结果异常:" + modelResult); } - modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);*/ + modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT); //打印结果 JSONObject jsonObjResult = new JSONObject(); - jsonObjResult.put("result", result); + jsonObjResult.put("result", modelResult); log.info(String.valueOf(jsonObjResult)); - log.info("调度模型计算完成:modelId=" + modelId + result); //5.返回调度结果 - scheduleResult.setResult(result); + scheduleResult.setResult(modelResult); scheduleResult.setModelId(modelId); scheduleResult.setSchemeId(scheduleScheme.getId()); scheduleResult.setScheduleTime(scheduleTime); @@ -133,60 +129,23 @@ * @param modelId * @return */ - private HashMap<String, Object> getPredictSettingsByModelId(String modelId) { + private HashMap<String, Object> getScheduleSettingsByModelId(String modelId) { List<StScheduleModelSettingEntity> list = stScheduleModelSettingService.getByModelId(modelId); if (CollectionUtils.isEmpty(list)) { return null; } HashMap<String, Object> result = new HashMap<>(); for (StScheduleModelSettingEntity entry : list) { - String valueType = entry.getValuetype().trim(); - String valueStr = entry.getValue().trim(); + String valueType = entry.getValuetype().trim(); //去除两端空格 if ("int".equals(valueType)) { - int value = Integer.parseInt(valueStr); + int value = Integer.parseInt(entry.getValue()); result.put(entry.getKey(), value); } else if ("double".equals(valueType)) { - double value = Double.parseDouble(valueStr); + double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } else if ("string".equals(valueType)) { - String value = valueStr; + String value = entry.getValue(); result.put(entry.getKey(), value); - } else if ("float".equals(valueType)) { - float value = Float.parseFloat(valueStr); - result.put(entry.getKey(), value); - } else if ("[[D".equals(valueType)) { - String valueStrTemp = entry.getValue(); - try { - //1.二位数组的行按照"/"来分割 - String[] rowList = valueStrTemp.split("/"); - int row = rowList.length; - int col = rowList[0].split(",").length; - double[][] value1 = new double[row][col]; - for (int i = 0; i < rowList.length; i++) { - //2.二位数组的列按照","来分割 - String[] colList = rowList[i].split(","); - for (int j = 0; j < colList.length; j++) { - value1[i][j] = Double.parseDouble(colList[j]); - } - } - //把从数据库的得到的参数的二维数组降为一维数组 - //int len =0; - double[] value = new double[row * col]; - /*for (int j = 0; j <value1.length ; j++) { - len+= value1.length; - }*/ - //value = new double[len]; - int index = 0; - for (int i = 0; i < value1.length; i++) { - for (int j = 0; j < value1[i].length; j++) { - value[index++] = value1[i][j]; - } - } - result.put(entry.getKey(), value); - } catch (Exception ex) { - System.out.println("二维数组类型的setting格式不正确"); - ex.printStackTrace(); - } } else if ("decimalArray".equals(valueType)) { JSONArray valueArray = JSONArray.parseArray(entry.getValue()); double[] value = new double[valueArray.size()]; @@ -196,10 +155,26 @@ result.put(entry.getKey(), value); } else if ("decimal".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); - //BigDecimal value = new BigDecimal(entry.getValue()); result.put(entry.getKey(), value); } } return result; } + + private IAILModel composeNewModelBean(StScheduleModelEntity model) { + IAILModel newModelBean = new IAILModel(); + newModelBean.setClassName(model.getClassName().trim()); + newModelBean.setMethodName(model.getMethodName().trim()); + //构造参数类型 + Class<?>[] paramsArray = new Class[model.getPortLength() + 1]; + for (int i = 0; i < model.getPortLength(); i++) { + paramsArray[i] = double[][].class; + } + paramsArray[model.getPortLength()] = HashMap.class; + newModelBean.setParamsArray(paramsArray); + // +// HashMap<String, Object> dataMap = new HashMap<>(); +// newModelBean.setDataMap(dataMap); + return newModelBean; + } } \ No newline at end of file -- Gitblit v1.9.3