package com.iailab.module.model.mdk.predict;
|
|
import com.iailab.module.model.common.exception.ModelResultErrorException;
|
import com.iailab.module.model.mcs.pre.entity.MmItemOutputEntity;
|
import com.iailab.module.model.mcs.pre.enums.ItemRunStatusEnum;
|
import com.iailab.module.model.mcs.pre.enums.ItemStatus;
|
import com.iailab.module.model.mcs.pre.enums.PredGranularityEnum;
|
import com.iailab.module.model.mcs.pre.service.MmItemStatusService;
|
import com.iailab.module.model.mdk.factory.PredictItemFactory;
|
import com.iailab.module.model.mdk.vo.ItemVO;
|
import com.iailab.module.model.mdk.vo.PredictResultVO;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.CollectionUtils;
|
|
import java.text.MessageFormat;
|
import java.util.*;
|
|
/**
|
* @author PanZhibao
|
* @Description
|
* @createTime 2024年08月30日
|
*/
|
@Slf4j
|
@Component
|
public class PredictModuleHandler {
|
|
|
@Autowired
|
private PredictItemFactory predictItemFactory;
|
|
@Autowired
|
private PredictResultHandler predictResultHandler;
|
|
@Autowired
|
private MmItemStatusService mmItemStatusService;
|
|
|
/**
|
* 预测处理
|
*
|
* @param predictItemList
|
* @param predictTime
|
* @param intervalTime
|
* @return
|
*/
|
public void predict(List<ItemVO> predictItemList, Date predictTime, int intervalTime,Map<String, PredictResultVO> predictResultMap) {
|
Map<String, double[]> predictValueMap = null;
|
if (!CollectionUtils.isEmpty(predictResultMap)) {
|
// 将predictResultMap处理成Map<outPutId, double[]>
|
predictValueMap = new HashMap<>();
|
for (Map.Entry<String, PredictResultVO> entry : predictResultMap.entrySet()) {
|
for (Map.Entry<MmItemOutputEntity, double[]> mmItemOutputEntityEntry : entry.getValue().getPredictMatrixs().entrySet()) {
|
predictValueMap.put(mmItemOutputEntityEntry.getKey().getId(),mmItemOutputEntityEntry.getValue());
|
}
|
}
|
}
|
for (ItemVO predictItem : predictItemList) {
|
// 根据item粒度处理预测时间
|
Calendar calendar = Calendar.getInstance();
|
calendar.setTime(predictTime);
|
calendar.set(Calendar.MILLISECOND, 0);
|
calendar.set(Calendar.SECOND, 0);
|
if (PredGranularityEnum.H1.getCode().equals(predictItem.getGranularity())) {
|
calendar.set(Calendar.MINUTE,0);
|
}else if (PredGranularityEnum.D1.getCode().equals(predictItem.getGranularity())) {
|
calendar.set(Calendar.MINUTE,0);
|
calendar.set(Calendar.HOUR_OF_DAY,0);
|
}
|
PredictResultVO predictResult;
|
if (!predictItem.getStatus().equals(ItemStatus.STATUS1.getCode())) {
|
continue;
|
}
|
Long totalDur = 0L;
|
ItemRunStatusEnum itemRunStatusEnum = ItemRunStatusEnum.PROCESSING;
|
try {
|
mmItemStatusService.recordStatus(predictItem.getId(), itemRunStatusEnum, totalDur, calendar.getTime());
|
PredictItemHandler predictItemHandler = predictItemFactory.create(predictItem.getId());
|
long start = System.currentTimeMillis();
|
try {
|
// 预测项开始预测
|
predictResult = predictItemHandler.predict(calendar.getTime(), predictItem, predictValueMap);
|
} catch (ModelResultErrorException e) {
|
itemRunStatusEnum = ItemRunStatusEnum.MODELRESULTERROR;
|
continue;
|
} catch (Exception e) {
|
itemRunStatusEnum = ItemRunStatusEnum.FAIL;
|
continue;
|
}
|
long end = System.currentTimeMillis();
|
Long drtPre = end - start;
|
log.info(MessageFormat.format("预测项:{0},预测时间:{1}ms", predictItem.getItemName(), drtPre));
|
totalDur = totalDur + drtPre;
|
predictResult.setGranularity(predictItem.getGranularity());
|
predictResult.setT(intervalTime);
|
predictResult.setSaveIndex(predictItem.getSaveIndex());
|
predictResult.setLt(1);
|
predictResultMap.put(predictItem.getItemNo(), predictResult);
|
|
// 保存预测结果
|
try {
|
predictResultHandler.savePredictResult(predictResult);
|
} catch (Exception e) {
|
itemRunStatusEnum = ItemRunStatusEnum.MODELRESULTSAVEERROR;
|
throw new RuntimeException("模型结果保存异常,result:" + predictResult);
|
}
|
itemRunStatusEnum = ItemRunStatusEnum.SUCCESS;
|
// long endSave = System.currentTimeMillis();
|
// Long drtSave = endSave - end;
|
// log.info(MessageFormat.format("预测项:{0},保存时间:{1}ms", predictItem.getItemName(),
|
// drtSave));
|
// totalDur = totalDur + drtSave;
|
} catch (Exception e) {
|
e.printStackTrace();
|
log.error(MessageFormat.format("预测项编号:{0},预测项名称:{1},预测失败:{2} 预测时刻:{3}",
|
predictItem.getId(), predictItem.getItemName(), e.getMessage(), predictTime));
|
} finally {
|
mmItemStatusService.recordStatus(predictItem.getId(), itemRunStatusEnum, totalDur, calendar.getTime());
|
}
|
}
|
}
|
}
|