package com.iailab.module.model.matlab.service.impl;
|
|
|
import cn.hutool.cache.CacheUtil;
|
import cn.hutool.cache.impl.FIFOCache;
|
import cn.hutool.core.io.FileUtil;
|
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSONArray;
|
import com.baomidou.dynamic.datasource.annotation.DSTransactional;
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
import com.baomidou.mybatisplus.core.metadata.IPage;
|
import com.iailab.framework.common.exception.enums.GlobalErrorCodeConstants;
|
import com.iailab.framework.common.page.PageData;
|
import com.iailab.framework.common.pojo.CommonResult;
|
import com.iailab.framework.common.service.impl.BaseServiceImpl;
|
import com.iailab.framework.common.util.object.ConvertUtils;
|
import com.iailab.framework.security.core.util.SecurityFrameworkUtils;
|
import com.iailab.framework.tenant.core.context.TenantContextHolder;
|
import com.iailab.module.model.common.utils.DateUtils;
|
import com.iailab.module.model.matlab.common.exceptions.IllegalityJarException;
|
import com.iailab.module.model.matlab.common.utils.MatlabUtils;
|
import com.iailab.module.model.matlab.dao.MlModelDao;
|
import com.iailab.module.model.matlab.dto.MatlabJarFileInfoDTO;
|
import com.iailab.module.model.matlab.dto.MatlabRunDTO;
|
import com.iailab.module.model.matlab.dto.MlModelDTO;
|
import com.iailab.module.model.matlab.entity.MlModelEntity;
|
import com.iailab.module.model.matlab.service.MlModelMethodService;
|
import com.iailab.module.model.matlab.service.MlModelService;
|
import com.iailab.module.model.mpk.common.MdkConstant;
|
import com.iailab.module.model.mpk.common.utils.Readtxt;
|
import com.mathworks.toolbox.javabuilder.MWStructArray;
|
import lombok.extern.slf4j.Slf4j;
|
import org.apache.commons.lang3.StringUtils;
|
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.CollectionUtils;
|
import org.springframework.web.multipart.MultipartFile;
|
|
import java.io.File;
|
import java.io.IOException;
|
import java.net.URLClassLoader;
|
import java.util.*;
|
|
import static com.iailab.framework.common.pojo.CommonResult.error;
|
import static com.iailab.framework.common.pojo.CommonResult.success;
|
|
/**
|
*
|
*
|
* @author Dzd
|
* @since 1.0.0 2025-02-08
|
*/
|
@Slf4j
|
@Service
|
public class MlModelServiceImpl extends BaseServiceImpl<MlModelDao, MlModelEntity> implements MlModelService {
|
|
@Value("${mablab.bak-file-path}")
|
private String mablabBakFilePath;
|
|
@Autowired
|
private MlModelMethodService mlModelMethodService;
|
|
// 先进先出缓存 临时保存导入的数据
|
private static FIFOCache<String, String> cache = CacheUtil.newFIFOCache(100);
|
|
@Override
|
public PageData<MlModelDTO> page(Map<String, Object> params) {
|
IPage<MlModelEntity> page = baseDao.selectPage(
|
getPage(params, "create_date", false),
|
getWrapper(params)
|
);
|
|
return getPageData(page, MlModelDTO.class);
|
}
|
|
@Override
|
public List<MlModelDTO> list(Map<String, Object> params) {
|
List<MlModelDTO> list = baseDao.list(params);
|
|
return list;
|
}
|
|
private QueryWrapper<MlModelEntity> getWrapper(Map<String, Object> params) {
|
String modelName = (String) params.get("modelName");
|
String modelFileName = (String) params.get("modelFileName");
|
|
QueryWrapper<MlModelEntity> wrapper = new QueryWrapper<>();
|
wrapper.like(StringUtils.isNotBlank(modelName), "model_name", modelName)
|
.like(StringUtils.isNotBlank(modelFileName), "model_file_name", modelFileName);
|
return wrapper;
|
}
|
|
@Override
|
public MlModelDTO get(String id) {
|
MlModelDTO entity = baseDao.get(id);
|
|
return entity;
|
}
|
|
@Override
|
@DSTransactional(rollbackFor = Exception.class)
|
public void save(MlModelDTO dto) {
|
MlModelEntity entity = ConvertUtils.sourceToTarget(dto, MlModelEntity.class);
|
entity.setId(UUID.randomUUID().toString());
|
entity.setModelFilePath(dto.getModelFilePath().trim());
|
entity.setModelFileName(dto.getModelFileName().trim());
|
entity.setCreator(SecurityFrameworkUtils.getLoginUserId());
|
entity.setCreateDate(new Date());
|
insert(entity);
|
|
mlModelMethodService.insertList(dto.getModelMethods(), entity.getId());
|
|
saveJarFile(dto.getModelFilePath(),dto.getModelFileName());
|
}
|
|
@Override
|
@DSTransactional(rollbackFor = Exception.class)
|
public void update(MlModelDTO dto) {
|
MlModelEntity entity = ConvertUtils.sourceToTarget(dto, MlModelEntity.class);
|
entity.setUpdater(SecurityFrameworkUtils.getLoginUserId());
|
entity.setUpdateDate(new Date());
|
updateById(entity);
|
mlModelMethodService.deleteModelMethod(entity.getId());
|
mlModelMethodService.insertList(dto.getModelMethods(), entity.getId());
|
|
saveJarFile(dto.getModelFilePath(),dto.getModelFileName());
|
}
|
|
/**
|
* @description: 保存最新jar文件,用于测试运行
|
* @author: dzd
|
* @date: 2025/2/24 16:25
|
**/
|
private void saveJarFile(String modelFilePath,String modelFileName) {
|
String matlabTenantBakFilePath = getMatlabTenantBakFilePath();
|
|
String jarBakPath = matlabTenantBakFilePath + File.separator + MdkConstant.JAR + File.separator + modelFileName + ".jar";
|
FileUtil.copy(modelFilePath, jarBakPath, true);
|
}
|
|
@Override
|
@DSTransactional(rollbackFor = Exception.class)
|
public void delete(String id) {
|
|
//删除源文件
|
MlModelEntity MlModelEntity = selectById(id);
|
if (StringUtils.isNoneBlank(MlModelEntity.getModelFilePath())) {
|
File mpkFile = new File(MlModelEntity.getModelFilePath());
|
if (mpkFile.exists()) {
|
mpkFile.delete();
|
log.info("删除源文件备份文件:" + MlModelEntity.getModelFilePath());
|
}
|
}
|
//删除 会级联删除掉关联表
|
deleteById(id);
|
}
|
|
@Override
|
public MatlabJarFileInfoDTO uploadJarFile(MultipartFile file) throws IllegalityJarException {
|
String matlabTenantBakFilePath = getMatlabTenantBakFilePath();
|
File bakDir = new File(matlabTenantBakFilePath);
|
|
String jarName = null;
|
File jarBakFile = null;
|
try {
|
String fileName = file.getOriginalFilename();
|
jarName = fileName.substring(0, fileName.lastIndexOf("."));
|
String pyName_time = jarName + "_" + DateUtils.format(new Date(),DateUtils.DATE_TIME_STRING);
|
jarBakFile = new File(bakDir.getAbsolutePath() + File.separator + pyName_time + ".jar");
|
file.transferTo(jarBakFile);
|
} catch (IOException e) {
|
throw new RuntimeException("保存算法封装jar文件失败!");
|
}
|
|
//解析jar info
|
MatlabJarFileInfoDTO result = new MatlabJarFileInfoDTO();
|
result.setFilePath(jarBakFile.getAbsolutePath());
|
result.setFileName(jarName);
|
result.setClassInfos(MatlabUtils.parseJarInfo(jarBakFile.getAbsolutePath(),jarName));
|
return result;
|
}
|
|
@Override
|
public CommonResult<String> test(MatlabRunDTO dto) {
|
String matlabTenantBakFilePath = getMatlabTenantBakFilePath();
|
|
Class<?> clazz;
|
URLClassLoader classLoader;
|
try {
|
File jarFile = new File(matlabTenantBakFilePath + File.separator + MdkConstant.JAR + File.separator + dto.getModelFileName() + ".jar");
|
if (!jarFile.exists()) {
|
throw new RuntimeException("jar包不存在,请检查。jarPath:" + jarFile.getAbsolutePath());
|
}
|
// 加载jar包
|
classLoader = MatlabUtils.loadJar(null,jarFile.getAbsolutePath());
|
// 实现类
|
clazz = classLoader.loadClass(dto.getClassName());
|
} catch (Exception e) {
|
e.printStackTrace();
|
throw new RuntimeException("加载运行环境失败。");
|
}
|
|
try {
|
List<String> uuids = dto.getUuids();
|
|
Object[] paramsValueArray = new Object[uuids.size() + 1];
|
|
try {
|
for (int i = 0; i < uuids.size(); i++) {
|
String uuid = uuids.get(i);
|
if (!cache.containsKey(uuid)) {
|
return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"请重新导入模型参数");
|
}
|
JSONArray jsonArray = JSON.parseArray(cache.get(uuid));
|
double[][] data = new double[jsonArray.size()][jsonArray.getJSONArray(0).size()];
|
for (int j = 0; j < jsonArray.size(); j++) {
|
for (int k = 0; k < jsonArray.getJSONArray(j).size(); k++) {
|
data[j][k] = jsonArray.getJSONArray(j).getDoubleValue(k);
|
}
|
}
|
paramsValueArray[i] = data;
|
}
|
} catch (Exception e) {
|
e.printStackTrace();
|
return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型参数错误,请检查!");
|
}
|
|
MWStructArray mwStructArraySetting;
|
try {
|
|
HashMap<String, Object> settings = MatlabUtils.handleModelSettings(dto.getModelSettings());
|
mwStructArraySetting = MatlabUtils.convertMapToStruct(settings);
|
paramsValueArray[uuids.size()] = mwStructArraySetting;
|
} catch (Exception e) {
|
e.printStackTrace();
|
return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型设置错误,请检查!");
|
}
|
|
Class<?>[] paramsArray = new Class[2];
|
paramsArray[0] = int.class;
|
paramsArray[1] = Object[].class;
|
Object[] objects = (Object[]) clazz.getDeclaredMethod(dto.getMethodName(), paramsArray).invoke(clazz.newInstance(), new Object[]{dto.getOutLength(), paramsValueArray});
|
mwStructArraySetting.dispose();
|
Map<String, Object> result = MatlabUtils.convertStructToMap((MWStructArray) objects[0]);
|
return success(JSON.toJSONString(result));
|
} catch (Exception ex) {
|
ex.printStackTrace();
|
return error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"运行异常");
|
} finally {
|
if (classLoader != null) {
|
MatlabUtils.unloadJar(classLoader);
|
}
|
}
|
}
|
|
@Override
|
public List<HashMap<String, Object>> importData(MultipartFile file) throws IOException {
|
List<double[][]> datas = Readtxt.readMethodExcel(file);
|
List<HashMap<String,Object>> result = new ArrayList<>();
|
if (!CollectionUtils.isEmpty(datas)) {
|
for (double[][] data : datas) {
|
if (data.length > 0) {
|
HashMap<String,Object> map = new HashMap<>();
|
String uuid = UUID.randomUUID().toString();
|
map.put("uuid",uuid);
|
map.put("data", JSON.toJSONString(data));
|
cache.put(uuid,JSON.toJSONString(data));
|
result.add(map);
|
}
|
}
|
}
|
return result;
|
}
|
|
private String getMatlabTenantBakFilePath() {
|
Long tenantId = TenantContextHolder.getTenantId();
|
// 备份文件夹 租户隔离
|
String matlabTenantBakFilePath = mablabBakFilePath + File.separator + tenantId;
|
File bakDir = new File(matlabTenantBakFilePath);
|
if (!bakDir.exists()) {
|
bakDir.mkdirs();
|
}
|
return matlabTenantBakFilePath;
|
}
|
|
}
|