From 2e0e42583419225b5dd38e97594de82accd594ad Mon Sep 17 00:00:00 2001
From: dengzedong <dengzedong@email>
Date: 星期二, 31 十二月 2024 17:23:29 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/master'

---
 iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java |  142 ++++++++++++++++++++---------------------------
 1 files changed, 60 insertions(+), 82 deletions(-)

diff --git a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java
index 34ecd54..e6b062d 100644
--- a/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java
+++ b/iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mdk/schedule/impl/ScheduleModelHandlerImpl.java
@@ -1,9 +1,10 @@
 package com.iailab.module.model.mdk.schedule.impl;
 
+import com.alibaba.fastjson.JSON;
 import com.alibaba.fastjson.JSONArray;
 import com.alibaba.fastjson.JSONObject;
-import com.iail.IAILMDK;
 import com.iail.model.IAILModel;
+import com.iailab.module.model.common.enums.CommonConstant;
 import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity;
 import com.iailab.module.model.mcs.sche.entity.StScheduleModelSettingEntity;
 import com.iailab.module.model.mcs.sche.entity.StScheduleSchemeEntity;
@@ -16,13 +17,17 @@
 import com.iailab.module.model.mdk.sample.dto.SampleData;
 import com.iailab.module.model.mdk.schedule.ScheduleModelHandler;
 import com.iailab.module.model.mdk.vo.ScheduleResultVO;
+import com.iailab.module.model.mpk.common.MdkConstant;
+import com.iailab.module.model.mpk.common.utils.DllUtils;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Component;
 import org.springframework.util.CollectionUtils;
 
 import java.text.MessageFormat;
-import java.util.*;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
 
 /**
  * @author PanZhibao
@@ -49,64 +54,58 @@
     public ScheduleResultVO doSchedule(String schemeCode, Date scheduleTime) throws ModelInvokeException {
         ScheduleResultVO scheduleResult = new ScheduleResultVO();
         StScheduleSchemeEntity scheduleScheme = stScheduleSchemeService.getByCode(schemeCode);
-        StScheduleModelEntity scheduleModel = stScheduleModelService.selectById(scheduleScheme.getModelId());
+        StScheduleModelEntity scheduleModel = stScheduleModelService.get(scheduleScheme.getModelId());
         if (scheduleModel == null) {
             throw new ModelInvokeException(MessageFormat.format("{0},modelId={1}",
                     ModelInvokeException.errorGetModelEntity, scheduleModel.getId()));
         }
         String modelId = scheduleModel.getId();
         try {
-            IAILModel newModelBean = new IAILModel();
             //1.根据模型id构造模型输入样本
-            List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime);
+            long now = System.currentTimeMillis();
+            List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime,scheduleScheme.getName());
+            log.info("构造模型输入样本消耗时长:" + (System.currentTimeMillis() - now) / 1000 + "秒");
             if (CollectionUtils.isEmpty(sampleDataList)) {
                 log.info("调度模型构造样本失败,schemeCode=" + schemeCode);
                 return null;
             }
 
-            //2.拼接newModelBean的参数结构:a.类名、方法名 b.参数类型
-            String className = scheduleModel.getClassName() .trim();
-            String methodName = scheduleModel.getMethodName().trim();
-            newModelBean.setClassName(className);
-            newModelBean.setMethodName(methodName);
-
-            Class<?>[] paramsArray = new Class[3];
-            paramsArray[0] = double[][].class;
-            paramsArray[1] = double[][].class;
-            paramsArray[2] = HashMap.class;
-            newModelBean.setParamsArray(paramsArray);
-
-            //3.拼接settings参数
-            HashMap<String, Object> settings_predict = getPredictSettingsByModelId(modelId);
-
-            //4.构造param2Values参数结构
-            int count = sampleDataList.size();
-            Object[] param2Values = new Object[count + 1];
-            for (int i = 0; i < count; i++) {
+            IAILModel newModelBean = composeNewModelBean(scheduleModel);
+            HashMap<String, Object> settings = getScheduleSettingsByModelId(modelId);
+            if (settings == null) {
+                log.error("模型setting不存在,modelId=" + modelId);
+                return null;
+            }
+            // 校验setting必须有pyFile,否则可能导致程序崩溃
+            if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
+                log.error("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY +  "】,请重新上传模型!");
+                return null;
+            }
+            int portLength = sampleDataList.size();
+            Object[] param2Values = new Object[portLength + 1];
+            for (int i = 0; i < portLength; i++) {
                 param2Values[i] = sampleDataList.get(i).getMatrix();
             }
-            param2Values[count] = settings_predict;
+            param2Values[portLength] = settings;
 
-            //打印参数
-            log.info("##############调度模型:modelId=" + modelId + " ##########################");
-            JSONObject jsonObjNewModelBean = new JSONObject();
-            jsonObjNewModelBean.put("newModelBean", newModelBean);
-            log.info(String.valueOf(jsonObjNewModelBean));
-            JSONObject jsonObjParam2Values = new JSONObject();
-            jsonObjParam2Values.put("param2Values", param2Values);
-            log.info(String.valueOf(jsonObjParam2Values));
-
+            log.info("#######################调度模型 " + scheduleModel.getModelName() + " ##########################");
+            log.info("参数: " + JSON.toJSONString(param2Values));
             //IAILMDK.run
-            HashMap<String, Object> result = IAILMDK.run(newModelBean, param2Values);
-
+            HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, scheduleScheme.getMpkprojectid());
+            if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
+                    !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
+                throw new RuntimeException("模型结果异常:" + modelResult);
+            }
+            String statusCode = modelResult.get(CommonConstant.MDK_STATUS_CODE).toString();
+            modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
             //打印结果
             JSONObject jsonObjResult = new JSONObject();
-            jsonObjResult.put("result", result);
+            jsonObjResult.put("result", modelResult);
             log.info(String.valueOf(jsonObjResult));
-            log.info("调度模型计算完成:modelId=" + modelId + result);
 
             //5.返回调度结果
-            scheduleResult.setResult(result);
+            scheduleResult.setResultCode(statusCode);
+            scheduleResult.setResult(modelResult);
             scheduleResult.setModelId(modelId);
             scheduleResult.setSchemeId(scheduleScheme.getId());
             scheduleResult.setScheduleTime(scheduleTime);
@@ -125,60 +124,23 @@
      * @param modelId
      * @return
      */
