潘志宝
2024-09-09 58c7491f231c8250e13c7ce705c2710a08b2cf88
提交 | 用户 | 时间
7fd198 1 package com.iailab.module.model.mdk.factory;
2
3 import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
4 import com.iailab.module.model.mcs.pre.entity.MmModelParamEntity;
5 import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
6 import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
7 import com.iailab.module.model.mcs.pre.service.MmModelParamService;
8 import com.iailab.module.model.mcs.pre.service.MmPredictModelService;
9 import lombok.extern.slf4j.Slf4j;
10 import org.springframework.beans.factory.annotation.Autowired;
11 import org.springframework.stereotype.Component;
12
13 import java.util.HashMap;
14 import java.util.List;
15 import java.util.Map;
16
17 /**
18  * 创建和管理模型实体
19  */
20 @Slf4j
21 @Component
22 public class ModelEntityFactory {
23
24     private Map<String, Object> modelFileMap = new HashMap<>();
25     private Map<String, MmPredictModelEntity> modelEntityMap = new HashMap<>();
26     private Map<String, List<MmModelParamEntity>> modelInputParamMap = new HashMap<>();
27     private Map<String, List<MmModelArithSettingsEntity>> modelArithParamMap = new HashMap<>();
28     private Map<String, List<MmPredictModelEntity>> modelListMap = new HashMap<>();
29
30     @Autowired
31     private MmPredictModelService mmPredictModelService;
32
33     @Autowired
34     private MmModelParamService mmModelParamService;
35
36     @Autowired
37     private MmModelArithSettingsService mmModelArithSettingsService;
38
39     /**
40      * 2.根据模型ID,获取模型实体
41      *
42      * @param modelId
43      * @return
44      */
45     public MmPredictModelEntity getModelEntity(String modelId) {
46         MmPredictModelEntity modelEntity = mmPredictModelService.getInfo(modelId);
47         if (!modelEntityMap.containsKey(modelId)) {
48             if (modelEntity != null) {
49                 modelEntityMap.put(modelId, modelEntity);
50             }
51         }
52         return modelEntity;
53     }
54
55     /**
56      * 3.根据模型ID,获取模型对应的输入参数
57      *
58      * @param modelId
59      * @return
60      */
61     public List<MmModelParamEntity> getModelInputParam(String modelId) {
62         if (!modelInputParamMap.containsKey(modelId)) {
63             List<MmModelParamEntity> modelInputParamEntities = mmModelParamService.getByModelid(modelId);
64             if (modelInputParamEntities != null) {
65                 modelInputParamMap.put(modelId, modelInputParamEntities);
66             } else {
67                 return null;
68             }
69         }
70         return mmModelParamService.getByModelid(modelId);
71     }
72
73     /**
74      * 4.根据模型ID,获取模型对应的输入参数的维数
75      *
76      * @param modelId
77      * @return
78      */
79     public Integer getModelInputCount(String modelId) {
80         if (!modelInputParamMap.containsKey(modelId)) {
81             List<MmModelParamEntity> modelInputParamEntityList = mmModelParamService.getByModelid(modelId);
82             if (modelInputParamEntityList != null) {
83                 modelInputParamMap.put(modelId, modelInputParamEntityList);
84             } else {
85                 return 0;
86             }
87         }
88         return modelInputParamMap.get(modelId).size();
89     }
90
91     /**
92      * 5.根据模型ID,获取模型对应的算法参数
93      *
94      * @param modelId
95      * @return
96      */
97     public List<MmModelArithSettingsEntity> getModelArithParam(String modelId) {
98         if (!modelArithParamMap.containsKey(modelId)) {
99             List<MmModelArithSettingsEntity> modelArithParamEntityList = mmModelArithSettingsService.getByModelId(modelId);
100             if (modelArithParamEntityList != null) {
101                 modelArithParamMap.put(modelId, modelArithParamEntityList);
102             } else {
103                 return null;
104             }
105         }
106         return modelArithParamMap.get(modelId);
107     }
108
109     /**
110      * 7.根据预测项itemID,获取status=1的模型列表
111      *
112      * @param itemId
113      * @return
114      */
115     public List<MmPredictModelEntity> getActiveModelByItemId(String itemId) {
116         if (!modelListMap.containsKey(itemId)) {
117             List<MmPredictModelEntity> modelEntityList = mmPredictModelService.getActiveModelByItemId(itemId);
118             if (modelEntityList != null) {
119                 modelListMap.put(itemId, modelEntityList);
120             } else {
121                 return null;
122             }
123         }
124         return modelListMap.get(itemId);
125     }
126
127     /**
128      * 8.根据模型ID,删除模型对应的输入参数
129      *
130      * @param modelId
131      * @return
132      */
133     public void removeModelInputParam(String modelId) {
134         if (modelInputParamMap.containsKey(modelId)) {
135             log.info("removeModelInputParam:modelId=" + modelId);
136             modelInputParamMap.remove(modelId);
137         }
138     }
139
140     /**
141      * 清除缓存
142      */
143     public void removeModelEntity() {
144         modelFileMap.clear();
145         modelEntityMap.clear();
146         modelInputParamMap.clear();
147         modelArithParamMap.clear();
148         modelListMap.clear();
149     }
150 }