package com.iailab.module.model.mdk.sample;
|
|
import com.iailab.module.data.api.point.DataPointApi;
|
import com.iailab.module.data.api.point.dto.ApiPointDTO;
|
import com.iailab.module.model.mcs.pre.entity.MmModelParamEntity;
|
import com.iailab.module.model.mcs.pre.service.MmModelParamService;
|
import com.iailab.module.model.mcs.pre.service.MmPredictItemService;
|
import com.iailab.module.model.mcs.pre.service.MmPredictModelService;
|
import com.iailab.module.model.mdk.common.enums.ModelParamType;
|
import com.iailab.module.model.mdk.sample.dto.ColumnItem;
|
import com.iailab.module.model.mdk.sample.dto.ColumnItemPort;
|
import com.iailab.module.model.mdk.sample.dto.SampleInfo;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.CollectionUtils;
|
|
import java.util.*;
|
import java.util.function.Function;
|
import java.util.stream.Collectors;
|
|
/**
|
* @author PanZhibao
|
* @Description
|
* @createTime 2024年09月03日
|
*/
|
@Component
|
public class PredictSampleInfoConstructor extends SampleInfoConstructor {
|
|
@Autowired
|
private MmPredictModelService mmPredictModelService;
|
|
@Autowired
|
private MmModelParamService mmModelParamService;
|
|
@Autowired
|
private MmPredictItemService mmPredictItemService;
|
|
@Autowired
|
private DataPointApi dataPointApi;
|
|
/**
|
* 返回样本矩阵的列数
|
*
|
* @param modelId
|
* @return
|
*/
|
@Override
|
protected Integer getSampleColumn(String modelId) {
|
return mmPredictModelService.getSampleLength(modelId).intValue();
|
}
|
|
/**
|
* 样本的列信息
|
*
|
* @param modelId
|
* @param predictTime
|
* @return
|
*/
|
@Override
|
protected SampleInfo getColumnInfo(String modelId, Date predictTime) {
|
SampleInfo sampleInfo = new SampleInfo();
|
List<ColumnItemPort> resultList = new ArrayList<>();
|
List<ColumnItem> columnItemList = new ArrayList<>();
|
ColumnItem columnInfo = new ColumnItem();
|
ColumnItemPort curPort = new ColumnItemPort(); //当前端口
|
List<MmModelParamEntity> modelInputParamEntityList = mmModelParamService.getByModelidFromCache(modelId);
|
if (CollectionUtils.isEmpty(modelInputParamEntityList)) {
|
return null;
|
}
|
//设置当前端口号,初始值为最小端口(查询结果按端口号从小到达排列)
|
int curPortOrder = modelInputParamEntityList.get(0).getModelparamportorder();
|
//设置当前查询数据长度,初始值为最小端口数据长度
|
int curDataLength = modelInputParamEntityList.get(0).getDatalength();
|
// 统一获取测点的信息
|
Set<String> pointIds = modelInputParamEntityList.stream().filter(e -> ModelParamType.getEumByCode(e.getModelparamtype()).equals(ModelParamType.DATAPOINT)).map(MmModelParamEntity::getModelparamid).collect(Collectors.toSet());
|
List<ApiPointDTO> points = dataPointApi.getInfoByIds(pointIds);
|
Map<String, ApiPointDTO> pointMap = points.stream().collect(Collectors.toMap(ApiPointDTO::getId, Function.identity(), (e1,e2) -> e1));
|
|
for (MmModelParamEntity entry : modelInputParamEntityList) {
|
columnInfo.setParamType(entry.getModelparamtype());
|
columnInfo.setParamId(entry.getModelparamid());
|
columnInfo.setDataLength(entry.getDatalength());
|
columnInfo.setModelParamOrder(entry.getModelparamorder());
|
columnInfo.setModelParamPortOrder(entry.getModelparamportorder());
|
columnInfo.setStartTime(getStartTime(columnInfo, predictTime,pointMap));
|
columnInfo.setEndTime(getEndTime(columnInfo, predictTime,pointMap));
|
columnInfo.setGranularity(super.getGranularity(columnInfo));
|
|
//对每一个爪进行数据项归并
|
if (curPortOrder != entry.getModelparamportorder()){
|
//当数据项端口号不为当前端口号时,封装上一个端口类,操作下一个端口类
|
curPort.setColumnItemList(columnItemList);
|
curPort.setDataLength(curDataLength);
|
curPort.setPortOrder(curPortOrder);
|
resultList.add(curPort);
|
curPort = new ColumnItemPort(); //对象重新初始化,防止引用拷贝导致数据覆盖
|
//封装上一个端口类后更新当前的各个参数
|
columnItemList = new ArrayList<>();
|
curDataLength = entry.getDatalength();
|
curPortOrder = entry.getModelparamportorder();
|
}
|
columnItemList.add(columnInfo);
|
columnInfo = new ColumnItem(); //对象重新初始化,防止引用拷贝导致数据覆盖
|
}
|
//当迭代到最后一个项的时候,封装最后一个端口的信息
|
curPort.setColumnItemList(columnItemList);
|
curPort.setDataLength(curDataLength);
|
curPort.setPortOrder(curPortOrder);
|
resultList.add(curPort);
|
sampleInfo.setColumnInfo(resultList);
|
sampleInfo.setPointMap(pointMap);
|
return sampleInfo;
|
}
|
|
/**
|
* 样本的采样周期
|
*
|
* @param modelId
|
* @return
|
*/
|
@Override
|
protected Integer getSampleCycle(String modelId) {
|
return mmPredictItemService.getItemByIdFromCache(mmPredictModelService.getInfoFromCatch(modelId).getItemid()).getGranularity();
|
}
|
|
|
}
|