From 373ab16760c6a6eeab92d0b92b8e67e7a1f7c398 Mon Sep 17 00:00:00 2001 From: 潘志宝 <979469083@qq.com> Date: 星期五, 22 十一月 2024 17:19:40 +0800 Subject: [PATCH] 模型输出列向量 --- iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java | 51 ++++++++--------- iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java | 2 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java | 32 ++++++++++ iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java | 2 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java | 25 ++++++- iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java | 16 +---- iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java | 18 ++++- 7 files changed, 94 insertions(+), 52 deletions(-) diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java new file mode 100644 index 0000000..55aa358 --- /dev/null +++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java @@ -0,0 +1,32 @@ +package com.iailab.module.model.common.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * @author PanZhibao + * @Description + * @createTime 2024年11月22日 + */ +@Getter +@AllArgsConstructor +public enum OutResultType { + D1(1, "一维数组"), + D2(2, "二维数组"); + + private Integer code; + private String desc; + + public static OutResultType getEumByCode(Integer code) { + if (code == null) { + return null; + } + + for (OutResultType statusEnum : OutResultType.values()) { + if (statusEnum.getCode().equals(code)) { + return statusEnum; + } + } + return null; + } +} \ No newline at end of file diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java index ea67d95..1621867 100644 --- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java +++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java @@ -20,7 +20,7 @@ MmItemOutputEntity getOutPutById(String outputid); - List<com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity> getByItemid(String itemid); + List<MmItemOutputEntity> getByItemid(String itemId); MmItemOutputEntity getByItemid(String itemid, String resultstr); diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java index db336ce..b65f7ac 100644 --- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java +++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java @@ -24,7 +24,7 @@ List<MmPredictModelEntity> getNoSettingmapPredictModel(Map<String, Object> params); - List<MmPredictModelEntity> getActiveModelByItemId(String itemId); + MmPredictModelEntity getActiveModelByItemId(String itemId); void clearCache(); } diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java index a7e7c44..a6f793e 100644 --- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java +++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java @@ -8,6 +8,7 @@ import com.iailab.module.model.mcs.pre.service.MmItemOutputService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; import java.util.List; import java.util.Map; @@ -24,19 +25,21 @@ private MmItemOutputDao mmItemOutputDao; private static Map<String, MmItemOutputEntity> outputMap = new ConcurrentHashMap<>(); + + private static Map<String, List<MmItemOutputEntity>> itemMap = new ConcurrentHashMap<>(); @Override public void saveMmItemOutput(List<MmItemOutputEntity> mmItemOutput) { mmItemOutputDao.insert(mmItemOutput); // 清空缓存 - outputMap.clear(); + clearCache(); } @Override public void update(MmItemOutputEntity mmItemOutput) { mmItemOutputDao.updateById(mmItemOutput); // 清空缓存 - outputMap.clear(); + clearCache(); } public void deleteBatch(String[] itemIds) { @@ -44,15 +47,27 @@ queryWrapper.in("itemid", itemIds); mmItemOutputDao.delete(queryWrapper); // 清空缓存 + clearCache(); + } + + private void clearCache() { outputMap.clear(); + itemMap.clear(); } @Override - public List<MmItemOutputEntity> getByItemid(String itemid) { + public List<MmItemOutputEntity> getByItemid(String itemId) { + if (itemMap.containsKey(itemId)) { + return itemMap.get(itemId); + } QueryWrapper<MmItemOutputEntity> queryWrapper = new QueryWrapper<>(); - queryWrapper.eq("itemid", itemid).orderByAsc("outputorder"); + queryWrapper.eq("itemid", itemId).orderByAsc("outputorder"); List<MmItemOutputEntity> list = mmItemOutputDao.selectList(queryWrapper); - return list; + if (CollectionUtils.isEmpty(list)) { + return null; + } + itemMap.put(itemId, list); + return itemMap.get(itemId); } @Override diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java index 892bae3..12221fc 100644 --- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java +++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java @@ -2,8 +2,6 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; -import com.iailab.framework.common.service.impl.BaseServiceImpl; -import com.iailab.module.model.mcs.pre.dao.MmPredictMergeItemDao; import com.iailab.module.model.mcs.pre.dao.MmPredictModelDao; import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; import com.iailab.module.model.mcs.pre.service.MmPredictModelService; @@ -12,7 +10,6 @@ import org.springframework.util.CollectionUtils; import java.math.BigDecimal; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; @@ -29,6 +26,8 @@ private MmPredictModelDao mmPredictModelDao; private static Map<String, MmPredictModelEntity> modelEntityMap = new ConcurrentHashMap<>(); + + private static Map<String, MmPredictModelEntity> activeModelMap = new ConcurrentHashMap<>(); @Override public void savePredictModel(MmPredictModelEntity predictModel) { @@ -57,6 +56,7 @@ @Override public void clearCache() { modelEntityMap.clear(); + activeModelMap.clear(); } @Override @@ -99,7 +99,15 @@ } @Override - public List<MmPredictModelEntity> getActiveModelByItemId(String itemId) { - return mmPredictModelDao.getActiveModelByItemId(itemId); + public MmPredictModelEntity getActiveModelByItemId(String itemId) { + if (activeModelMap.containsKey(itemId)) { + return activeModelMap.get(itemId); + } + List<MmPredictModelEntity> list = mmPredictModelDao.getActiveModelByItemId(itemId); + if (CollectionUtils.isEmpty(list)) { + return null; + } + activeModelMap.put(itemId, list.get(0)); + return activeModelMap.get(itemId); } } diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java index 93e5084..855794b 100644 --- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java +++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java @@ -1,9 +1,7 @@ package com.iailab.module.model.mdk.predict.impl; -import com.iailab.framework.common.util.collection.CollectionUtils; import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; import com.iailab.module.model.mcs.pre.service.MmPredictModelService; -import com.iailab.module.model.mdk.common.enums.ItemPredictStatus; import com.iailab.module.model.mdk.common.exceptions.ItemInvokeException; import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException; import com.iailab.module.model.mdk.predict.PredictItemHandler; @@ -13,12 +11,8 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; -import java.sql.Timestamp; import java.text.MessageFormat; -import java.util.ArrayList; -import java.util.Calendar; import java.util.Date; -import java.util.List; /** * @author PanZhibao @@ -43,22 +37,18 @@ * @throws ItemInvokeException */ @Override - public PredictResultVO predict(Date predictTime, ItemVO predictItemDto) throws ItemInvokeException{ + public PredictResultVO predict(Date predictTime, ItemVO predictItemDto) throws ItemInvokeException { PredictResultVO predictResult = new PredictResultVO(); String itemId = predictItemDto.getId(); predictResult.setPredictId(itemId); try { - // 获取预测项模型 - List<MmPredictModelEntity> predictModelList = mmPredictModelService.getActiveModelByItemId(itemId); - if (CollectionUtils.isAnyEmpty(predictModelList)) { + MmPredictModelEntity predictModel = mmPredictModelService.getActiveModelByItemId(itemId); + if (predictModel == null) { throw new ModelInvokeException(MessageFormat.format("{0},itemId={1}", ModelInvokeException.errorGetModelEntity, itemId)); } - MmPredictModelEntity predictModel = predictModelList.get(0); predictResult = predictModelHandler.predictByModel(predictTime, predictModel); } catch (Exception ex) { - ex.printStackTrace(); - //预测项预测失败的状态 throw new ItemInvokeException(MessageFormat.format("{0},itemId={1}", ItemInvokeException.errorItemFailed, itemId)); } diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java index 3bdf3d4..9d5359b 100644 --- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java +++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java @@ -4,6 +4,7 @@ import com.alibaba.fastjson.JSONObject; import com.iail.model.IAILModel; import com.iailab.module.model.common.enums.CommonConstant; +import com.iailab.module.model.common.enums.OutResultType; import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity; import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity; import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; @@ -55,12 +56,11 @@ throw new ModelInvokeException("modelEntity is null"); } String modelId = predictModel.getId(); - try { List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime); String modelPath = predictModel.getModelpath(); if (modelPath == null) { - System.out.println("模型路径不存在,modelId=" + modelId); + log.info("模型路径不存在,modelId=" + modelId); return null; } IAILModel newModelBean = composeNewModelBean(predictModel); @@ -93,36 +93,33 @@ } modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT); //打印结果 + log.info("模型计算完成:modelId=" + modelId + modelResult); JSONObject jsonObjResult = new JSONObject(); jsonObjResult.put("result", modelResult); log.info(String.valueOf(jsonObjResult)); - List<MmItemOutputEntity> ItemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid()); - log.info("模型计算完成:modelId=" + modelId + modelResult); - - Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(ItemOutputList.size()); - - for (MmItemOutputEntity outputEntity : ItemOutputList) { - String resultStr = outputEntity.getResultstr(); - if (modelResult.containsKey(resultStr)) { - if (outputEntity.getResultType() == 1) { - // 一维数组 - Double[] temp = (Double[]) modelResult.get(resultStr); - double[] temp1 = new double[temp.length]; - for (int i = 0; i < temp.length; i++) { - temp1[i] = temp[i].doubleValue(); + List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid()); + Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(itemOutputList.size()); + for (MmItemOutputEntity output : itemOutputList) { + if (!modelResult.containsKey(output.getResultstr())) { + continue; + } + OutResultType outResultType = OutResultType.getEumByCode(output.getResultType()); + switch (outResultType) { + case D1: + double[] temp1 = (double[]) modelResult.get(output.getResultstr()); + predictMatrixs.put(output, temp1); + break; + case D2: + double[][] temp2 = (double[][]) modelResult.get(output.getResultstr()); + double[] tempColumn = new double[temp2.length]; + for (int i = 0; i < tempColumn.length; i++) { + tempColumn[i] = temp2[i][output.getResultIndex()]; } - predictMatrixs.put(outputEntity, temp1); - } else if (outputEntity.getResultType() == 2) { - // 二维数组 - Double[][] temp = (Double[][]) modelResult.get(resultStr); - Double[] temp2 = temp[outputEntity.getResultIndex()]; - double[] temp1 = new double[temp2.length]; - for (int i = 0; i < temp2.length; i++) { - temp1[i] = temp2[i].doubleValue(); - } - predictMatrixs.put(outputEntity, temp1); - } + predictMatrixs.put(output, tempColumn); + break; + default: + break; } } result.setPredictMatrixs(predictMatrixs); -- Gitblit v1.9.3