iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/dao/MmPredictModelDao.java
@@ -20,5 +20,5 @@ List<MmPredictModelEntity> getActiveModelByItemId(String itemId); List<MmPredictModelEntity> getSampleLength(String modelId); MmPredictModelEntity getSampleLength(String modelId); } iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmModelParamService.java
@@ -12,7 +12,7 @@ void saveList(List<MmModelParamEntity> list); List<MmModelParamEntity> getByModelidFromCatch(String modelId); List<MmModelParamEntity> getByModelidFromCache(String modelId); List<MmModelParamEntity> getByModelid(String modelid); iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmModelParamServiceImpl.java
@@ -44,7 +44,7 @@ } @Override public List<MmModelParamEntity> getByModelidFromCatch(String modelId) { public List<MmModelParamEntity> getByModelidFromCache(String modelId) { if (!modelInputParamMap.containsKey(modelId)) { List<MmModelParamEntity> list = getByModelid(modelId); if (list != null) { iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java
@@ -8,6 +8,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; import java.math.BigDecimal; import java.util.List; @@ -67,11 +68,11 @@ @Override public BigDecimal getSampleLength(String id) { BigDecimal result = BigDecimal.ZERO; List<MmPredictModelEntity> list = mmPredictModelDao.getSampleLength(id); if (CollectionUtils.isEmpty(list)) { MmPredictModelEntity entity = mmPredictModelDao.getSampleLength(id); if (ObjectUtils.isEmpty(entity)) { return result; } result = list.get(0).getPredictsamplength(); result = entity.getPredictsamplength(); return result; } iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/sche/service/StScheduleModelParamService.java
@@ -18,4 +18,6 @@ void deleteByModelId(String modelId); void saveList(String modelId, List<StScheduleModelParamSaveReqVO> saveList); List<StScheduleModelParamEntity> getByModelidFromCache(String modelId); } iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/sche/service/impl/StScheduleModelParamServiceImpl.java
@@ -11,7 +11,9 @@ import org.springframework.util.CollectionUtils; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; /** * @author PanZhibao @@ -21,6 +23,8 @@ @Service public class StScheduleModelParamServiceImpl extends BaseServiceImpl<StScheduleModelParamDao, StScheduleModelParamEntity> implements StScheduleModelParamService { private static Map<String, List<StScheduleModelParamEntity>> modelInputParamMap = new ConcurrentHashMap<>(); @Override public List<StScheduleModelParamEntity> getByModelId(String modelId) { @@ -51,5 +55,23 @@ entity.setModelid(modelId); baseDao.insert(entity); }); clearCache(); } @Override public List<StScheduleModelParamEntity> getByModelidFromCache(String modelId) { if (!modelInputParamMap.containsKey(modelId)) { List<StScheduleModelParamEntity> list = getByModelId(modelId); if (list != null) { modelInputParamMap.put(modelId, list); } else { return null; } } return modelInputParamMap.get(modelId); } public void clearCache() { modelInputParamMap.clear(); } } iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java
@@ -16,12 +16,16 @@ import com.iailab.module.model.mdk.sample.SampleConstructor; import com.iailab.module.model.mdk.sample.dto.SampleData; import com.iailab.module.model.mdk.vo.PredictResultVO; import com.iailab.module.model.mpk.common.MdkConstant; import com.iailab.module.model.mpk.common.utils.DllUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.util.*; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; /** * @author PanZhibao @@ -65,6 +69,11 @@ } IAILModel newModelBean = composeNewModelBean(predictModel); HashMap<String, Object> settings = getPredictSettingsByModelId(modelId); // 校验setting必须有pyFile,否则可能导致程序崩溃 if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { throw new RuntimeException("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); } if (settings == null) { log.error("模型setting不存在,modelId=" + modelId); return null; iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/sample/PredictSampleInfoConstructor.java
@@ -6,7 +6,6 @@ import com.iailab.module.model.mcs.pre.service.MmPredictModelService; import com.iailab.module.model.mdk.sample.dto.ColumnItem; import com.iailab.module.model.mdk.sample.dto.ColumnItemPort; import com.iailab.module.model.mdk.sample.dto.SampleInfo; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -32,11 +31,6 @@ @Autowired private MmPredictItemService mmPredictItemService; @Override public SampleInfo prepareSampleInfo(String modelId, Date predictTime) { return super.prepareSampleInfo(modelId, predictTime); } /** * 返回样本矩阵的列数 * @@ -46,30 +40,6 @@ @Override protected Integer getSampleColumn(String modelId) { return mmPredictModelService.getSampleLength(modelId).intValue(); } /** * 返回样本的开始时间 * * @param columnItem * @param predictTime * @return */ @Override protected Date getStartTime(ColumnItem columnItem, Date predictTime) { return super.getStartTime(columnItem, predictTime); } /** * 返回样本的结束时间 * * @param columnItem * @param predictTime * @return */ @Override protected Date getEndTime(ColumnItem columnItem, Date predictTime) { return super.getEndTime(columnItem, predictTime); } /** @@ -85,7 +55,7 @@ List<ColumnItem> columnItemList = new ArrayList<>(); ColumnItem columnInfo = new ColumnItem(); ColumnItemPort curPort = new ColumnItemPort(); //当前端口 List<MmModelParamEntity> modelInputParamEntityList = mmModelParamService.getByModelidFromCatch(modelId); List<MmModelParamEntity> modelInputParamEntityList = mmModelParamService.getByModelidFromCache(modelId); if (CollectionUtils.isEmpty(modelInputParamEntityList)) { return null; } iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/sample/SampleFactory.java
@@ -14,6 +14,9 @@ private PredictSampleInfoConstructor predictSampleInfoConstructor; @Autowired private ScheduleSampleInfoConstructor scheduleSampleInfoConstructor; @Autowired private PredictSampleDataConstructor predictSampleDataConstructor; /** @@ -24,11 +27,11 @@ * @return */ public SampleInfoConstructor createSampleInfo(String typeA, String modelId){ PredictSampleInfoConstructor sampleInfoConstructor = null; SampleInfoConstructor sampleInfoConstructor = null; if (typeA.compareTo(TypeA.Predict.name()) == 0) { sampleInfoConstructor = predictSampleInfoConstructor; } else if (typeA.compareTo(TypeA.Schedule.name()) == 0) { sampleInfoConstructor = predictSampleInfoConstructor; sampleInfoConstructor = scheduleSampleInfoConstructor; } return sampleInfoConstructor; } iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/sample/SampleInfoConstructor.java
@@ -46,14 +46,14 @@ * @param predictTime * @return */ public SampleInfo prepareSampleInfo(String modelId, Date predictTime) { protected SampleInfo prepareSampleInfo(String modelId, Date predictTime) { SampleInfo sampleInfo = new SampleInfo(); //调用样本列数的方法 sampleInfo.setSampleColumn(getSampleColumn(modelId)); // sampleInfo.setSampleColumn(getSampleColumn(modelId)); //样本的列信息 sampleInfo.setColumnInfo(getColumnInfo(modelId, predictTime)); //样本的采样周期 sampleInfo.setSampleCycle(getSampleCycle(modelId)); // sampleInfo.setSampleCycle(getSampleCycle(modelId)); return sampleInfo; } @@ -64,6 +64,24 @@ * @return */ protected abstract Integer getSampleColumn(String modelId); /** * 样本的列信息 * * @param modelId * @param predictTime * @return */ protected abstract List<ColumnItemPort> getColumnInfo(String modelId, Date predictTime); /** * 样本的采样周期 * * @param modelId * @return */ protected abstract Integer getSampleCycle(String modelId); /** * 获取开始时间 @@ -176,23 +194,6 @@ } return granularity; } /** * 样本的列信息 * * @param modelId * @param predictTime * @return */ protected abstract List<ColumnItemPort> getColumnInfo(String modelId, Date predictTime); /** * 样本的采样周期 * * @param modelId * @return */ protected abstract Integer getSampleCycle(String modelId); /** * 计算取值的时间 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/sample/ScheduleSampleInfoConstructor.java
对比新文件 @@ -0,0 +1,79 @@ package com.iailab.module.model.mdk.sample; import com.iailab.module.model.mcs.sche.entity.StScheduleModelParamEntity; import com.iailab.module.model.mcs.sche.service.StScheduleModelParamService; import com.iailab.module.model.mdk.sample.dto.ColumnItem; import com.iailab.module.model.mdk.sample.dto.ColumnItemPort; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.Date; import java.util.List; @Component public class ScheduleSampleInfoConstructor extends SampleInfoConstructor { @Autowired private StScheduleModelParamService stScheduleModelParamService; @Override protected Integer getSampleColumn(String modelId) { return null; } @Override protected List<ColumnItemPort> getColumnInfo(String modelId, Date predictTime) { List<ColumnItemPort> resultList = new ArrayList<>(); List<ColumnItem> columnItemList = new ArrayList<>(); ColumnItem columnInfo = new ColumnItem(); ColumnItemPort curPort = new ColumnItemPort(); //当前端口 List<StScheduleModelParamEntity> modelInputParamEntityList = stScheduleModelParamService.getByModelidFromCache(modelId); if (CollectionUtils.isEmpty(modelInputParamEntityList)) { return null; } //设置当前端口号,初始值为最小端口(查询结果按端口号从小到达排列) int curPortOrder = modelInputParamEntityList.get(0).getModelparamportorder(); //设置当前查询数据长度,初始值为最小端口数据长度 int curDataLength = modelInputParamEntityList.get(0).getDatalength(); for (StScheduleModelParamEntity entry : modelInputParamEntityList) { columnInfo.setParamType(entry.getModelparamtype()); columnInfo.setParamId(entry.getModelparamid()); columnInfo.setDataLength(entry.getDatalength()); columnInfo.setModelParamOrder(entry.getModelparamorder()); columnInfo.setModelParamPortOrder(entry.getModelparamportorder()); columnInfo.setStartTime(getStartTime(columnInfo, predictTime)); columnInfo.setEndTime(getEndTime(columnInfo, predictTime)); columnInfo.setGranularity(super.getGranularity(columnInfo)); //对每一个爪进行数据项归并 if (curPortOrder != entry.getModelparamportorder()){ //当数据项端口号不为当前端口号时,封装上一个端口类,操作下一个端口类 curPort.setColumnItemList(columnItemList); curPort.setDataLength(curDataLength); curPort.setPortOrder(curPortOrder); resultList.add(curPort); curPort = new ColumnItemPort(); //对象重新初始化,防止引用拷贝导致数据覆盖 //封装上一个端口类后更新当前的各个参数 columnItemList = new ArrayList<>(); curDataLength = entry.getDatalength(); curPortOrder = entry.getModelparamportorder(); } columnItemList.add(columnInfo); columnInfo = new ColumnItem(); //对象重新初始化,防止引用拷贝导致数据覆盖 } //当迭代到最后一个项的时候,封装最后一个端口的信息 curPort.setColumnItemList(columnItemList); curPort.setDataLength(curDataLength); curPort.setPortOrder(curPortOrder); resultList.add(curPort); return resultList; } @Override protected Integer getSampleCycle(String modelId) { return null; } } iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java
@@ -2,7 +2,6 @@ import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.iail.IAILMDK; import com.iail.model.IAILModel; import com.iailab.module.model.common.enums.CommonConstant; import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity; @@ -17,6 +16,7 @@ import com.iailab.module.model.mdk.sample.dto.SampleData; import com.iailab.module.model.mdk.schedule.ScheduleModelHandler; import com.iailab.module.model.mdk.vo.ScheduleResultVO; import com.iailab.module.model.mpk.common.MdkConstant; import com.iailab.module.model.mpk.common.utils.DllUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -24,7 +24,9 @@ import org.springframework.util.CollectionUtils; import java.text.MessageFormat; import java.util.*; import java.util.Date; import java.util.HashMap; import java.util.List; /** * @author PanZhibao @@ -58,7 +60,6 @@ } String modelId = scheduleModel.getId(); try { IAILModel newModelBean = new IAILModel(); //1.根据模型id构造模型输入样本 List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime); if (CollectionUtils.isEmpty(sampleDataList)) { @@ -66,31 +67,25 @@ return null; } //2.拼接newModelBean的参数结构:a.类名、方法名 b.参数类型 String className = scheduleModel.getClassName() .trim(); String methodName = scheduleModel.getMethodName().trim(); newModelBean.setClassName(className); newModelBean.setMethodName(methodName); Class<?>[] paramsArray = new Class[3]; paramsArray[0] = double[][].class; paramsArray[1] = double[][].class; paramsArray[2] = HashMap.class; newModelBean.setParamsArray(paramsArray); //3.拼接settings参数 HashMap<String, Object> settings_predict = getPredictSettingsByModelId(modelId); //4.构造param2Values参数结构 int count = sampleDataList.size(); Object[] param2Values = new Object[count + 1]; for (int i = 0; i < count; i++) { IAILModel newModelBean = composeNewModelBean(scheduleModel); HashMap<String, Object> settings = getScheduleSettingsByModelId(modelId); if (settings == null) { log.error("模型setting不存在,modelId=" + modelId); return null; } // 校验setting必须有pyFile,否则可能导致程序崩溃 if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { log.error("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); return null; } int portLength = sampleDataList.size(); Object[] param2Values = new Object[portLength + 1]; for (int i = 0; i < portLength; i++) { param2Values[i] = sampleDataList.get(i).getMatrix(); } param2Values[count] = settings_predict; param2Values[portLength] = settings; //打印参数 log.info("##############调度模型:scheduleScheme=" + scheduleScheme.getCode() + " ##########################"); log.info("#######################调度模型 " + scheduleModel.getModelName() + " ##########################"); JSONObject jsonObjNewModelBean = new JSONObject(); jsonObjNewModelBean.put("newModelBean", newModelBean); log.info(String.valueOf(jsonObjNewModelBean)); @@ -98,7 +93,7 @@ jsonObjParam2Values.put("param2Values", param2Values); log.info(String.valueOf(jsonObjParam2Values)); //运行模型 //IAILMDK.run HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, scheduleScheme.getMpkprojectid()); if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { @@ -131,60 +126,23 @@ * @param modelId * @return */ private HashMap<String, Object> getPredictSettingsByModelId(String modelId) { private HashMap<String, Object> getScheduleSettingsByModelId(String modelId) { List<StScheduleModelSettingEntity> list = stScheduleModelSettingService.getByModelId(modelId); if (CollectionUtils.isEmpty(list)) { return null; } HashMap<String, Object> result = new HashMap<>(); for (StScheduleModelSettingEntity entry : list) { String valueType = entry.getValuetype().trim(); String valueStr = entry.getValue().trim(); String valueType = entry.getValuetype().trim(); //去除两端空格 if ("int".equals(valueType)) { int value = Integer.parseInt(valueStr); int value = Integer.parseInt(entry.getValue()); result.put(entry.getKey(), value); } else if ("double".equals(valueType)) { double value = Double.parseDouble(valueStr); double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } else if ("string".equals(valueType)) { String value = valueStr; String value = entry.getValue(); result.put(entry.getKey(), value); } else if ("float".equals(valueType)) { float value = Float.parseFloat(valueStr); result.put(entry.getKey(), value); } else if ("[[D".equals(valueType)) { String valueStrTemp = entry.getValue(); try { //1.二位数组的行按照"/"来分割 String[] rowList = valueStrTemp.split("/"); int row = rowList.length; int col = rowList[0].split(",").length; double[][] value1 = new double[row][col]; for (int i = 0; i < rowList.length; i++) { //2.二位数组的列按照","来分割 String[] colList = rowList[i].split(","); for (int j = 0; j < colList.length; j++) { value1[i][j] = Double.parseDouble(colList[j]); } } //把从数据库的得到的参数的二维数组降为一维数组 //int len =0; double[] value = new double[row * col]; /*for (int j = 0; j <value1.length ; j++) { len+= value1.length; }*/ //value = new double[len]; int index = 0; for (int i = 0; i < value1.length; i++) { for (int j = 0; j < value1[i].length; j++) { value[index++] = value1[i][j]; } } result.put(entry.getKey(), value); } catch (Exception ex) { System.out.println("二维数组类型的setting格式不正确"); ex.printStackTrace(); } } else if ("decimalArray".equals(valueType)) { JSONArray valueArray = JSONArray.parseArray(entry.getValue()); double[] value = new double[valueArray.size()]; @@ -194,10 +152,26 @@ result.put(entry.getKey(), value); } else if ("decimal".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); //BigDecimal value = new BigDecimal(entry.getValue()); result.put(entry.getKey(), value); } } return result; } private IAILModel composeNewModelBean(StScheduleModelEntity model) { IAILModel newModelBean = new IAILModel(); newModelBean.setClassName(model.getClassName().trim()); newModelBean.setMethodName(model.getMethodName().trim()); //构造参数类型 Class<?>[] paramsArray = new Class[model.getPortLength() + 1]; for (int i = 0; i < model.getPortLength(); i++) { paramsArray[i] = double[][].class; } paramsArray[model.getPortLength()] = HashMap.class; newModelBean.setParamsArray(paramsArray); // // HashMap<String, Object> dataMap = new HashMap<>(); // newModelBean.setDataMap(dataMap); return newModelBean; } }