潘志宝
2024-09-06 c06f48bded461209f117167fbf89ed57a3f37ef4
提交 | 用户 | 时间
054fb9 1 package com.iailab.module.model.mdk.schedule.impl;
2
3 import com.alibaba.fastjson.JSONArray;
4 import com.alibaba.fastjson.JSONObject;
5 import com.iail.IAILMDK;
6 import com.iail.model.IAILModel;
7 import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity;
8 import com.iailab.module.model.mcs.sche.entity.StScheduleParamSettingEntity;
9 import com.iailab.module.model.mcs.sche.service.StScheduleModelService;
10 import com.iailab.module.model.mcs.sche.service.StScheduleParamSettingService;
11 import com.iailab.module.model.mcs.sche.service.StScheduleService;
12 import com.iailab.module.model.mdk.common.enums.TypeA;
13 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
14 import com.iailab.module.model.mdk.sample.SampleConstructor;
15 import com.iailab.module.model.mdk.sample.dto.SampleData;
16 import com.iailab.module.model.mdk.schedule.ScheduleModelHandler;
17 import com.iailab.module.model.mdk.vo.ScheduleResultVO;
18 import lombok.extern.slf4j.Slf4j;
19 import org.springframework.beans.factory.annotation.Autowired;
20 import org.springframework.stereotype.Component;
21 import org.springframework.util.CollectionUtils;
22
23 import java.text.MessageFormat;
24 import java.util.*;
25
26 /**
27  * @author PanZhibao
28  * @Description
29  * @createTime 2024年09月05日
30  */
31 @Slf4j
32 @Component
33 public class ScheduleModelHandlerImpl implements ScheduleModelHandler {
34
35     @Autowired
36     private StScheduleModelService stScheduleModelService;
37
38     @Autowired
39     private StScheduleService stScheduleService;
40
41     @Autowired
42     private SampleConstructor sampleConstructor;
43
44     @Autowired
45     private StScheduleParamSettingService stScheduleParamSettingService;
46
47     @Override
48     public ScheduleResultVO doSchedule(String scheduleCode, Date scheduleTime) throws ModelInvokeException {
49         ScheduleResultVO scheduleResult = new ScheduleResultVO();
50
51         // todo
52         StScheduleModelEntity schModelEntity = stScheduleModelService.selectById(scheduleCode);
53         if (schModelEntity == null) {
54             throw new ModelInvokeException(MessageFormat.format("{0},modelId={1}",
55                     ModelInvokeException.errorGetModelEntity, schModelEntity.getId()));
56         }
57         String modelId = schModelEntity.getId();
58         try {
59             IAILModel newModelBean = new IAILModel();
60             //1.根据模型id构造模型输入样本
61             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime);
62             if (CollectionUtils.isEmpty(sampleDataList)) {
63                 log.info("调度模型构造样本失败,scheduleCode=" + scheduleCode);
64                 return null;
65             }
66
67             //2.拼接newModelBean的参数结构:a.类名、方法名 b.参数类型
68             String className = schModelEntity.getClassname().trim();
69             String methodName = schModelEntity.getMethodname().trim();
70             newModelBean.setClassName(className);
71             newModelBean.setMethodName(methodName);
72
73             Class<?>[] paramsArray = new Class[3];
74             paramsArray[0] = double[][].class;
75             paramsArray[1] = double[][].class;
76             paramsArray[2] = HashMap.class;
77             newModelBean.setParamsArray(paramsArray);
78
79             //3.拼接settings参数
80             HashMap<String, Object> settings_predict = getPredictSettingsByModelId(modelId);
81
82             //4.构造param2Values参数结构
83             int count = sampleDataList.size();
84             Object[] param2Values = new Object[count + 1];
85             for (int i = 0; i < count; i++) {
86                 param2Values[i] = sampleDataList.get(i).getMatrix();
87             }
88             param2Values[count] = settings_predict;
89
90             //打印参数
91             log.info("##############调度模型:modelId=" + modelId + " ##########################");
92             JSONObject jsonObjNewModelBean = new JSONObject();
93             jsonObjNewModelBean.put("newModelBean", newModelBean);
94             log.info(String.valueOf(jsonObjNewModelBean));
95             JSONObject jsonObjParam2Values = new JSONObject();
96             jsonObjParam2Values.put("param2Values", param2Values);
97             log.info(String.valueOf(jsonObjParam2Values));
98
99             //IAILMDK.run
100             HashMap<String, Object> result = IAILMDK.run(newModelBean, param2Values);
101
102             //打印结果
103             JSONObject jsonObjResult = new JSONObject();
104             jsonObjResult.put("result", result);
105             log.info(String.valueOf(jsonObjResult));
106             log.info("调度模型计算完成:modelId=" + modelId + result);
107
108             //5.返回调度结果
109             scheduleResult.setResult(result);
110             scheduleResult.setModelId(modelId);
111             scheduleResult.setScheduleId(schModelEntity.getId());
112             scheduleResult.setScheduleTime(scheduleTime);
113         } catch (Exception ex) {
114             log.error("IAILMDK.run()执行失败");
115             log.error(ex.getMessage());
116             log.error("调用发生异常,异常信息为:{}", ex);
117             ex.printStackTrace();
118         }
119         return scheduleResult;
120     }
121
122     /**
123      * 根据模型id获取参数map
124      *
125      * @param modelId
126      * @return
127      */
128     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
129         List<StScheduleParamSettingEntity> list = stScheduleParamSettingService.getByModelid(modelId);
130         if (CollectionUtils.isEmpty(list)) {
131             return null;
132         }
133         HashMap<String, Object> result = new HashMap<>();
134         for (StScheduleParamSettingEntity entry : list) {
135             String valueType = entry.getValuetype().trim();
136             String valueStr = entry.getValue().trim();
137             if ("int".equals(valueType)) {
138                 int value = Integer.parseInt(valueStr);
139                 result.put(entry.getKey(), value);
140             } else if ("double".equals(valueType)) {
141                 double value = Double.parseDouble(valueStr);
142                 result.put(entry.getKey(), value);
143             } else if ("string".equals(valueType)) {
144                 String value = valueStr;
145                 result.put(entry.getKey(), value);
146             } else if ("float".equals(valueType)) {
147                 float value = Float.parseFloat(valueStr);
148                 result.put(entry.getKey(), value);
149             } else if ("[[D".equals(valueType)) {
150                 String valueStrTemp = entry.getValue();
151                 try {
152                     //1.二位数组的行按照"/"来分割
153                     String[] rowList = valueStrTemp.split("/");
154                     int row = rowList.length;
155                     int col = rowList[0].split(",").length;
156                     double[][] value1 = new double[row][col];
157                     for (int i = 0; i < rowList.length; i++) {
158                         //2.二位数组的列按照","来分割
159                         String[] colList = rowList[i].split(",");
160                         for (int j = 0; j < colList.length; j++) {
161                             value1[i][j] = Double.parseDouble(colList[j]);
162                         }
163                     }
164                     //把从数据库的得到的参数的二维数组降为一维数组
165                     //int len =0;
166                     double[] value = new double[row * col];
167                     /*for (int j = 0; j <value1.length ; j++) {
168                         len+= value1.length;
169                     }*/
170                     //value = new double[len];
171                     int index = 0;
172                     for (int i = 0; i < value1.length; i++) {
173                         for (int j = 0; j < value1[i].length; j++) {
174                             value[index++] = value1[i][j];
175                         }
176                     }
177                     result.put(entry.getKey(), value);
178                 } catch (Exception ex) {
179                     System.out.println("二维数组类型的setting格式不正确");
180                     ex.printStackTrace();
181                 }
182             } else if ("decimalArray".equals(valueType)) {
183                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
184                 double[] value = new double[valueArray.size()];
185                 for (int i = 0; i < valueArray.size(); i++) {
186                     value[i] = Double.parseDouble(valueArray.get(i).toString());
187                 }
188                 result.put(entry.getKey(), value);
189             } else if ("decimal".equals(valueType)) {
190                 double value = Double.parseDouble(entry.getValue());
191                 //BigDecimal value = new BigDecimal(entry.getValue());
192                 result.put(entry.getKey(), value);
193             }
194         }
195         return result;
196     }
197 }