潘志宝
2024-11-22 373ab16760c6a6eeab92d0b92b8e67e7a1f7c398
模型输出列向量
已修改6个文件
已添加1个文件
146 ■■■■■ 文件已修改
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java 32 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java 25 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java 18 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java 51 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/common/enums/OutResultType.java
对比新文件
@@ -0,0 +1,32 @@
package com.iailab.module.model.common.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
 * @author PanZhibao
 * @Description
 * @createTime 2024年11月22日
 */
@Getter
@AllArgsConstructor
public enum OutResultType {
    D1(1, "一维数组"),
    D2(2, "二维数组");
    private Integer code;
    private String desc;
    public static OutResultType getEumByCode(Integer code) {
        if (code == null) {
            return null;
        }
        for (OutResultType statusEnum : OutResultType.values()) {
            if (statusEnum.getCode().equals(code)) {
                return statusEnum;
            }
        }
        return null;
    }
}
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmItemOutputService.java
@@ -20,7 +20,7 @@
    MmItemOutputEntity getOutPutById(String outputid);
    List<com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity> getByItemid(String itemid);
    List<MmItemOutputEntity> getByItemid(String itemId);
    MmItemOutputEntity getByItemid(String itemid, String resultstr);
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/MmPredictModelService.java
@@ -24,7 +24,7 @@
    List<MmPredictModelEntity> getNoSettingmapPredictModel(Map<String, Object> params);
    List<MmPredictModelEntity> getActiveModelByItemId(String itemId);
    MmPredictModelEntity getActiveModelByItemId(String itemId);
    void clearCache();
}
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmItemOutputServiceImpl.java
@@ -8,6 +8,7 @@
import com.iailab.module.model.mcs.pre.service.MmItemOutputService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
@@ -24,19 +25,21 @@
    private MmItemOutputDao mmItemOutputDao;
    private static Map<String, MmItemOutputEntity> outputMap = new ConcurrentHashMap<>();
    private static Map<String, List<MmItemOutputEntity>> itemMap = new ConcurrentHashMap<>();
    
    @Override
    public void saveMmItemOutput(List<MmItemOutputEntity> mmItemOutput) {
        mmItemOutputDao.insert(mmItemOutput);
        // 清空缓存
        outputMap.clear();
        clearCache();
    }
    @Override
    public void update(MmItemOutputEntity mmItemOutput) {
        mmItemOutputDao.updateById(mmItemOutput);
        // 清空缓存
        outputMap.clear();
        clearCache();
    }
    public void deleteBatch(String[] itemIds) {
@@ -44,15 +47,27 @@
        queryWrapper.in("itemid", itemIds);
        mmItemOutputDao.delete(queryWrapper);
        // 清空缓存
        clearCache();
    }
    private void clearCache() {
        outputMap.clear();
        itemMap.clear();
    }
    @Override
    public List<MmItemOutputEntity> getByItemid(String itemid) {
    public List<MmItemOutputEntity> getByItemid(String itemId) {
        if (itemMap.containsKey(itemId)) {
            return itemMap.get(itemId);
        }
        QueryWrapper<MmItemOutputEntity> queryWrapper = new QueryWrapper<>();
        queryWrapper.eq("itemid", itemid).orderByAsc("outputorder");
        queryWrapper.eq("itemid", itemId).orderByAsc("outputorder");
        List<MmItemOutputEntity> list = mmItemOutputDao.selectList(queryWrapper);
        return list;
        if (CollectionUtils.isEmpty(list)) {
            return null;
        }
        itemMap.put(itemId, list);
        return itemMap.get(itemId);
    }
    @Override
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mcs/pre/service/impl/MmPredictModelServiceImpl.java
@@ -2,8 +2,6 @@
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.iailab.framework.common.service.impl.BaseServiceImpl;
import com.iailab.module.model.mcs.pre.dao.MmPredictMergeItemDao;
import com.iailab.module.model.mcs.pre.dao.MmPredictModelDao;
import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
import com.iailab.module.model.mcs.pre.service.MmPredictModelService;
@@ -12,7 +10,6 @@
import org.springframework.util.CollectionUtils;
import java.math.BigDecimal;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@@ -29,6 +26,8 @@
    private MmPredictModelDao mmPredictModelDao;
    private static Map<String, MmPredictModelEntity> modelEntityMap = new ConcurrentHashMap<>();
    private static Map<String, MmPredictModelEntity> activeModelMap = new ConcurrentHashMap<>();
    @Override
    public void savePredictModel(MmPredictModelEntity predictModel) {
@@ -57,6 +56,7 @@
    @Override
    public void clearCache() {
        modelEntityMap.clear();
        activeModelMap.clear();
    }
    @Override
@@ -99,7 +99,15 @@
    }
    @Override
    public List<MmPredictModelEntity> getActiveModelByItemId(String itemId) {
        return mmPredictModelDao.getActiveModelByItemId(itemId);
    public MmPredictModelEntity getActiveModelByItemId(String itemId) {
        if (activeModelMap.containsKey(itemId)) {
            return activeModelMap.get(itemId);
        }
        List<MmPredictModelEntity> list = mmPredictModelDao.getActiveModelByItemId(itemId);
        if (CollectionUtils.isEmpty(list)) {
            return null;
        }
        activeModelMap.put(itemId, list.get(0));
        return activeModelMap.get(itemId);
    }
}
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictItemNormalHandlerImpl.java
@@ -1,9 +1,7 @@
package com.iailab.module.model.mdk.predict.impl;
import com.iailab.framework.common.util.collection.CollectionUtils;
import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
import com.iailab.module.model.mcs.pre.service.MmPredictModelService;
import com.iailab.module.model.mdk.common.enums.ItemPredictStatus;
import com.iailab.module.model.mdk.common.exceptions.ItemInvokeException;
import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
import com.iailab.module.model.mdk.predict.PredictItemHandler;
@@ -13,12 +11,8 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.sql.Timestamp;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Date;
import java.util.List;
/**
 * @author PanZhibao
@@ -43,22 +37,18 @@
     * @throws ItemInvokeException
     */
    @Override
    public PredictResultVO predict(Date predictTime, ItemVO predictItemDto) throws ItemInvokeException{
    public PredictResultVO predict(Date predictTime, ItemVO predictItemDto) throws ItemInvokeException {
        PredictResultVO predictResult = new PredictResultVO();
        String itemId = predictItemDto.getId();
        predictResult.setPredictId(itemId);
        try {
            // 获取预测项模型
            List<MmPredictModelEntity> predictModelList = mmPredictModelService.getActiveModelByItemId(itemId);
            if (CollectionUtils.isAnyEmpty(predictModelList)) {
            MmPredictModelEntity predictModel = mmPredictModelService.getActiveModelByItemId(itemId);
            if (predictModel == null) {
                throw new ModelInvokeException(MessageFormat.format("{0},itemId={1}",
                        ModelInvokeException.errorGetModelEntity, itemId));
            }
            MmPredictModelEntity predictModel = predictModelList.get(0);
            predictResult = predictModelHandler.predictByModel(predictTime, predictModel);
        } catch (Exception ex) {
            ex.printStackTrace();
            //预测项预测失败的状态
            throw new ItemInvokeException(MessageFormat.format("{0},itemId={1}",
                    ItemInvokeException.errorItemFailed, itemId));
        }
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/predict/impl/PredictModelHandlerImpl.java
@@ -4,6 +4,7 @@
import com.alibaba.fastjson.JSONObject;
import com.iail.model.IAILModel;
import com.iailab.module.model.common.enums.CommonConstant;
import com.iailab.module.model.common.enums.OutResultType;
import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
@@ -55,12 +56,11 @@
            throw new ModelInvokeException("modelEntity is null");
        }
        String modelId = predictModel.getId();
        try {
            List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Predict.name(), modelId, predictTime);
            String modelPath = predictModel.getModelpath();
            if (modelPath == null) {
                System.out.println("模型路径不存在,modelId=" + modelId);
                log.info("模型路径不存在,modelId=" + modelId);
                return null;
            }
            IAILModel newModelBean = composeNewModelBean(predictModel);
@@ -93,36 +93,33 @@
            }
            modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
            //打印结果
            log.info("模型计算完成:modelId=" + modelId + modelResult);
            JSONObject jsonObjResult = new JSONObject();
            jsonObjResult.put("result", modelResult);
            log.info(String.valueOf(jsonObjResult));
            List<MmItemOutputEntity> ItemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
            log.info("模型计算完成:modelId=" + modelId + modelResult);
            Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(ItemOutputList.size());
            for (MmItemOutputEntity outputEntity : ItemOutputList) {
                String resultStr = outputEntity.getResultstr();
                if (modelResult.containsKey(resultStr)) {
                    if (outputEntity.getResultType() == 1) {
                        // 一维数组
                        Double[] temp = (Double[]) modelResult.get(resultStr);
                        double[] temp1 = new double[temp.length];
                        for (int i = 0; i < temp.length; i++) {
                            temp1[i] = temp[i].doubleValue();
            List<MmItemOutputEntity> itemOutputList = mmItemOutputService.getByItemid(predictModel.getItemid());
            Map<MmItemOutputEntity, double[]> predictMatrixs = new HashMap<>(itemOutputList.size());
            for (MmItemOutputEntity output : itemOutputList) {
                if (!modelResult.containsKey(output.getResultstr())) {
                    continue;
                }
                OutResultType outResultType = OutResultType.getEumByCode(output.getResultType());
                switch (outResultType) {
                    case D1:
                        double[] temp1 = (double[]) modelResult.get(output.getResultstr());
                        predictMatrixs.put(output, temp1);
                        break;
                    case D2:
                        double[][] temp2 = (double[][]) modelResult.get(output.getResultstr());
                        double[] tempColumn = new double[temp2.length];
                        for (int i = 0; i < tempColumn.length; i++) {
                            tempColumn[i] = temp2[i][output.getResultIndex()];
                        }
                        predictMatrixs.put(outputEntity, temp1);
                    } else if (outputEntity.getResultType() == 2) {
                        // 二维数组
                        Double[][] temp = (Double[][]) modelResult.get(resultStr);
                        Double[] temp2 = temp[outputEntity.getResultIndex()];
                        double[] temp1 = new double[temp2.length];
                        for (int i = 0; i < temp2.length; i++) {
                            temp1[i] = temp2[i].doubleValue();
                        }
                        predictMatrixs.put(outputEntity, temp1);
                    }
                        predictMatrixs.put(output, tempColumn);
                        break;
                    default:
                        break;
                }
            }
            result.setPredictMatrixs(predictMatrixs);