package com.iailab.module.model.mdk.schedule.impl;
|
|
import com.alibaba.fastjson.JSONArray;
|
import com.alibaba.fastjson.JSONObject;
|
import com.iail.IAILMDK;
|
import com.iail.model.IAILModel;
|
import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity;
|
import com.iailab.module.model.mcs.sche.entity.StScheduleModelSettingEntity;
|
import com.iailab.module.model.mcs.sche.entity.StScheduleSchemeEntity;
|
import com.iailab.module.model.mcs.sche.service.StScheduleModelService;
|
import com.iailab.module.model.mcs.sche.service.StScheduleModelSettingService;
|
import com.iailab.module.model.mcs.sche.service.StScheduleSchemeService;
|
import com.iailab.module.model.mdk.common.enums.TypeA;
|
import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException;
|
import com.iailab.module.model.mdk.sample.SampleConstructor;
|
import com.iailab.module.model.mdk.sample.dto.SampleData;
|
import com.iailab.module.model.mdk.schedule.ScheduleModelHandler;
|
import com.iailab.module.model.mdk.vo.ScheduleResultVO;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.CollectionUtils;
|
|
import java.text.MessageFormat;
|
import java.util.*;
|
|
/**
|
* @author PanZhibao
|
* @Description
|
* @createTime 2024年09月05日
|
*/
|
@Slf4j
|
@Component
|
public class ScheduleModelHandlerImpl implements ScheduleModelHandler {
|
|
@Autowired
|
private StScheduleSchemeService stScheduleSchemeService;
|
|
@Autowired
|
private StScheduleModelService stScheduleModelService;
|
|
@Autowired
|
private StScheduleModelSettingService stScheduleModelSettingService;
|
|
@Autowired
|
private SampleConstructor sampleConstructor;
|
|
@Override
|
public ScheduleResultVO doSchedule(String schemeCode, Date scheduleTime) throws ModelInvokeException {
|
ScheduleResultVO scheduleResult = new ScheduleResultVO();
|
StScheduleSchemeEntity scheduleScheme = stScheduleSchemeService.getByCode(schemeCode);
|
StScheduleModelEntity scheduleModel = stScheduleModelService.get(scheduleScheme.getModelId());
|
if (scheduleModel == null) {
|
throw new ModelInvokeException(MessageFormat.format("{0},modelId={1}",
|
ModelInvokeException.errorGetModelEntity, scheduleModel.getId()));
|
}
|
String modelId = scheduleModel.getId();
|
try {
|
IAILModel newModelBean = new IAILModel();
|
//1.根据模型id构造模型输入样本
|
List<SampleData> sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime);
|
if (CollectionUtils.isEmpty(sampleDataList)) {
|
log.info("调度模型构造样本失败,schemeCode=" + schemeCode);
|
return null;
|
}
|
|
//2.拼接newModelBean的参数结构:a.类名、方法名 b.参数类型
|
String className = scheduleModel.getClassName() .trim();
|
String methodName = scheduleModel.getMethodName().trim();
|
newModelBean.setClassName(className);
|
newModelBean.setMethodName(methodName);
|
|
Class<?>[] paramsArray = new Class[3];
|
paramsArray[0] = double[][].class;
|
paramsArray[1] = double[][].class;
|
paramsArray[2] = HashMap.class;
|
newModelBean.setParamsArray(paramsArray);
|
|
//3.拼接settings参数
|
HashMap<String, Object> settings_predict = getPredictSettingsByModelId(modelId);
|
|
//4.构造param2Values参数结构
|
int count = sampleDataList.size();
|
Object[] param2Values = new Object[count + 1];
|
for (int i = 0; i < count; i++) {
|
param2Values[i] = sampleDataList.get(i).getMatrix();
|
}
|
param2Values[count] = settings_predict;
|
|
//打印参数
|
log.info("##############调度模型:modelId=" + modelId + " ##########################");
|
JSONObject jsonObjNewModelBean = new JSONObject();
|
jsonObjNewModelBean.put("newModelBean", newModelBean);
|
log.info(String.valueOf(jsonObjNewModelBean));
|
JSONObject jsonObjParam2Values = new JSONObject();
|
jsonObjParam2Values.put("param2Values", param2Values);
|
log.info(String.valueOf(jsonObjParam2Values));
|
|
//IAILMDK.run
|
HashMap<String, Object> result = IAILMDK.run(newModelBean, param2Values);
|
|
//打印结果
|
JSONObject jsonObjResult = new JSONObject();
|
jsonObjResult.put("result", result);
|
log.info(String.valueOf(jsonObjResult));
|
log.info("调度模型计算完成:modelId=" + modelId + result);
|
|
//5.返回调度结果
|
scheduleResult.setResult(result);
|
scheduleResult.setModelId(modelId);
|
scheduleResult.setSchemeId(scheduleScheme.getId());
|
scheduleResult.setScheduleTime(scheduleTime);
|
} catch (Exception ex) {
|
log.error("IAILMDK.run()执行失败");
|
log.error(ex.getMessage());
|
log.error("调用发生异常,异常信息为:{}", ex);
|
ex.printStackTrace();
|
}
|
return scheduleResult;
|
}
|
|
/**
|
* 根据模型id获取参数map
|
*
|
* @param modelId
|
* @return
|
*/
|
private HashMap<String, Object> getPredictSettingsByModelId(String modelId) {
|
List<StScheduleModelSettingEntity> list = stScheduleModelSettingService.getByModelId(modelId);
|
if (CollectionUtils.isEmpty(list)) {
|
return null;
|
}
|
HashMap<String, Object> result = new HashMap<>();
|
for (StScheduleModelSettingEntity entry : list) {
|
String valueType = entry.getValuetype().trim();
|
String valueStr = entry.getValue().trim();
|
if ("int".equals(valueType)) {
|
int value = Integer.parseInt(valueStr);
|
result.put(entry.getKey(), value);
|
} else if ("double".equals(valueType)) {
|
double value = Double.parseDouble(valueStr);
|
result.put(entry.getKey(), value);
|
} else if ("string".equals(valueType)) {
|
String value = valueStr;
|
result.put(entry.getKey(), value);
|
} else if ("float".equals(valueType)) {
|
float value = Float.parseFloat(valueStr);
|
result.put(entry.getKey(), value);
|
} else if ("[[D".equals(valueType)) {
|
String valueStrTemp = entry.getValue();
|
try {
|
//1.二位数组的行按照"/"来分割
|
String[] rowList = valueStrTemp.split("/");
|
int row = rowList.length;
|
int col = rowList[0].split(",").length;
|
double[][] value1 = new double[row][col];
|
for (int i = 0; i < rowList.length; i++) {
|
//2.二位数组的列按照","来分割
|
String[] colList = rowList[i].split(",");
|
for (int j = 0; j < colList.length; j++) {
|
value1[i][j] = Double.parseDouble(colList[j]);
|
}
|
}
|
//把从数据库的得到的参数的二维数组降为一维数组
|
//int len =0;
|
double[] value = new double[row * col];
|
/*for (int j = 0; j <value1.length ; j++) {
|
len+= value1.length;
|
}*/
|
//value = new double[len];
|
int index = 0;
|
for (int i = 0; i < value1.length; i++) {
|
for (int j = 0; j < value1[i].length; j++) {
|
value[index++] = value1[i][j];
|
}
|
}
|
result.put(entry.getKey(), value);
|
} catch (Exception ex) {
|
System.out.println("二维数组类型的setting格式不正确");
|
ex.printStackTrace();
|
}
|
} else if ("decimalArray".equals(valueType)) {
|
JSONArray valueArray = JSONArray.parseArray(entry.getValue());
|
double[] value = new double[valueArray.size()];
|
for (int i = 0; i < valueArray.size(); i++) {
|
value[i] = Double.parseDouble(valueArray.get(i).toString());
|
}
|
result.put(entry.getKey(), value);
|
} else if ("decimal".equals(valueType)) {
|
double value = Double.parseDouble(entry.getValue());
|
//BigDecimal value = new BigDecimal(entry.getValue());
|
result.put(entry.getKey(), value);
|
}
|
}
|
return result;
|
}
|
}
|