-    private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
+    private HashMap<String, Object> getScheduleSettingsByModelId(String modelId) {
         List<StScheduleModelSettingEntity> list = stScheduleModelSettingService.getByModelId(modelId);
         if (CollectionUtils.isEmpty(list)) {
             return null;
         }
         HashMap<String, Object> result = new HashMap<>();
         for (StScheduleModelSettingEntity entry : list) {
-            String valueType = entry.getValuetype().trim();
-            String valueStr = entry.getValue().trim();
+            String valueType = entry.getValuetype().trim(); //去除两端空格
             if ("int".equals(valueType)) {
-                int value = Integer.parseInt(valueStr);
+                int value = Integer.parseInt(entry.getValue());
                 result.put(entry.getKey(), value);
             } else if ("double".equals(valueType)) {
-                double value = Double.parseDouble(valueStr);
+                double value = Double.parseDouble(entry.getValue());
                 result.put(entry.getKey(), value);
             } else if ("string".equals(valueType)) {
-                String value = valueStr;
+                String value = entry.getValue();
                 result.put(entry.getKey(), value);
-            } else if ("float".equals(valueType)) {
-                float value = Float.parseFloat(valueStr);
-                result.put(entry.getKey(), value);
-            } else if ("[[D".equals(valueType)) {
-                String valueStrTemp = entry.getValue();
-                try {
-                    //1.二位数组的行按照"/"来分割
-                    String[] rowList = valueStrTemp.split("/");
-                    int row = rowList.length;
-                    int col = rowList[0].split(",").length;
-                    double[][] value1 = new double[row][col];
-                    for (int i = 0; i < rowList.length; i++) {
-                        //2.二位数组的列按照","来分割
-                        String[] colList = rowList[i].split(",");
-                        for (int j = 0; j < colList.length; j++) {
-                            value1[i][j] = Double.parseDouble(colList[j]);
-                        }
-                    }
-                    //把从数据库的得到的参数的二维数组降为一维数组
-                    //int len =0;
-                    double[] value = new double[row * col];
-                    /*for (int j = 0; j <value1.length ; j++) {
-                        len+= value1.length;
-                    }*/
-                    //value = new double[len];
-                    int index = 0;
-                    for (int i = 0; i < value1.length; i++) {
-                        for (int j = 0; j < value1[i].length; j++) {
-                            value[index++] = value1[i][j];
-                        }
-                    }
-                    result.put(entry.getKey(), value);
-                } catch (Exception ex) {
-                    System.out.println("二维数组类型的setting格式不正确");
-                    ex.printStackTrace();
-                }
             } else if ("decimalArray".equals(valueType)) {
                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
                 double[] value = new double[valueArray.size()];
@@ -188,10 +150,26 @@
                 result.put(entry.getKey(), value);
             } else if ("decimal".equals(valueType)) {
                 double value = Double.parseDouble(entry.getValue());
-                //BigDecimal value = new BigDecimal(entry.getValue());
                 result.put(entry.getKey(), value);
             }
         }
         return result;
     }
+
+    private IAILModel composeNewModelBean(StScheduleModelEntity model) {
+        IAILModel newModelBean = new IAILModel();
+        newModelBean.setClassName(model.getClassName().trim());
+        newModelBean.setMethodName(model.getMethodName().trim());
+        //构造参数类型
+        Class<?>[] paramsArray = new Class[model.getPortLength() + 1];
+        for (int i = 0; i < model.getPortLength(); i++) {
+            paramsArray[i] = double[][].class;
+        }
+        paramsArray[model.getPortLength()] = HashMap.class;
+        newModelBean.setParamsArray(paramsArray);
+        //
+//        HashMap<String, Object> dataMap = new HashMap<>();
+//        newModelBean.setDataMap(dataMap);
+        return newModelBean;
+    }
 }
\ No newline at end of file

--
Gitblit v1.9.3