| | |
| | | package com.iailab.module.model.mdk.schedule.impl; |
| | | |
| | | import com.alibaba.fastjson.JSON; |
| | | import com.alibaba.fastjson.JSONArray; |
| | | import com.alibaba.fastjson.JSONObject; |
| | | import com.iail.IAILMDK; |
| | | import com.iail.model.IAILModel; |
| | | import com.iailab.module.model.common.enums.CommonConstant; |
| | | import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity; |
| | | import com.iailab.module.model.mcs.sche.entity.StScheduleModelSettingEntity; |
| | | import com.iailab.module.model.mcs.sche.entity.StScheduleSchemeEntity; |
| | |
| | | import com.iailab.module.model.mdk.sample.dto.SampleData; |
| | | import com.iailab.module.model.mdk.schedule.ScheduleModelHandler; |
| | | import com.iailab.module.model.mdk.vo.ScheduleResultVO; |
| | | import com.iailab.module.model.mpk.common.MdkConstant; |
| | | import com.iailab.module.model.mpk.common.utils.DllUtils; |
| | | import lombok.extern.slf4j.Slf4j; |
| | | import org.springframework.beans.factory.annotation.Autowired; |
| | | import org.springframework.stereotype.Component; |
| | | import org.springframework.util.CollectionUtils; |
| | | |
| | | import java.text.MessageFormat; |
| | | import java.util.*; |
| | | import java.util.Date; |
| | | import java.util.HashMap; |
| | | import java.util.List; |
| | | |
| | | /** |
| | | * @author PanZhibao |
| | |
| | | } |
| | | String modelId = scheduleModel.getId(); |
| | | try { |
| | | IAILModel newModelBean = new IAILModel(); |
| | | //1.根据模型id构造模型输入样本 |
| | | List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime); |
| | | long now = System.currentTimeMillis(); |
| | | List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime,scheduleScheme.getName()); |
| | | log.info("构造模型输入样本消耗时长:" + (System.currentTimeMillis() - now) / 1000 + "秒"); |
| | | if (CollectionUtils.isEmpty(sampleDataList)) { |
| | | log.info("调度模型构造样本失败,schemeCode=" + schemeCode); |
| | | return null; |
| | | } |
| | | |
| | | //2.拼接newModelBean的参数结构:a.类名、方法名 b.参数类型 |
| | | String className = scheduleModel.getClassName() .trim(); |
| | | String methodName = scheduleModel.getMethodName().trim(); |
| | | newModelBean.setClassName(className); |
| | | newModelBean.setMethodName(methodName); |
| | | |
| | | Class<?>[] paramsArray = new Class[3]; |
| | | paramsArray[0] = double[][].class; |
| | | paramsArray[1] = double[][].class; |
| | | paramsArray[2] = HashMap.class; |
| | | newModelBean.setParamsArray(paramsArray); |
| | | |
| | | //3.拼接settings参数 |
| | | HashMap<String, Object> settings_predict = getPredictSettingsByModelId(modelId); |
| | | |
| | | //4.构造param2Values参数结构 |
| | | int count = sampleDataList.size(); |
| | | Object[] param2Values = new Object[count + 1]; |
| | | for (int i = 0; i < count; i++) { |
| | | IAILModel newModelBean = composeNewModelBean(scheduleModel); |
| | | HashMap<String, Object> settings = getScheduleSettingsByModelId(modelId); |
| | | if (settings == null) { |
| | | log.error("模型setting不存在,modelId=" + modelId); |
| | | return null; |
| | | } |
| | | // 校验setting必须有pyFile,否则可能导致程序崩溃 |
| | | if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { |
| | | log.error("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); |
| | | return null; |
| | | } |
| | | int portLength = sampleDataList.size(); |
| | | Object[] param2Values = new Object[portLength + 1]; |
| | | for (int i = 0; i < portLength; i++) { |
| | | param2Values[i] = sampleDataList.get(i).getMatrix(); |
| | | } |
| | | param2Values[count] = settings_predict; |
| | | param2Values[portLength] = settings; |
| | | |
| | | //打印参数 |
| | | log.info("##############调度模型:modelId=" + modelId + " ##########################"); |
| | | JSONObject jsonObjNewModelBean = new JSONObject(); |
| | | jsonObjNewModelBean.put("newModelBean", newModelBean); |
| | | log.info(String.valueOf(jsonObjNewModelBean)); |
| | | JSONObject jsonObjParam2Values = new JSONObject(); |
| | | jsonObjParam2Values.put("param2Values", param2Values); |
| | | log.info(String.valueOf(jsonObjParam2Values)); |
| | | |
| | | log.info("#######################调度模型 " + scheduleModel.getModelName() + " ##########################"); |
| | | log.info("参数: " + JSON.toJSONString(param2Values)); |
| | | //IAILMDK.run |
| | | HashMap<String, Object> result = IAILMDK.run(newModelBean, param2Values); |
| | | |
| | | HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, scheduleScheme.getMpkprojectid()); |
| | | if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || |
| | | !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { |
| | | throw new RuntimeException("模型结果异常:" + modelResult); |
| | | } |
| | | String statusCode = modelResult.get(CommonConstant.MDK_STATUS_CODE).toString(); |
| | | modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT); |
| | | //打印结果 |
| | | JSONObject jsonObjResult = new JSONObject(); |
| | | jsonObjResult.put("result", result); |
| | | jsonObjResult.put("result", modelResult); |
| | | log.info(String.valueOf(jsonObjResult)); |
| | | log.info("调度模型计算完成:modelId=" + modelId + result); |
| | | |
| | | //5.返回调度结果 |
| | | scheduleResult.setResult(result); |
| | | scheduleResult.setResultCode(statusCode); |
| | | scheduleResult.setResult(modelResult); |
| | | scheduleResult.setModelId(modelId); |
| | | scheduleResult.setSchemeId(scheduleScheme.getId()); |
| | | scheduleResult.setScheduleTime(scheduleTime); |
| | |
| | | * @param modelId |
| | | * @return |
| | | */ |
| | | private HashMap<String, Object> getPredictSettingsByModelId(String modelId) { |
| | | private HashMap<String, Object> getScheduleSettingsByModelId(String modelId) { |
| | | List<StScheduleModelSettingEntity> list = stScheduleModelSettingService.getByModelId(modelId); |
| | | if (CollectionUtils.isEmpty(list)) { |
| | | return null; |
| | | } |
| | | HashMap<String, Object> result = new HashMap<>(); |
| | | for (StScheduleModelSettingEntity entry : list) { |
| | | String valueType = entry.getValuetype().trim(); |
| | | String valueStr = entry.getValue().trim(); |
| | | String valueType = entry.getValuetype().trim(); //去除两端空格 |
| | | if ("int".equals(valueType)) { |
| | | int value = Integer.parseInt(valueStr); |
| | | int value = Integer.parseInt(entry.getValue()); |
| | | result.put(entry.getKey(), value); |
| | | } else if ("double".equals(valueType)) { |
| | | double value = Double.parseDouble(valueStr); |
| | | double value = Double.parseDouble(entry.getValue()); |
| | | result.put(entry.getKey(), value); |
| | | } else if ("string".equals(valueType)) { |
| | | String value = valueStr; |
| | | String value = entry.getValue(); |
| | | result.put(entry.getKey(), value); |
| | | } else if ("float".equals(valueType)) { |
| | | float value = Float.parseFloat(valueStr); |
| | | result.put(entry.getKey(), value); |
| | | } else if ("[[D".equals(valueType)) { |
| | | String valueStrTemp = entry.getValue(); |
| | | try { |
| | | //1.二位数组的行按照"/"来分割 |
| | | String[] rowList = valueStrTemp.split("/"); |
| | | int row = rowList.length; |
| | | int col = rowList[0].split(",").length; |
| | | double[][] value1 = new double[row][col]; |
| | | for (int i = 0; i < rowList.length; i++) { |
| | | //2.二位数组的列按照","来分割 |
| | | String[] colList = rowList[i].split(","); |
| | | for (int j = 0; j < colList.length; j++) { |
| | | value1[i][j] = Double.parseDouble(colList[j]); |
| | | } |
| | | } |
| | | //把从数据库的得到的参数的二维数组降为一维数组 |
| | | //int len =0; |
| | | double[] value = new double[row * col]; |
| | | /*for (int j = 0; j <value1.length ; j++) { |
| | | len+= value1.length; |
| | | }*/ |
| | | //value = new double[len]; |
| | | int index = 0; |
| | | for (int i = 0; i < value1.length; i++) { |
| | | for (int j = 0; j < value1[i].length; j++) { |
| | | value[index++] = value1[i][j]; |
| | | } |
| | | } |
| | | result.put(entry.getKey(), value); |
| | | } catch (Exception ex) { |
| | | System.out.println("二维数组类型的setting格式不正确"); |
| | | ex.printStackTrace(); |
| | | } |
| | | } else if ("decimalArray".equals(valueType)) { |
| | | JSONArray valueArray = JSONArray.parseArray(entry.getValue()); |
| | | double[] value = new double[valueArray.size()]; |
| | |
| | | result.put(entry.getKey(), value); |
| | | } else if ("decimal".equals(valueType)) { |
| | | double value = Double.parseDouble(entry.getValue()); |
| | | //BigDecimal value = new BigDecimal(entry.getValue()); |
| | | result.put(entry.getKey(), value); |
| | | } |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | private IAILModel composeNewModelBean(StScheduleModelEntity model) { |
| | | IAILModel newModelBean = new IAILModel(); |
| | | newModelBean.setClassName(model.getClassName().trim()); |
| | | newModelBean.setMethodName(model.getMethodName().trim()); |
| | | //构造参数类型 |
| | | Class<?>[] paramsArray = new Class[model.getPortLength() + 1]; |
| | | for (int i = 0; i < model.getPortLength(); i++) { |
| | | paramsArray[i] = double[][].class; |
| | | } |
| | | paramsArray[model.getPortLength()] = HashMap.class; |
| | | newModelBean.setParamsArray(paramsArray); |
| | | // |
| | | // HashMap<String, Object> dataMap = new HashMap<>(); |
| | | // newModelBean.setDataMap(dataMap); |
| | | return newModelBean; |
| | | } |
| | | } |