潘志宝
2024-10-10 aa96f9cab5f43c372d96a0844792d90d2dad2d5e
提交 | 用户 | 时间
054fb9 1 package com.iailab.module.model.mdk.schedule.impl;
2
3 import com.alibaba.fastjson.JSONArray;
4 import com.alibaba.fastjson.JSONObject;
5 import com.iail.IAILMDK;
6 import com.iail.model.IAILModel;
7 import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity;
bbc1ee 8 import com.iailab.module.model.mcs.sche.entity.StScheduleModelSettingEntity;
9 import com.iailab.module.model.mcs.sche.entity.StScheduleSchemeEntity;
054fb9 10 import com.iailab.module.model.mcs.sche.service.StScheduleModelService;
bbc1ee 11 import com.iailab.module.model.mcs.sche.service.StScheduleModelSettingService;
12 import com.iailab.module.model.mcs.sche.service.StScheduleSchemeService;
054fb9 13 import com.iailab.module.model.mdk.common.enums.TypeA;
14 import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
15 import com.iailab.module.model.mdk.sample.SampleConstructor;
16 import com.iailab.module.model.mdk.sample.dto.SampleData;
17 import com.iailab.module.model.mdk.schedule.ScheduleModelHandler;
18 import com.iailab.module.model.mdk.vo.ScheduleResultVO;
19 import lombok.extern.slf4j.Slf4j;
20 import org.springframework.beans.factory.annotation.Autowired;
21 import org.springframework.stereotype.Component;
22 import org.springframework.util.CollectionUtils;
23
24 import java.text.MessageFormat;
25 import java.util.*;
26
27 /**
28  * @author PanZhibao
29  * @Description
30  * @createTime 2024年09月05日
31  */
32 @Slf4j
33 @Component
34 public class ScheduleModelHandlerImpl implements ScheduleModelHandler {
35
36     @Autowired
bbc1ee 37     private StScheduleSchemeService stScheduleSchemeService;
38
39     @Autowired
054fb9 40     private StScheduleModelService stScheduleModelService;
41
42     @Autowired
bbc1ee 43     private StScheduleModelSettingService stScheduleModelSettingService;
054fb9 44
45     @Autowired
46     private SampleConstructor sampleConstructor;
47
48     @Override
bbc1ee 49     public ScheduleResultVO doSchedule(String schemeCode, Date scheduleTime) throws ModelInvokeException {
054fb9 50         ScheduleResultVO scheduleResult = new ScheduleResultVO();
bbc1ee 51         StScheduleSchemeEntity scheduleScheme = stScheduleSchemeService.getByCode(schemeCode);
b425df 52         StScheduleModelEntity scheduleModel = stScheduleModelService.get(scheduleScheme.getModelId());
bbc1ee 53         if (scheduleModel == null) {
054fb9 54             throw new ModelInvokeException(MessageFormat.format("{0},modelId={1}",
bbc1ee 55                     ModelInvokeException.errorGetModelEntity, scheduleModel.getId()));
054fb9 56         }
bbc1ee 57         String modelId = scheduleModel.getId();
054fb9 58         try {
59             IAILModel newModelBean = new IAILModel();
60             //1.根据模型id构造模型输入样本
61             List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime);
62             if (CollectionUtils.isEmpty(sampleDataList)) {
bbc1ee 63                 log.info("调度模型构造样本失败,schemeCode=" + schemeCode);
054fb9 64                 return null;
65             }
66
67             //2.拼接newModelBean的参数结构:a.类名、方法名 b.参数类型
bbc1ee 68             String className = scheduleModel.getClassName() .trim();
69             String methodName = scheduleModel.getMethodName().trim();
054fb9 70             newModelBean.setClassName(className);
71             newModelBean.setMethodName(methodName);
72
73             Class<?>[] paramsArray = new Class[3];
74             paramsArray[0] = double[][].class;
75             paramsArray[1] = double[][].class;
76             paramsArray[2] = HashMap.class;
77             newModelBean.setParamsArray(paramsArray);
78
79             //3.拼接settings参数
80             HashMap<String, Object> settings_predict = getPredictSettingsByModelId(modelId);
81
82             //4.构造param2Values参数结构
83             int count = sampleDataList.size();
84             Object[] param2Values = new Object[count + 1];
85             for (int i = 0; i < count; i++) {
86                 param2Values[i] = sampleDataList.get(i).getMatrix();
87             }
88             param2Values[count] = settings_predict;
89
90             //打印参数
91             log.info("##############调度模型:modelId=" + modelId + " ##########################");
92             JSONObject jsonObjNewModelBean = new JSONObject();
93             jsonObjNewModelBean.put("newModelBean", newModelBean);
94             log.info(String.valueOf(jsonObjNewModelBean));
95             JSONObject jsonObjParam2Values = new JSONObject();
96             jsonObjParam2Values.put("param2Values", param2Values);
97             log.info(String.valueOf(jsonObjParam2Values));
98
99             //IAILMDK.run
100             HashMap<String, Object> result = IAILMDK.run(newModelBean, param2Values);
101
102             //打印结果
103             JSONObject jsonObjResult = new JSONObject();
104             jsonObjResult.put("result", result);
105             log.info(String.valueOf(jsonObjResult));
106             log.info("调度模型计算完成:modelId=" + modelId + result);
107
108             //5.返回调度结果
109             scheduleResult.setResult(result);
110             scheduleResult.setModelId(modelId);
bbc1ee 111             scheduleResult.setSchemeId(scheduleScheme.getId());
054fb9 112             scheduleResult.setScheduleTime(scheduleTime);
113         } catch (Exception ex) {
114             log.error("IAILMDK.run()执行失败");
115             log.error(ex.getMessage());
116             log.error("调用发生异常,异常信息为:{}", ex);
117             ex.printStackTrace();
118         }
119         return scheduleResult;
120     }
121
122     /**
123      * 根据模型id获取参数map
124      *
125      * @param modelId
126      * @return
127      */
128     private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
bbc1ee 129         List<StScheduleModelSettingEntity> list = stScheduleModelSettingService.getByModelId(modelId);
054fb9 130         if (CollectionUtils.isEmpty(list)) {
131             return null;
132         }
133         HashMap<String, Object> result = new HashMap<>();
bbc1ee 134         for (StScheduleModelSettingEntity entry : list) {
054fb9 135             String valueType = entry.getValuetype().trim();
136             String valueStr = entry.getValue().trim();
137             if ("int".equals(valueType)) {
138                 int value = Integer.parseInt(valueStr);
139                 result.put(entry.getKey(), value);
140             } else if ("double".equals(valueType)) {
141                 double value = Double.parseDouble(valueStr);
142                 result.put(entry.getKey(), value);
143             } else if ("string".equals(valueType)) {
144                 String value = valueStr;
145                 result.put(entry.getKey(), value);
146             } else if ("float".equals(valueType)) {
147                 float value = Float.parseFloat(valueStr);
148                 result.put(entry.getKey(), value);
149             } else if ("[[D".equals(valueType)) {
150                 String valueStrTemp = entry.getValue();
151                 try {
152                     //1.二位数组的行按照"/"来分割
153                     String[] rowList = valueStrTemp.split("/");
154                     int row = rowList.length;
155                     int col = rowList[0].split(",").length;
156                     double[][] value1 = new double[row][col];
157                     for (int i = 0; i < rowList.length; i++) {
158                         //2.二位数组的列按照","来分割
159                         String[] colList = rowList[i].split(",");
160                         for (int j = 0; j < colList.length; j++) {
161                             value1[i][j] = Double.parseDouble(colList[j]);
162                         }
163                     }
164                     //把从数据库的得到的参数的二维数组降为一维数组
165                     //int len =0;
166                     double[] value = new double[row * col];
167                     /*for (int j = 0; j <value1.length ; j++) {
168                         len+= value1.length;
169                     }*/
170                     //value = new double[len];
171                     int index = 0;
172                     for (int i = 0; i < value1.length; i++) {
173                         for (int j = 0; j < value1[i].length; j++) {
174                             value[index++] = value1[i][j];
175                         }
176                     }
177                     result.put(entry.getKey(), value);
178                 } catch (Exception ex) {
179                     System.out.println("二维数组类型的setting格式不正确");
180                     ex.printStackTrace();
181                 }
182             } else if ("decimalArray".equals(valueType)) {
183                 JSONArray valueArray = JSONArray.parseArray(entry.getValue());
184                 double[] value = new double[valueArray.size()];
185                 for (int i = 0; i < valueArray.size(); i++) {
186                     value[i] = Double.parseDouble(valueArray.get(i).toString());
187                 }
188                 result.put(entry.getKey(), value);
189             } else if ("decimal".equals(valueType)) {
190                 double value = Double.parseDouble(entry.getValue());
191                 //BigDecimal value = new BigDecimal(entry.getValue());
192                 result.put(entry.getKey(), value);
193             }
194         }
195         return result;
196     }
197 }