潘志宝
2024-12-23 5bf42aa9950058f391805e6fb8d7376f4378924b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
package com.iailab.module.model.mdk.schedule.impl;
 
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.iail.model.IAILModel;
import com.iailab.module.model.common.enums.CommonConstant;
import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity;
import com.iailab.module.model.mcs.sche.entity.StScheduleModelSettingEntity;
import com.iailab.module.model.mcs.sche.entity.StScheduleSchemeEntity;
import com.iailab.module.model.mcs.sche.service.StScheduleModelService;
import com.iailab.module.model.mcs.sche.service.StScheduleModelSettingService;
import com.iailab.module.model.mcs.sche.service.StScheduleSchemeService;
import com.iailab.module.model.mdk.common.enums.TypeA;
import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
import com.iailab.module.model.mdk.sample.SampleConstructor;
import com.iailab.module.model.mdk.sample.dto.SampleData;
import com.iailab.module.model.mdk.schedule.ScheduleModelHandler;
import com.iailab.module.model.mdk.vo.ScheduleResultVO;
import com.iailab.module.model.mpk.common.MdkConstant;
import com.iailab.module.model.mpk.common.utils.DllUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
 
import java.text.MessageFormat;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
 
/**
 * @author PanZhibao
 * @Description
 * @createTime 2024年09月05日
 */
@Slf4j
@Component
public class ScheduleModelHandlerImpl implements ScheduleModelHandler {
 
    @Autowired
    private StScheduleSchemeService stScheduleSchemeService;
 
    @Autowired
    private StScheduleModelService stScheduleModelService;
 
    @Autowired
    private StScheduleModelSettingService stScheduleModelSettingService;
 
    @Autowired
    private SampleConstructor sampleConstructor;
 
    @Override
    public ScheduleResultVO doSchedule(String schemeCode, Date scheduleTime) throws ModelInvokeException {
        ScheduleResultVO scheduleResult = new ScheduleResultVO();
        StScheduleSchemeEntity scheduleScheme = stScheduleSchemeService.getByCode(schemeCode);
        StScheduleModelEntity scheduleModel = stScheduleModelService.get(scheduleScheme.getModelId());
        if (scheduleModel == null) {
            throw new ModelInvokeException(MessageFormat.format("{0},modelId={1}",
                    ModelInvokeException.errorGetModelEntity, scheduleModel.getId()));
        }
        String modelId = scheduleModel.getId();
        try {
            //1.根据模型id构造模型输入样本
            long now = System.currentTimeMillis();
            List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime,scheduleScheme.getName());
            log.info("构造模型输入样本消耗时长:" + (System.currentTimeMillis() - now) / 1000 + "秒");
            if (CollectionUtils.isEmpty(sampleDataList)) {
                log.info("调度模型构造样本失败,schemeCode=" + schemeCode);
                return null;
            }
 
            IAILModel newModelBean = composeNewModelBean(scheduleModel);
            HashMap<String, Object> settings = getScheduleSettingsByModelId(modelId);
            if (settings == null) {
                log.error("模型setting不存在,modelId=" + modelId);
                return null;
            }
            // 校验setting必须有pyFile,否则可能导致程序崩溃
            if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) {
                log.error("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY +  "】,请重新上传模型!");
                return null;
            }
            int portLength = sampleDataList.size();
            Object[] param2Values = new Object[portLength + 1];
            for (int i = 0; i < portLength; i++) {
                param2Values[i] = sampleDataList.get(i).getMatrix();
            }
            param2Values[portLength] = settings;
 
            log.info("#######################调度模型 " + scheduleModel.getModelName() + " ##########################");
//            JSONObject jsonObjNewModelBean = new JSONObject();
//            jsonObjNewModelBean.put("newModelBean", newModelBean);
//            log.info(String.valueOf(jsonObjNewModelBean));
//            JSONObject jsonObjParam2Values = new JSONObject();
//            jsonObjParam2Values.put("param2Values", param2Values);
            log.info("参数: " + JSON.toJSONString(param2Values));
 
            //IAILMDK.run
            HashMap<String, Object> modelResult = DllUtils.run(newModelBean, param2Values, scheduleScheme.getMpkprojectid());
            if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) ||
                    !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) {
                throw new RuntimeException("模型结果异常:" + modelResult);
            }
            modelResult = (HashMap<String, Object>) modelResult.get(CommonConstant.MDK_RESULT);
 
            //打印结果
            JSONObject jsonObjResult = new JSONObject();
            jsonObjResult.put("result", modelResult);
            log.info(String.valueOf(jsonObjResult));
 
            //5.返回调度结果
            scheduleResult.setResult(modelResult);
            scheduleResult.setModelId(modelId);
            scheduleResult.setSchemeId(scheduleScheme.getId());
            scheduleResult.setScheduleTime(scheduleTime);
        } catch (Exception ex) {
            log.error("IAILMDK.run()执行失败");
            log.error(ex.getMessage());
            log.error("调用发生异常,异常信息为:{}", ex);
            ex.printStackTrace();
        }
        return scheduleResult;
    }
 
    /**
     * 根据模型id获取参数map
     *
     * @param modelId
     * @return
     */
    private HashMap<String, Object> getScheduleSettingsByModelId(String modelId) {
        List<StScheduleModelSettingEntity> list = stScheduleModelSettingService.getByModelId(modelId);
        if (CollectionUtils.isEmpty(list)) {
            return null;
        }
        HashMap<String, Object> result = new HashMap<>();
        for (StScheduleModelSettingEntity entry : list) {
            String valueType = entry.getValuetype().trim(); //去除两端空格
            if ("int".equals(valueType)) {
                int value = Integer.parseInt(entry.getValue());
                result.put(entry.getKey(), value);
            } else if ("double".equals(valueType)) {
                double value = Double.parseDouble(entry.getValue());
                result.put(entry.getKey(), value);
            } else if ("string".equals(valueType)) {
                String value = entry.getValue();
                result.put(entry.getKey(), value);
            } else if ("decimalArray".equals(valueType)) {
                JSONArray valueArray = JSONArray.parseArray(entry.getValue());
                double[] value = new double[valueArray.size()];
                for (int i = 0; i < valueArray.size(); i++) {
                    value[i] = Double.parseDouble(valueArray.get(i).toString());
                }
                result.put(entry.getKey(), value);
            } else if ("decimal".equals(valueType)) {
                double value = Double.parseDouble(entry.getValue());
                result.put(entry.getKey(), value);
            }
        }
        return result;
    }
 
    private IAILModel composeNewModelBean(StScheduleModelEntity model) {
        IAILModel newModelBean = new IAILModel();
        newModelBean.setClassName(model.getClassName().trim());
        newModelBean.setMethodName(model.getMethodName().trim());
        //构造参数类型
        Class<?>[] paramsArray = new Class[model.getPortLength() + 1];
        for (int i = 0; i < model.getPortLength(); i++) {
            paramsArray[i] = double[][].class;
        }
        paramsArray[model.getPortLength()] = HashMap.class;
        newModelBean.setParamsArray(paramsArray);
        //
//        HashMap<String, Object> dataMap = new HashMap<>();
//        newModelBean.setDataMap(dataMap);
        return newModelBean;
    }
}