From 373ab16760c6a6eeab92d0b92b8e67e7a1f7c398 Mon Sep 17 00:00:00 2001
From: 潘志宝 <979469083@qq.com>
Date: 星期五, 22 十一月 2024 17:19:40 +0800
Subject: [PATCH] 模型输出列向量

---
 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java       |   51 ++++++++---------
 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java          |    2 
 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java                     |   32 ++++++++++
 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java            |    2 
 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java   |   25 ++++++-
 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java  |   16 +----
 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java |   18 ++++-
 7 files changed, 94 insertions(+), 52 deletions(-)

diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java
new file mode 100644
index 0000000..55aa358
--- /dev/null
+++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java
@@ -0,0 +1,32 @@
+package com.iailab.module.model.common.enums;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+/**
+ * @author PanZhibao
+ * @Description
+ * @createTime 2024年11月22日
+ */
+@Getter
+@AllArgsConstructor
+public enum OutResultType {
+    D1(1, "一维数组"),
+    D2(2, "二维数组");
+
+    private Integer code;
+    private String desc;
+
+    public static OutResultType getEumByCode(Integer code) {
+        if (code == null) {
+            return null;
+        }
+
+        for (OutResultType statusEnum : OutResultType.values()) {
+            if (statusEnum.getCode().equals(code)) {
+                return statusEnum;
+            }
+        }
+        return null;
+    }
+}
\ No newline at end of file
diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java
index ea67d95..1621867 100644
--- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java
+++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java
@@ -20,7 +20,7 @@
 
     MmItemOutputEntity getOutPutById(String outputid);
 
-    List<com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity> getByItemid(String itemid);
+    List<MmItemOutputEntity> getByItemid(String itemId);
 
     MmItemOutputEntity getByItemid(String itemid, String resultstr);
 
diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java
index db336ce..b65f7ac 100644
--- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java
+++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java
@@ -24,7 +24,7 @@
 
     List<MmPredictModelEntity> getNoSettingmapPredictModel(Map<String, Object> params);
 
-    List<MmPredictModelEntity> getActiveModelByItemId(String itemId);
+    MmPredictModelEntity getActiveModelByItemId(String itemId);
 
     void clearCache();
 }
diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java
index a7e7c44..a6f793e 100644
--- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java
+++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java
@@ -8,6 +8,7 @@
 import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
+import org.springframework.util.CollectionUtils;
 
 import java.util.List;
 import java.util.Map;
@@ -24,19 +25,21 @@
     private MmItemOutputDao mmItemOutputDao;
 
     private static Map<String, MmItemOutputEntity> outputMap = new ConcurrentHashMap<>();
+
+    private static Map<String, List<MmItemOutputEntity>> itemMap = new ConcurrentHashMap<>();
     
     @Override
     public void saveMmItemOutput(List<MmItemOutputEntity> mmItemOutput) {
         mmItemOutputDao.insert(mmItemOutput);
         // 清空缓存
-        outputMap.clear();
+        clearCache();
     }
 
     @Override
     public void update(MmItemOutputEntity mmItemOutput) {
         mmItemOutputDao.updateById(mmItemOutput);
         // 清空缓存
-        outputMap.clear();
+        clearCache();
     }
 
     public void deleteBatch(String[] itemIds) {
@@ -44,15 +47,27 @@
         queryWrapper.in("itemid", itemIds);
         mmItemOutputDao.delete(queryWrapper);
         // 清空缓存
+        clearCache();
+    }
+
+    private void clearCache() {
         outputMap.clear();
+        itemMap.clear();
     }
 
     @Override
