package com.iailab.module.model.mdk.schedule.impl; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.iail.model.IAILModel; import com.iailab.module.model.common.enums.CommonConstant; import com.iailab.module.model.mcs.sche.entity.StScheduleModelEntity; import com.iailab.module.model.mcs.sche.entity.StScheduleModelSettingEntity; import com.iailab.module.model.mcs.sche.entity.StScheduleSchemeEntity; import com.iailab.module.model.mcs.sche.service.StScheduleModelService; import com.iailab.module.model.mcs.sche.service.StScheduleModelSettingService; import com.iailab.module.model.mcs.sche.service.StScheduleSchemeService; import com.iailab.module.model.mdk.common.enums.TypeA; import com.iailab.module.model.mdk.common.exceptions.ModelInvokeException; import com.iailab.module.model.mdk.sample.SampleConstructor; import com.iailab.module.model.mdk.sample.dto.SampleData; import com.iailab.module.model.mdk.schedule.ScheduleModelHandler; import com.iailab.module.model.mdk.vo.ScheduleResultVO; import com.iailab.module.model.mpk.common.MdkConstant; import com.iailab.module.model.mpk.common.utils.DllUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import java.text.MessageFormat; import java.util.Date; import java.util.HashMap; import java.util.List; /** * @author PanZhibao * @Description * @createTime 2024年09月05日 */ @Slf4j @Component public class ScheduleModelHandlerImpl implements ScheduleModelHandler { @Autowired private StScheduleSchemeService stScheduleSchemeService; @Autowired private StScheduleModelService stScheduleModelService; @Autowired private StScheduleModelSettingService stScheduleModelSettingService; @Autowired private SampleConstructor sampleConstructor; @Override public ScheduleResultVO doSchedule(String schemeCode, Date scheduleTime) throws ModelInvokeException { ScheduleResultVO scheduleResult = new ScheduleResultVO(); StScheduleSchemeEntity scheduleScheme = stScheduleSchemeService.getByCode(schemeCode); StScheduleModelEntity scheduleModel = stScheduleModelService.get(scheduleScheme.getModelId()); if (scheduleModel == null) { throw new ModelInvokeException(MessageFormat.format("{0},modelId={1}", ModelInvokeException.errorGetModelEntity, scheduleModel.getId())); } String modelId = scheduleModel.getId(); try { //1.根据模型id构造模型输入样本 long now = System.currentTimeMillis(); List sampleDataList = sampleConstructor.constructSample(TypeA.Schedule.name(), modelId, scheduleTime); log.info("构造模型输入样本消耗时长:" + (System.currentTimeMillis() - now) / 1000 + "秒"); if (CollectionUtils.isEmpty(sampleDataList)) { log.info("调度模型构造样本失败,schemeCode=" + schemeCode); return null; } IAILModel newModelBean = composeNewModelBean(scheduleModel); HashMap settings = getScheduleSettingsByModelId(modelId); if (settings == null) { log.error("模型setting不存在,modelId=" + modelId); return null; } // 校验setting必须有pyFile,否则可能导致程序崩溃 if (!settings.containsKey(MdkConstant.PY_FILE_KEY)) { log.error("模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY + "】,请重新上传模型!"); return null; } int portLength = sampleDataList.size(); Object[] param2Values = new Object[portLength + 1]; for (int i = 0; i < portLength; i++) { param2Values[i] = sampleDataList.get(i).getMatrix(); } param2Values[portLength] = settings; log.info("#######################调度模型 " + scheduleModel.getModelName() + " ##########################"); // JSONObject jsonObjNewModelBean = new JSONObject(); // jsonObjNewModelBean.put("newModelBean", newModelBean); // log.info(String.valueOf(jsonObjNewModelBean)); // JSONObject jsonObjParam2Values = new JSONObject(); // jsonObjParam2Values.put("param2Values", param2Values); log.info("参数: " + JSON.toJSONString(param2Values)); //IAILMDK.run HashMap modelResult = DllUtils.run(newModelBean, param2Values, scheduleScheme.getMpkprojectid()); if (!modelResult.containsKey(CommonConstant.MDK_STATUS_CODE) || !modelResult.containsKey(CommonConstant.MDK_RESULT) || !modelResult.get(CommonConstant.MDK_STATUS_CODE).toString().equals(CommonConstant.MDK_STATUS_100)) { throw new RuntimeException("模型结果异常:" + modelResult); } modelResult = (HashMap) modelResult.get(CommonConstant.MDK_RESULT); //打印结果 JSONObject jsonObjResult = new JSONObject(); jsonObjResult.put("result", modelResult); log.info(String.valueOf(jsonObjResult)); //5.返回调度结果 scheduleResult.setResult(modelResult); scheduleResult.setModelId(modelId); scheduleResult.setSchemeId(scheduleScheme.getId()); scheduleResult.setScheduleTime(scheduleTime); } catch (Exception ex) { log.error("IAILMDK.run()执行失败"); log.error(ex.getMessage()); log.error("调用发生异常,异常信息为:{}", ex); ex.printStackTrace(); } return scheduleResult; } /** * 根据模型id获取参数map * * @param modelId * @return */ private HashMap getScheduleSettingsByModelId(String modelId) { List list = stScheduleModelSettingService.getByModelId(modelId); if (CollectionUtils.isEmpty(list)) { return null; } HashMap result = new HashMap<>(); for (StScheduleModelSettingEntity entry : list) { String valueType = entry.getValuetype().trim(); //去除两端空格 if ("int".equals(valueType)) { int value = Integer.parseInt(entry.getValue()); result.put(entry.getKey(), value); } else if ("double".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } else if ("string".equals(valueType)) { String value = entry.getValue(); result.put(entry.getKey(), value); } else if ("decimalArray".equals(valueType)) { JSONArray valueArray = JSONArray.parseArray(entry.getValue()); double[] value = new double[valueArray.size()]; for (int i = 0; i < valueArray.size(); i++) { value[i] = Double.parseDouble(valueArray.get(i).toString()); } result.put(entry.getKey(), value); } else if ("decimal".equals(valueType)) { double value = Double.parseDouble(entry.getValue()); result.put(entry.getKey(), value); } } return result; } private IAILModel composeNewModelBean(StScheduleModelEntity model) { IAILModel newModelBean = new IAILModel(); newModelBean.setClassName(model.getClassName().trim()); newModelBean.setMethodName(model.getMethodName().trim()); //构造参数类型 Class[] paramsArray = new Class[model.getPortLength() + 1]; for (int i = 0; i < model.getPortLength(); i++) { paramsArray[i] = double[][].class; } paramsArray[model.getPortLength()] = HashMap.class; newModelBean.setParamsArray(paramsArray); // // HashMap dataMap = new HashMap<>(); // newModelBean.setDataMap(dataMap); return newModelBean; } }