houzhongjian
2024-12-04 a82313d17b2b5d1c02e880122efc1b701c401dcf
提交 | 用户 | 时间
054fb9 1 package com.iailab.module.model.mdk.schedule.impl;
2
3 import com.alibaba.fastjson.JSONArray;
4 import com.alibaba.fastjson.JSONObject;
5 import com.iail.IAILMDK;
6 import com.iail.model.IAILModel;
51c1c2 7 import com.iailab.module.model.common.enums.CommonConstant;
054fb9 8 import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity;
bbc1ee 9 import com.iailab.module.model.mcs.sche.entity.StScheduleModelSettingEntity;
10 import com.iailab.module.model.mcs.sche.entity.StScheduleSchemeEntity;
054fb9 11 import com.iailab.module.model.mcs.sche.service.StScheduleModelService;
bbc1ee 12 import com.iailab.module.model.mcs.sche.service.StScheduleModelSettingService;
13 import com.iailab.module.model.mcs.sche.service.StScheduleSchemeService;
054fb9 14 import com.iailab.module.model.mdk.common.enums.TypeA;
15 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
16 import com.iailab.module.model.mdk.sample.SampleConstructor;
17 import com.iailab.module.model.mdk.sample.dto.SampleData;
18 import com.iailab.module.model.mdk.schedule.ScheduleModelHandler;
19 import com.iailab.module.model.mdk.vo.ScheduleResultVO;
51c1c2 20 import com.iailab.module.model.mpk.common.utils.DllUtils;
054fb9 21 import lombok.extern.slf4j.Slf4j;
22 import org.springframework.beans.factory.annotation.Autowired;
23 import org.springframework.stereotype.Component;
24 import org.springframework.util.CollectionUtils;
25
26 import java.text.MessageFormat;
27 import java.util.*;
28
29 /**
30  * @author PanZhibao
31  * @Description
32  * @createTime 2024年09月05日
33  */
34 @Slf4j
35 @Component
36 public class ScheduleModelHandlerImpl implements ScheduleModelHandler {
37
38     @Autowired
bbc1ee 39     private StScheduleSchemeService stScheduleSchemeService;
40
41     @Autowired
054fb9 42     private StScheduleModelService stScheduleModelService;
43
44     @Autowired
bbc1ee 45     private StScheduleModelSettingService stScheduleModelSettingService;
054fb9 46
47     @Autowired
48     private SampleConstructor sampleConstructor;
49
50     @Override
bbc1ee 51     public ScheduleResultVO doSchedule(String schemeCode, Date scheduleTime) throws ModelInvokeException {
054fb9 52         ScheduleResultVO scheduleResult = new ScheduleResultVO();
bbc1ee 53         StScheduleSchemeEntity scheduleScheme = stScheduleSchemeService.getByCode(schemeCode);
b425df 54         StScheduleModelEntity scheduleModel = stScheduleModelService.get(scheduleScheme.getModelId());
bbc1ee 55         if (scheduleModel == null) {
054fb9 56             throw new ModelInvokeException(MessageFormat.format("{0},modelId={1}",
bbc1ee 57                     ModelInvokeException.errorGetModelEntity, scheduleModel.getId()));
054fb9 58         }
bbc1ee 59         String modelId = scheduleModel.getId();
054fb9 60         try {
61             IAILModel newModelBean = new IAILModel();
62             //1.根据模型id构造模型输入样本
63             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime);
64             if (CollectionUtils.isEmpty(sampleDataList)) {
bbc1ee 65                 log.info("调度模型构造样本失败,schemeCode=" + schemeCode);
054fb9 66                 return null;
67             }
68
69             //2.拼接newModelBean的参数结构:a.类名、方法名 b.参数类型
bbc1ee 70             String className = scheduleModel.getClassName() .trim();
71             String methodName = scheduleModel.getMethodName().trim();
054fb9 72             newModelBean.setClassName(className);
73             newModelBean.setMethodName(methodName);
74
75             Class<?>[] paramsArray = new Class[3];
76             paramsArray[0] = double[][].class;
77             paramsArray[1] = double[][].class;
78             paramsArray[2] = HashMap.class;
79             newModelBean.setParamsArray(paramsArray);
80
81             //3.拼接settings参数
82             HashMap<String, Object> settings_predict = getPredictSettingsByModelId(modelId);
83
84             //4.构造param2Values参数结构
85             int count = sampleDataList.size();
86             Object[] param2Values = new Object[count + 1];
87             for (int i = 0; i < count; i++) {
88                 param2Values[i] = sampleDataList.get(i).getMatrix();
89             }
90             param2Values[count] = settings_predict;
91
92             //打印参数
c204b3 93             log.info("##############调度模型:scheduleScheme=" + scheduleScheme.getCode() + " ##########################");
054fb9 94             JSONObject jsonObjNewModelBean = new JSONObject();
95             jsonObjNewModelBean.put("newModelBean", newModelBean);
96             log.info(String.valueOf(jsonObjNewModelBean));
97             JSONObject jsonObjParam2Values = new JSONObject();
98             jsonObjParam2Values.put("param2Values", param2Values);
99             log.info(String.valueOf(jsonObjParam2Values));
100
c204b3 101             //运行模型
102             HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, scheduleScheme.getMpkprojectid());
51c1c2 103             if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
104                     !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
105                 throw new RuntimeException("模型结果异常:" + modelResult);
106             }
c204b3 107             modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
054fb9 108
109             //打印结果
110             JSONObject jsonObjResult = new JSONObject();
c204b3 111             jsonObjResult.put("result", modelResult);
054fb9 112             log.info(String.valueOf(jsonObjResult));
113
114             //5.返回调度结果
c204b3 115             scheduleResult.setResult(modelResult);
054fb9 116             scheduleResult.setModelId(modelId);
bbc1ee 117             scheduleResult.setSchemeId(scheduleScheme.getId());
054fb9 118             scheduleResult.setScheduleTime(scheduleTime);
119         } catch (Exception ex) {
120             log.error("IAILMDK.run()执行失败");
121             log.error(ex.getMessage());
122             log.error("调用发生异常,异常信息为:{}", ex);
123             ex.printStackTrace();
124         }
125         return scheduleResult;
126     }
127
128     /**
129      * 根据模型id获取参数map
130      *
131      * @param modelId
132      * @return
133      */
134     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
bbc1ee 135         List<StScheduleModelSettingEntity> list = stScheduleModelSettingService.getByModelId(modelId);
054fb9 136         if (CollectionUtils.isEmpty(list)) {
137             return null;
138         }
139         HashMap<String, Object> result = new HashMap<>();
bbc1ee 140         for (StScheduleModelSettingEntity entry : list) {
054fb9 141             String valueType = entry.getValuetype().trim();
142             String valueStr = entry.getValue().trim();
143             if ("int".equals(valueType)) {
144                 int value = Integer.parseInt(valueStr);
145                 result.put(entry.getKey(), value);
146             } else if ("double".equals(valueType)) {
147                 double value = Double.parseDouble(valueStr);
148                 result.put(entry.getKey(), value);
149             } else if ("string".equals(valueType)) {
150                 String value = valueStr;
151                 result.put(entry.getKey(), value);
152             } else if ("float".equals(valueType)) {
153                 float value = Float.parseFloat(valueStr);
154                 result.put(entry.getKey(), value);
155             } else if ("[[D".equals(valueType)) {
156                 String valueStrTemp = entry.getValue();
157                 try {
158                     //1.二位数组的行按照"/"来分割
159                     String[] rowList = valueStrTemp.split("/");
160                     int row = rowList.length;
161                     int col = rowList[0].split(",").length;
162                     double[][] value1 = new double[row][col];
163                     for (int i = 0; i < rowList.length; i++) {
164                         //2.二位数组的列按照","来分割
165                         String[] colList = rowList[i].split(",");
166                         for (int j = 0; j < colList.length; j++) {
167                             value1[i][j] = Double.parseDouble(colList[j]);
168                         }
169                     }
170                     //把从数据库的得到的参数的二维数组降为一维数组
171                     //int len =0;
172                     double[] value = new double[row * col];
173                     /*for (int j = 0; j <value1.length ; j++) {
174                         len+= value1.length;
175                     }*/
176                     //value = new double[len];
177                     int index = 0;
178                     for (int i = 0; i < value1.length; i++) {
179                         for (int j = 0; j < value1[i].length; j++) {
180                             value[index++] = value1[i][j];
181                         }
182                     }
183                     result.put(entry.getKey(), value);
184                 } catch (Exception ex) {
185                     System.out.println("二维数组类型的setting格式不正确");
186                     ex.printStackTrace();
187                 }
188             } else if ("decimalArray".equals(valueType)) {
189                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
190                 double[] value = new double[valueArray.size()];
191                 for (int i = 0; i < valueArray.size(); i++) {
192                     value[i] = Double.parseDouble(valueArray.get(i).toString());
193                 }
194                 result.put(entry.getKey(), value);
195             } else if ("decimal".equals(valueType)) {
196                 double value = Double.parseDouble(entry.getValue());
197                 //BigDecimal value = new BigDecimal(entry.getValue());
198                 result.put(entry.getKey(), value);
199             }
200         }
201         return result;
202     }
203 }