潘志宝
2024-12-23 2fe27eee95f46825fdeee267a42811a3069991c8
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.sample;
2
3 import com.iailab.module.model.mcs.pre.entity.MmModelParamEntity;
91343d 4 import com.iailab.module.model.mcs.pre.service.MmModelParamService;
5 import com.iailab.module.model.mcs.pre.service.MmPredictItemService;
7fd198 6 import com.iailab.module.model.mcs.pre.service.MmPredictModelService;
7 import com.iailab.module.model.mdk.sample.dto.ColumnItem;
8 import com.iailab.module.model.mdk.sample.dto.ColumnItemPort;
9 import org.springframework.beans.factory.annotation.Autowired;
10 import org.springframework.stereotype.Component;
11 import org.springframework.util.CollectionUtils;
12
13 import java.util.ArrayList;
14 import java.util.Date;
15 import java.util.List;
16
17 /**
18  * @author PanZhibao
19  * @Description
20  * @createTime 2024年09月03日
21  */
22 @Component
23 public class PredictSampleInfoConstructor extends SampleInfoConstructor {
24
25     @Autowired
26     private MmPredictModelService mmPredictModelService;
27
28     @Autowired
91343d 29     private MmModelParamService mmModelParamService;
7fd198 30
31     @Autowired
91343d 32     private MmPredictItemService mmPredictItemService;
7fd198 33
34     /**
35      * 返回样本矩阵的列数
36      *
37      * @param modelId
38      * @return
39      */
40     @Override
41     protected Integer getSampleColumn(String modelId) {
42         return mmPredictModelService.getSampleLength(modelId).intValue();
43     }
44
45     /**
46      * 样本的列信息
47      *
48      * @param modelId
49      * @param predictTime
50      * @return
51      */
52     @Override
53     protected List<ColumnItemPort> getColumnInfo(String modelId, Date predictTime) {
54         List<ColumnItemPort> resultList = new ArrayList<>();
55         List<ColumnItem> columnItemList = new ArrayList<>();
56         ColumnItem columnInfo = new ColumnItem();
57         ColumnItemPort curPort = new ColumnItemPort();  //当前端口
45520a 58         List<MmModelParamEntity> modelInputParamEntityList = mmModelParamService.getByModelidFromCache(modelId);
7fd198 59         if (CollectionUtils.isEmpty(modelInputParamEntityList)) {
60             return null;
61         }
62         //设置当前端口号,初始值为最小端口(查询结果按端口号从小到达排列)
63         int curPortOrder = modelInputParamEntityList.get(0).getModelparamportorder();
64         //设置当前查询数据长度,初始值为最小端口数据长度
65         int curDataLength = modelInputParamEntityList.get(0).getDatalength();
66         for (MmModelParamEntity entry : modelInputParamEntityList) {
67             columnInfo.setParamType(entry.getModelparamtype());
1a2b62 68             columnInfo.setParamId(entry.getModelparamid());
7fd198 69             columnInfo.setDataLength(entry.getDatalength());
70             columnInfo.setModelParamOrder(entry.getModelparamorder());
71             columnInfo.setModelParamPortOrder(entry.getModelparamportorder());
72             columnInfo.setStartTime(getStartTime(columnInfo, predictTime));
73             columnInfo.setEndTime(getEndTime(columnInfo, predictTime));
74             columnInfo.setGranularity(super.getGranularity(columnInfo));
75
76             //对每一个爪进行数据项归并
77             if (curPortOrder != entry.getModelparamportorder()){
78                 //当数据项端口号不为当前端口号时,封装上一个端口类,操作下一个端口类
79                 curPort.setColumnItemList(columnItemList);
80                 curPort.setDataLength(curDataLength);
81                 curPort.setPortOrder(curPortOrder);
82                 resultList.add(curPort);
83                 curPort = new ColumnItemPort(); //对象重新初始化,防止引用拷贝导致数据覆盖
84                 //封装上一个端口类后更新当前的各个参数
85                 columnItemList = new ArrayList<>();
86                 curDataLength = entry.getDatalength();
87                 curPortOrder = entry.getModelparamportorder();
88             }
89             columnItemList.add(columnInfo);
90             columnInfo = new ColumnItem();    //对象重新初始化,防止引用拷贝导致数据覆盖
91         }
92         //当迭代到最后一个项的时候,封装最后一个端口的信息
93         curPort.setColumnItemList(columnItemList);
94         curPort.setDataLength(curDataLength);
95         curPort.setPortOrder(curPortOrder);
96         resultList.add(curPort);
97         return resultList;
98     }
99
100     /**
101      * 样本的采样周期
102      *
103      * @param modelId
104      * @return
105      */
106     @Override
107     protected Integer getSampleCycle(String modelId) {
91343d 108         return mmPredictItemService.getItemById(mmPredictModelService.getInfoFromCatch(modelId).getItemid()).getGranularity();
7fd198 109     }
110
111
112 }