-    public List<MmItemOutputEntity> getByItemid(String itemid) {
+    public List<MmItemOutputEntity> getByItemid(String itemId) {
+        if (itemMap.containsKey(itemId)) {
+            return itemMap.get(itemId);
+        }
         QueryWrapper<MmItemOutputEntity> queryWrapper = new QueryWrapper<>();
-        queryWrapper.eq("itemid", itemid).orderByAsc("outputorder");
+        queryWrapper.eq("itemid", itemId).orderByAsc("outputorder");
         List<MmItemOutputEntity> list = mmItemOutputDao.selectList(queryWrapper);
-        return list;
+        if (CollectionUtils.isEmpty(list)) {
+            return null;
+        }
+        itemMap.put(itemId, list);
+        return itemMap.get(itemId);
     }
 
     @Override
diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java
index 892bae3..12221fc 100644
--- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java
+++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java
@@ -2,8 +2,6 @@
 
 import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
 import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
-import com.iailab.framework.common.service.impl.BaseServiceImpl;
-import com.iailab.module.model.mcs.pre.dao.MmPredictMergeItemDao;
 import com.iailab.module.model.mcs.pre.dao.MmPredictModelDao;
 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
 import com.iailab.module.model.mcs.pre.service.MmPredictModelService;
@@ -12,7 +10,6 @@
 import org.springframework.util.CollectionUtils;
 
 import java.math.BigDecimal;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.UUID;
@@ -29,6 +26,8 @@
     private MmPredictModelDao mmPredictModelDao;
 
     private static Map<String, MmPredictModelEntity> modelEntityMap = new ConcurrentHashMap<>();
+
+    private static Map<String, MmPredictModelEntity> activeModelMap = new ConcurrentHashMap<>();
 
     @Override
     public void savePredictModel(MmPredictModelEntity predictModel) {
@@ -57,6 +56,7 @@
     @Override
     public void clearCache() {
         modelEntityMap.clear();
+        activeModelMap.clear();
     }
 
     @Override
@@ -99,7 +99,15 @@
     }
 
     @Override
-    public List<MmPredictModelEntity> getActiveModelByItemId(String itemId) {
-        return mmPredictModelDao.getActiveModelByItemId(itemId);
+    public MmPredictModelEntity getActiveModelByItemId(String itemId) {
+        if (activeModelMap.containsKey(itemId)) {
+            return activeModelMap.get(itemId);
+        }
+        List<MmPredictModelEntity> list = mmPredictModelDao.getActiveModelByItemId(itemId);
+        if (CollectionUtils.isEmpty(list)) {
+            return null;
+        }
+        activeModelMap.put(itemId, list.get(0));
+        return activeModelMap.get(itemId);
     }
 }
diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java
index 93e5084..855794b 100644
--- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java
+++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java
@@ -1,9 +1,7 @@
 package com.iailab.module.model.mdk.predict.impl;
 
-import com.iailab.framework.common.util.collection.CollectionUtils;
 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
 import com.iailab.module.model.mcs.pre.service.MmPredictModelService;
-import com.iailab.module.model.mdk.common.enums.ItemPredictStatus;
 import com.iailab.module.model.mdk.common.exceptions.ItemInvokeException;
 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
 import com.iailab.module.model.mdk.predict.PredictItemHandler;
@@ -13,12 +11,8 @@
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Component;
 
-import java.sql.Timestamp;
 import java.text.MessageFormat;
-import java.util.ArrayList;
-import java.util.Calendar;
 import java.util.Date;
-import java.util.List;
 
 /**
  * @author PanZhibao
@@ -43,22 +37,18 @@
      * @throws ItemInvokeException
      */
     @Override
-    public PredictResultVO predict(Date predictTime, ItemVO predictItemDto) throws ItemInvokeException{
+    public PredictResultVO predict(Date predictTime, ItemVO predictItemDto) throws ItemInvokeException {
         PredictResultVO predictResult = new PredictResultVO();
         String itemId = predictItemDto.getId();
         predictResult.setPredictId(itemId);
         try {
-            // 获取预测项模型
-            List<MmPredictModelEntity> predictModelList = mmPredictModelService.getActiveModelByItemId(itemId);
-            if (CollectionUtils.isAnyEmpty(predictModelList)) {
+            MmPredictModelEntity predictModel = mmPredictModelService.getActiveModelByItemId(itemId);
+            if (predictModel == null) {
                 throw new ModelInvokeException(MessageFormat.format("{0},itemId={1}",
                         ModelInvokeException.errorGetModelEntity, itemId));
             }
-            MmPredictModelEntity predictModel = predictModelList.get(0);
             predictResult = predictModelHandler.predictByModel(predictTime, predictModel);
         } catch (Exception ex) {
-            ex.printStackTrace();
-            //预测项预测失败的状态
             throw new ItemInvokeException(MessageFormat.format("{0},itemId={1}",
                     ItemInvokeException.errorItemFailed, itemId));
         }
diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java
index 3bdf3d4..9d5359b 100644
--- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java
+++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java
@@ -4,6 +4,7 @@
 import com.alibaba.fastjson.JSONObject;
 import com.iail.model.IAILModel;
 import com.iailab.module.model.common.enums.CommonConstant;
+import com.iailab.module.model.common.enums.OutResultType;
 import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
@@ -55,12 +56,11 @@
             throw new ModelInvokeException("modelEntity is null");
         }
         String modelId = predictModel.getId();
-
         try {
             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime);
             String modelPath = predictModel.getModelpath();
             if (modelPath == null) {
-                System.out.println("模型路径不存在,modelId=" + modelId);
+                log.info("模型路径不存在,modelId=" + modelId);
                 return null;
             }
             IAILModel newModelBean = composeNewModelBean(predictModel);
@@ -93,36 +93,33 @@
             }
             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
             //打印结果
+            log.info("模型计算完成:modelId=" + modelId + modelResult);
             JSONObject jsonObjResult = new JSONObject();
             jsonObjResult.put("result", modelResult);
             log.info(String.valueOf(jsonObjResult));
 
-            List<MmItemOutputEntity> ItemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
-            log.info("模型计算完成:modelId=" + modelId + modelResult);
-
-            Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(ItemOutputList.size());
-
-            for (MmItemOutputEntity outputEntity : ItemOutputList) {
-                String resultStr = outputEntity.getResultstr();
-                if (modelResult.containsKey(resultStr)) {
-                    if (outputEntity.getResultType() == 1) {
-                        // 一维数组
-                        Double[] temp = (Double[]) modelResult.get(resultStr);
-                        double[] temp1 = new double[temp.length];
-                        for (int i = 0; i < temp.length; i++) {
-                            temp1[i] = temp[i].doubleValue();
+            List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
+            Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(itemOutputList.size());
+            for (MmItemOutputEntity output : itemOutputList) {
+                if (!modelResult.containsKey(output.getResultstr())) {
+                    continue;
+                }
+                OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
+                switch (outResultType) {
+                    case D1:
+                        double[] temp1 = (double[]) modelResult.get(output.getResultstr());
+                        predictMatrixs.put(output, temp1);
+                        break;
+                    case D2:
+                        double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
+                        double[] tempColumn = new double[temp2.length];
+                        for (int i = 0; i < tempColumn.length; i++) {
+                            tempColumn[i] = temp2[i][output.getResultIndex()];
                         }
-                        predictMatrixs.put(outputEntity, temp1);
-                    } else if (outputEntity.getResultType() == 2) {
-                        // 二维数组
-                        Double[][] temp = (Double[][]) modelResult.get(resultStr);
-                        Double[] temp2 = temp[outputEntity.getResultIndex()];
-                        double[] temp1 = new double[temp2.length];
-                        for (int i = 0; i < temp2.length; i++) {
-                            temp1[i] = temp2[i].doubleValue();
-                        }
-                        predictMatrixs.put(outputEntity, temp1);
-                    }
+                        predictMatrixs.put(output, tempColumn);
+                        break;
+                    default:
+                        break;
                 }
             }
             result.setPredictMatrixs(predictMatrixs);

--
Gitblit v1.9.3