package com.iailab.module.model.mdk.sample;
|
|
import com.iailab.framework.common.util.object.ConvertUtils;
|
import com.iailab.module.data.api.point.DataPointApi;
|
import com.iailab.module.data.api.point.dto.ApiPointDTO;
|
import com.iailab.module.data.api.point.dto.ApiPointValueDTO;
|
import com.iailab.module.data.api.point.dto.ApiPointValueQueryDTO;
|
import com.iailab.module.model.mcs.pre.service.MmItemResultService;
|
import com.iailab.module.model.mdk.factory.ItemEntityFactory;
|
import com.iailab.module.model.mdk.sample.dto.ColumnItem;
|
import com.iailab.module.model.mdk.sample.dto.ColumnItemPort;
|
import com.iailab.module.model.mdk.sample.dto.SampleData;
|
import com.iailab.module.model.mdk.sample.dto.SampleInfo;
|
import com.iailab.module.model.mdk.vo.DataValueVO;
|
import com.iailab.module.model.mdk.vo.MmItemOutputVO;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.stereotype.Component;
|
|
import java.math.BigDecimal;
|
import java.util.*;
|
|
/**
|
* 预测样本数据构造
|
*/
|
@Component
|
public class PredictSampleDataConstructor extends SampleDataConstructor {
|
|
private Logger logger = LoggerFactory.getLogger(getClass());
|
|
@Autowired
|
private DataPointApi dataPointApi;
|
|
@Autowired
|
private MmItemResultService mmItemResultService;
|
|
@Autowired
|
private ItemEntityFactory itemEntityFactory;
|
|
/**
|
* alter by zfc 2020.11.24 修改数据样本构造方案:sampleInfo中数据已按爪子进行分类,但爪内数据为无序的,
|
* 对爪内数据样本拼接:先基于modelParamOrder对项进行排序(重写comparator匿名函数),再逐项拼接
|
*
|
* @param sampleInfo
|
* @return
|
*/
|
@Override
|
public List<SampleData> prepareSampleData(SampleInfo sampleInfo) {
|
List<SampleData> sampleDataList = new ArrayList<>();
|
//对每个爪分别进行计算
|
int deviationIndex = 0;
|
for (ColumnItemPort entry : sampleInfo.getColumnInfo()) {
|
//先依据爪内数据项的modelParamOrder进行排序——重写comparator匿名函数
|
Collections.sort(entry.getColumnItemList(), new Comparator<ColumnItem>() {
|
@Override
|
public int compare(ColumnItem o1, ColumnItem o2) {
|
return o1.getModelParamOrder() - o2.getModelParamOrder();
|
}
|
});
|
|
//默认都是double类型的数据,且按列向量进行拼接,默认初始值为0.0
|
double[][] matrix = new double[entry.getDataLength()][entry.getColumnItemList().size()];
|
for (int i = 0; i < entry.getColumnItemList().size(); i++) {
|
for (int j = 0; j < entry.getDataLength(); j++) {
|
matrix[j][i] = -2.0;
|
}
|
}
|
|
//找出对应的调整值
|
BigDecimal[] deviationItem = null;
|
if (sampleInfo.getDeviation() != null && sampleInfo.getDeviation().length > 0) {
|
deviationItem = sampleInfo.getDeviation()[deviationIndex];
|
}
|
deviationIndex++;
|
|
//对每一项依次进行数据查询,然后将查询出的值赋给matrix对应的位置
|
for (int i = 0; i < entry.getColumnItemList().size(); i++) {
|
try {
|
List<DataValueVO> dataEntityList = getData(entry.getColumnItemList().get(i));
|
//设置调整值
|
if (deviationItem != null && deviationItem.length > 0) {
|
logger.info("设置调整值, i = " + i);
|
if (deviationItem[i] != null && deviationItem[i].compareTo(BigDecimal.ZERO) != 0) {
|
for (int dataKey = 1; dataKey < dataEntityList.size(); dataKey++) {
|
DataValueVO item = dataEntityList.get(dataKey);
|
item.setDataValue(item.getDataValue() + deviationItem[i].doubleValue());
|
}
|
}
|
}
|
//补全数据
|
ColumnItem columnItem = entry.getColumnItemList().get(i);
|
dataEntityList = super.completionData(matrix.length, dataEntityList, columnItem.startTime, columnItem.getEndTime(), columnItem.granularity);
|
|
/** 如果数据取不满,把缺失的数据点放在后面 */
|
if (dataEntityList != null && dataEntityList.size() != 0) {
|
logger.info("设置matrix, i = " + i + ", size = " + dataEntityList.size());
|
for (int k = 0; k < dataEntityList.size(); k++) {
|
matrix[k][i] = dataEntityList.get(k).getDataValue();
|
}
|
}
|
} catch (Exception e) {
|
e.printStackTrace();
|
}
|
}
|
SampleData sampleData = new SampleData();
|
sampleData.setMatrix(matrix);
|
sampleDataList.add(sampleData);
|
}
|
return sampleDataList;
|
}
|
|
/**
|
* getData
|
*
|
* @param columnItem
|
* @return
|
* @throws Exception
|
*/
|
private List<DataValueVO> getData(ColumnItem columnItem) throws Exception {
|
List<DataValueVO> dataList = new ArrayList<>();
|
String paramType = columnItem.getParamType();
|
switch (paramType) {
|
case "DATAPOINT":
|
ApiPointDTO point = dataPointApi.getPointById(columnItem.getId());
|
ApiPointValueQueryDTO queryDto = new ApiPointValueQueryDTO();
|
queryDto.setPointNo(point.getPointNo());
|
queryDto.setStart(columnItem.getStartTime());
|
queryDto.setEnd(columnItem.getEndTime());
|
List<ApiPointValueDTO> pointValueList = dataPointApi.getValue(queryDto);
|
dataList = ConvertUtils.sourceToTarget(pointValueList, DataValueVO.class);
|
break;
|
case "PREDICTITEM":
|
MmItemOutputVO outPut = itemEntityFactory.getItemOutPutById(columnItem.getId());
|
dataList = mmItemResultService.getPredictValue(outPut.getId(),
|
columnItem.getStartTime(), columnItem.getEndTime());
|
if (dataList == null) {
|
throw new Exception("没有预测值");
|
}
|
break;
|
|
|
default:
|
break;
|
}
|
return dataList;
|
}
|
}
|