package com.iailab.module.model.mdk.factory; import com.iailab.module.model.mcs.pre.entity.MmModelArithSettingsEntity; import com.iailab.module.model.mcs.pre.entity.MmModelParamEntity; import com.iailab.module.model.mcs.pre.entity.MmPredictModelEntity; import com.iailab.module.model.mcs.pre.service.MmModelArithSettingsService; import com.iailab.module.model.mcs.pre.service.MmModelParamService; import com.iailab.module.model.mcs.pre.service.MmPredictModelService; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.util.HashMap; import java.util.List; import java.util.Map; /** * 创建和管理模型实体 */ @Slf4j @Component public class ModelEntityFactory { private Map modelFileMap = new HashMap<>(); private Map modelEntityMap = new HashMap<>(); private Map> modelInputParamMap = new HashMap<>(); private Map> modelArithParamMap = new HashMap<>(); private Map> modelListMap = new HashMap<>(); @Autowired private MmPredictModelService mmPredictModelService; @Autowired private MmModelParamService mmModelParamService; @Autowired private MmModelArithSettingsService mmModelArithSettingsService; /** * 2.根据模型ID,获取模型实体 * * @param modelId * @return */ public MmPredictModelEntity getModelEntity(String modelId) { MmPredictModelEntity modelEntity = mmPredictModelService.getInfo(modelId); if (!modelEntityMap.containsKey(modelId)) { if (modelEntity != null) { modelEntityMap.put(modelId, modelEntity); } } return modelEntity; } /** * 3.根据模型ID,获取模型对应的输入参数 * * @param modelId * @return */ public List getModelInputParam(String modelId) { if (!modelInputParamMap.containsKey(modelId)) { List modelInputParamEntities = mmModelParamService.getByModelid(modelId); if (modelInputParamEntities != null) { modelInputParamMap.put(modelId, modelInputParamEntities); } else { return null; } } return mmModelParamService.getByModelid(modelId); } /** * 4.根据模型ID,获取模型对应的输入参数的维数 * * @param modelId * @return */ public Integer getModelInputCount(String modelId) { if (!modelInputParamMap.containsKey(modelId)) { List modelInputParamEntityList = mmModelParamService.getByModelid(modelId); if (modelInputParamEntityList != null) { modelInputParamMap.put(modelId, modelInputParamEntityList); } else { return 0; } } return modelInputParamMap.get(modelId).size(); } /** * 5.根据模型ID,获取模型对应的算法参数 * * @param modelId * @return */ public List getModelArithParam(String modelId) { if (!modelArithParamMap.containsKey(modelId)) { List modelArithParamEntityList = mmModelArithSettingsService.getByModelId(modelId); if (modelArithParamEntityList != null) { modelArithParamMap.put(modelId, modelArithParamEntityList); } else { return null; } } return modelArithParamMap.get(modelId); } /** * 7.根据预测项itemID,获取status=1的模型列表 * * @param itemId * @return */ public List getActiveModelByItemId(String itemId) { if (!modelListMap.containsKey(itemId)) { List modelEntityList = mmPredictModelService.getActiveModelByItemId(itemId); if (modelEntityList != null) { modelListMap.put(itemId, modelEntityList); } else { return null; } } return modelListMap.get(itemId); } /** * 8.根据模型ID,删除模型对应的输入参数 * * @param modelId * @return */ public void removeModelInputParam(String modelId) { if (modelInputParamMap.containsKey(modelId)) { log.info("removeModelInputParam:modelId=" + modelId); modelInputParamMap.remove(modelId); } } /** * 清除缓存 */ public void removeModelEntity() { modelFileMap.clear(); modelEntityMap.clear(); modelInputParamMap.clear(); modelArithParamMap.clear(); modelListMap.clear(); } }