dengzedong
2024-09-14 e18f2001fda0eccfbf2aa617e127c70b92083909
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package com.iailab.module.model.mdk.factory;
 
import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity;
import com.iailab.module.model.mcs.pre.entity.MmModelParamEntity;
import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity;
import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService;
import com.iailab.module.model.mcs.pre.service.MmModelParamService;
import com.iailab.module.model.mcs.pre.service.MmPredictModelService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
 
import java.util.HashMap;
import java.util.List;
import java.util.Map;
 
/**
 * 创建和管理模型实体
 */
@Slf4j
@Component
public class ModelEntityFactory {
 
    private Map<String, Object> modelFileMap = new HashMap<>();
    private Map<String, MmPredictModelEntity> modelEntityMap = new HashMap<>();
    private Map<String, List<MmModelParamEntity>> modelInputParamMap = new HashMap<>();
    private Map<String, List<MmModelArithSettingsEntity>> modelArithParamMap = new HashMap<>();
    private Map<String, List<MmPredictModelEntity>> modelListMap = new HashMap<>();
 
    @Autowired
    private MmPredictModelService mmPredictModelService;
 
    @Autowired
    private MmModelParamService mmModelParamService;
 
    @Autowired
    private MmModelArithSettingsService mmModelArithSettingsService;
 
    /**
     * 2.根据模型ID,获取模型实体
     *
     * @param modelId
     * @return
     */
    public MmPredictModelEntity getModelEntity(String modelId) {
        MmPredictModelEntity modelEntity = mmPredictModelService.getInfo(modelId);
        if (!modelEntityMap.containsKey(modelId)) {
            if (modelEntity != null) {
                modelEntityMap.put(modelId, modelEntity);
            }
        }
        return modelEntity;
    }
 
    /**
     * 3.根据模型ID,获取模型对应的输入参数
     *
     * @param modelId
     * @return
     */
    public List<MmModelParamEntity> getModelInputParam(String modelId) {
        if (!modelInputParamMap.containsKey(modelId)) {
            List<MmModelParamEntity> modelInputParamEntities = mmModelParamService.getByModelid(modelId);
            if (modelInputParamEntities != null) {
                modelInputParamMap.put(modelId, modelInputParamEntities);
            } else {
                return null;
            }
        }
        return mmModelParamService.getByModelid(modelId);
    }
 
    /**
     * 4.根据模型ID,获取模型对应的输入参数的维数
     *
     * @param modelId
     * @return
     */
    public Integer getModelInputCount(String modelId) {
        if (!modelInputParamMap.containsKey(modelId)) {
            List<MmModelParamEntity> modelInputParamEntityList = mmModelParamService.getByModelid(modelId);
            if (modelInputParamEntityList != null) {
                modelInputParamMap.put(modelId, modelInputParamEntityList);
            } else {
                return 0;
            }
        }
        return modelInputParamMap.get(modelId).size();
    }
 
    /**
     * 5.根据模型ID,获取模型对应的算法参数
     *
     * @param modelId
     * @return
     */
    public List<MmModelArithSettingsEntity> getModelArithParam(String modelId) {
        if (!modelArithParamMap.containsKey(modelId)) {
            List<MmModelArithSettingsEntity> modelArithParamEntityList = mmModelArithSettingsService.getByModelId(modelId);
            if (modelArithParamEntityList != null) {
                modelArithParamMap.put(modelId, modelArithParamEntityList);
            } else {
                return null;
            }
        }
        return modelArithParamMap.get(modelId);
    }
 
    /**
     * 7.根据预测项itemID,获取status=1的模型列表
     *
     * @param itemId
     * @return
     */
    public List<MmPredictModelEntity> getActiveModelByItemId(String itemId) {
        if (!modelListMap.containsKey(itemId)) {
            List<MmPredictModelEntity> modelEntityList = mmPredictModelService.getActiveModelByItemId(itemId);
            if (modelEntityList != null) {
                modelListMap.put(itemId, modelEntityList);
            } else {
                return null;
            }
        }
        return modelListMap.get(itemId);
    }
 
    /**
     * 8.根据模型ID,删除模型对应的输入参数
     *
     * @param modelId
     * @return
     */
    public void removeModelInputParam(String modelId) {
        if (modelInputParamMap.containsKey(modelId)) {
            log.info("removeModelInputParam:modelId=" + modelId);
            modelInputParamMap.remove(modelId);
        }
    }
 
    /**
     * 清除缓存
     */
    public void removeModelEntity() {
        modelFileMap.clear();
        modelEntityMap.clear();
        modelInputParamMap.clear();
        modelArithParamMap.clear();
        modelListMap.clear();
    }
}