潘志宝
2024-09-23 0a2f6f78683ba1c4e07f1359c1e7bf105a4bd507
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package com.iailab.module.model.mdk.sample;
 
import com.iailab.framework.common.util.object.ConvertUtils;
import com.iailab.module.data.api.point.DataPointApi;
import com.iailab.module.data.api.point.dto.ApiPointDTO;
import com.iailab.module.data.api.point.dto.ApiPointValueDTO;
import com.iailab.module.data.api.point.dto.ApiPointValueQueryDTO;
import com.iailab.module.model.mcs.pre.service.MmItemResultService;
import com.iailab.module.model.mdk.factory.ItemEntityFactory;
import com.iailab.module.model.mdk.sample.dto.ColumnItem;
import com.iailab.module.model.mdk.sample.dto.ColumnItemPort;
import com.iailab.module.model.mdk.sample.dto.SampleData;
import com.iailab.module.model.mdk.sample.dto.SampleInfo;
import com.iailab.module.model.mdk.vo.DataValueVO;
import com.iailab.module.model.mdk.vo.MmItemOutputVO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
 
import java.math.BigDecimal;
import java.util.*;
 
/**
 * 预测样本数据构造
 */
@Component
public class PredictSampleDataConstructor extends SampleDataConstructor {
 
    private Logger logger = LoggerFactory.getLogger(getClass());
 
    @Autowired
    private DataPointApi dataPointApi;
 
    @Autowired
    private MmItemResultService mmItemResultService;
 
    @Autowired
    private ItemEntityFactory itemEntityFactory;
 
    /**
     * alter by zfc 2020.11.24 修改数据样本构造方案:sampleInfo中数据已按爪子进行分类,但爪内数据为无序的,
     * 对爪内数据样本拼接:先基于modelParamOrder对项进行排序(重写comparator匿名函数),再逐项拼接
     *
     * @param sampleInfo
     * @return
     */
    @Override
    public List<SampleData> prepareSampleData(SampleInfo sampleInfo) {
        List<SampleData> sampleDataList = new ArrayList<>();
        //对每个爪分别进行计算
        int deviationIndex = 0;
        for (ColumnItemPort entry : sampleInfo.getColumnInfo()) {
            //先依据爪内数据项的modelParamOrder进行排序——重写comparator匿名函数
            Collections.sort(entry.getColumnItemList(), new Comparator<ColumnItem>() {
                @Override
                public int compare(ColumnItem o1, ColumnItem o2) {
                    return o1.getModelParamOrder() - o2.getModelParamOrder();
                }
            });
 
            //默认都是double类型的数据,且按列向量进行拼接,默认初始值为0.0
            double[][] matrix = new double[entry.getDataLength()][entry.getColumnItemList().size()];
            for (int i = 0; i < entry.getColumnItemList().size(); i++) {
                for (int j = 0; j < entry.getDataLength(); j++) {
                    matrix[j][i] = -2.0;
                }
            }
 
            //找出对应的调整值
            BigDecimal[] deviationItem = null;
            if (sampleInfo.getDeviation() != null && sampleInfo.getDeviation().length > 0) {
                deviationItem = sampleInfo.getDeviation()[deviationIndex];
            }
            deviationIndex++;
 
            //对每一项依次进行数据查询,然后将查询出的值赋给matrix对应的位置
            for (int i = 0; i < entry.getColumnItemList().size(); i++) {
                try {
                    List<DataValueVO> dataEntityList = getData(entry.getColumnItemList().get(i));
                    //设置调整值
                    if (deviationItem != null && deviationItem.length > 0) {
                        logger.info("设置调整值, i = " + i);
                        if (deviationItem[i] != null && deviationItem[i].compareTo(BigDecimal.ZERO) != 0) {
                            for (int dataKey = 1; dataKey < dataEntityList.size(); dataKey++) {
                                DataValueVO item = dataEntityList.get(dataKey);
                                item.setDataValue(item.getDataValue() + deviationItem[i].doubleValue());
                            }
                        }
                    }
                    //补全数据
                    ColumnItem columnItem = entry.getColumnItemList().get(i);
                    dataEntityList = super.completionData(matrix.length, dataEntityList, columnItem.startTime, columnItem.getEndTime(), columnItem.granularity);
 
                    /** 如果数据取不满,把缺失的数据点放在后面 */
                    if (dataEntityList != null && dataEntityList.size() != 0) {
                        logger.info("设置matrix, i = " + i + ", size = " + dataEntityList.size());
                        for (int k = 0; k < dataEntityList.size(); k++) {
                            matrix[k][i] = dataEntityList.get(k).getDataValue();
                        }
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            SampleData sampleData = new SampleData();
            sampleData.setMatrix(matrix);
            sampleDataList.add(sampleData);
        }
        return sampleDataList;
    }
 
    /**
     * getData
     *
     * @param columnItem
     * @return
     * @throws Exception
     */
    private List<DataValueVO> getData(ColumnItem columnItem) throws Exception {
        List<DataValueVO> dataList = new ArrayList<>();
        String paramType = columnItem.getParamType();
        switch (paramType) {
            case "DATAPOINT":
                ApiPointDTO point = dataPointApi.getPointById(columnItem.getId());
                ApiPointValueQueryDTO queryDto = new ApiPointValueQueryDTO();
                queryDto.setPointNo(point.getPointNo());
                queryDto.setStart(columnItem.getStartTime());
                queryDto.setEnd(columnItem.getEndTime());
                List<ApiPointValueDTO> pointValueList = dataPointApi.getValue(queryDto);
                dataList = ConvertUtils.sourceToTarget(pointValueList, DataValueVO.class);
                break;
            case "PREDICTITEM":
                MmItemOutputVO outPut = itemEntityFactory.getItemOutPutById(columnItem.getId());
                dataList = mmItemResultService.getPredictValue(outPut.getId(),
                        columnItem.getStartTime(), columnItem.getEndTime());
                if (dataList == null) {
                    throw new Exception("没有预测值");
                }
                break;
 
 
            default:
                break;
        }
        return dataList;
    }
}