dengzedong
2024-11-22 0425a35c2378f9c024ce4d50ee13a896f1eee3ac
动态加载pyd
已修改4个文件
104 ■■■■ 文件已修改
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mpk/common/MdkConstant.java 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mpk/service/impl/MethodSettingServiceImpl.java 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mpk/service/impl/MpkFileServiceImpl.java 86 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/resources/template/cpp.vm 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mpk/common/MdkConstant.java
@@ -54,4 +54,8 @@
     * 编译pyd生成文件后缀
     */
    String PYD_SUFFIX = ".cp37-win_amd64.pyd";
    /**
     * 默认模型路径key
     */
    String PY_FILE_KEY = "pyFile";
}
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mpk/service/impl/MethodSettingServiceImpl.java
@@ -2,6 +2,7 @@
import com.iailab.framework.common.service.impl.BaseServiceImpl;
import com.iailab.framework.common.util.object.ConvertUtils;
import com.iailab.module.model.mpk.common.MdkConstant;
import com.iailab.module.model.mpk.dao.MethodSettingDao;
import com.iailab.module.model.mpk.dto.MethodSettingDTO;
import com.iailab.module.model.mpk.entity.MethodSettingEntity;
@@ -30,6 +31,8 @@
    @Override
    public void insertList(List<MethodSettingDTO> list, String MethodId) {
        List<MethodSettingEntity> entityList = ConvertUtils.sourceToTarget(list, MethodSettingEntity.class);
        // pyFile排第一
        entityList.sort((e1,e2) -> e1.getSettingKey().equals(MdkConstant.PY_FILE_KEY) ? -1 : e2.getSettingKey().equals(MdkConstant.PY_FILE_KEY) ? 1 : 0);
        for(int i = 0; i < entityList.size(); i++){
            MethodSettingEntity entity = entityList.get(i);
            if (StringUtils.isNotBlank(entity.getValue())){
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mpk/service/impl/MpkFileServiceImpl.java
@@ -20,6 +20,7 @@
import com.iailab.module.model.mpk.common.utils.GenUtils;
import com.iailab.module.model.mpk.dao.MpkFileDao;
import com.iailab.module.model.mpk.dto.GeneratorCodeHistoryDTO;
import com.iailab.module.model.mpk.dto.MethodSettingDTO;
import com.iailab.module.model.mpk.dto.ModelMethodDTO;
import com.iailab.module.model.mpk.dto.MpkFileDTO;
import com.iailab.module.model.mpk.entity.GeneratorCodeHistoryEntity;
@@ -128,30 +129,82 @@
        entity.setCreator(SecurityFrameworkUtils.getLoginUserId());
        entity.setCreateDate(new Date());
        insert(entity);
        modelMethodService.insertList(dto.getModelMethods(), entity.getId());
        // 替换环境变量MDK_PKGS下的py文件
//        String mdkPkgs = System.getenv("MDK_PKGS");
//        String pyFilePath = mdkPkgs + File.separator + entity.getPyModule().replace(".", File.separator) + File.separator + entity.getPyName() + ".pyd";
//        FileUtil.mkParentDirs(pyFilePath);
//        FileUtil.copy(entity.getFilePath(), pyFilePath, true);
        // 将备份的pyd文件,转移到MDK_PKGS环境变量下,并添加方法的默认参数(pyFile,模型路径)
        String mdkPkgs = System.getenv("MDK_PKGS");
        String fileName = entity.getFilePath().substring(entity.getFilePath().lastIndexOf("\\") + 1,entity.getFilePath().lastIndexOf(".pyd"));
        String pyFilePath = mdkPkgs + File.separator + entity.getPyModule().replace(".", File.separator) + File.separator + fileName + ".pyd";
        FileUtil.mkParentDirs(pyFilePath);
        FileUtil.copy(entity.getFilePath(), pyFilePath, true);
        // 添加参数
        for (ModelMethodDTO method : dto.getModelMethods()) {
            List<MethodSettingDTO> methodSettings = method.getMethodSettings();
            if (methodSettings.stream().anyMatch(e -> e.getSettingKey().equals(MdkConstant.PY_FILE_KEY))) {
                methodSettings.forEach(e -> {
                    if (e.getSettingKey().equals(MdkConstant.PY_FILE_KEY)) {
                        e.setValue(entity.getPyModule() + "." + fileName);
                    }
                });
            }else {
                MethodSettingDTO setting = new MethodSettingDTO();
                setting.setId(UUID.randomUUID().toString());
                setting.setMethodId(method.getId());
                setting.setSettingKey(MdkConstant.PY_FILE_KEY);
                setting.setValue(entity.getPyModule() + "." + fileName);
                setting.setName("模型路径");
                setting.setType("input");
                setting.setValueType("string");
                methodSettings.add(setting);
            }
        }
        modelMethodService.insertList(dto.getModelMethods(), entity.getId());
    }
    @Override
    @DSTransactional(rollbackFor = Exception.class)
    public void update(MpkFileDTO dto) {
        // 判断py文件是否修改
        MpkFileEntity oldEntity = selectById(dto.getId());
        if (!oldEntity.getFilePath().equals(dto.getFilePath())) {
            // 将备份的pyd文件,转移到MDK_PKGS环境变量下,并添加方法的默认参数(pyFile,模型路径)
            String mdkPkgs = System.getenv("MDK_PKGS");
            String fileName = dto.getFilePath().substring(dto.getFilePath().lastIndexOf("\\") + 1,dto.getFilePath().lastIndexOf(".pyd"));
            String pyFilePath = mdkPkgs + File.separator + dto.getPyModule().replace(".", File.separator) + File.separator + fileName + ".pyd";
            FileUtil.mkParentDirs(pyFilePath);
            FileUtil.copy(dto.getFilePath(), pyFilePath, true);
            // 添加/修改参数
            for (ModelMethodDTO method : dto.getModelMethods()) {
                List<MethodSettingDTO> methodSettings = method.getMethodSettings();
                if (methodSettings.stream().anyMatch(e -> e.getSettingKey().equals(MdkConstant.PY_FILE_KEY))) {
                    methodSettings.forEach(e -> {
                        if (e.getSettingKey().equals(MdkConstant.PY_FILE_KEY)) {
                            e.setValue(dto.getPyModule() + "." + fileName);
                        }
                    });
                }else {
                    MethodSettingDTO setting = new MethodSettingDTO();
                    setting.setId(UUID.randomUUID().toString());
                    setting.setMethodId(method.getId());
                    setting.setSettingKey(MdkConstant.PY_FILE_KEY);
                    setting.setValue(dto.getPyModule() + "." + fileName);
                    setting.setName("模型路径");
                    setting.setType("input");
                    setting.setValueType("string");
                    methodSettings.add(setting);
                }
            }
        }
        MpkFileEntity entity = ConvertUtils.sourceToTarget(dto, MpkFileEntity.class);
        entity.setUpdater(SecurityFrameworkUtils.getLoginUserId());
        entity.setUpdateDate(new Date());
        updateById(entity);
        modelMethodService.deleteModelMethod(entity.getId());
        modelMethodService.insertList(dto.getModelMethods(), entity.getId());
        // 替换环境变量MDK_PKGS下的py文件
//        String mdkPkgs = System.getenv("MDK_PKGS");
//        String pyFilePath = mdkPkgs + File.separator + entity.getPyModule().replace(".", File.separator) + File.separator + entity.getPyName() + ".pyd";
//        FileUtil.mkParentDirs(pyFilePath);
//        FileUtil.copy(entity.getFilePath(), pyFilePath, true);
    }
    @Override
@@ -509,9 +562,10 @@
            throw new RuntimeException("创建临时文件夹异常",e);
        }
        String fileName = file.getOriginalFilename();
        File pydBakFile = new File(bakDir.getAbsolutePath() + File.separator + UUID.randomUUID() + ".pyd");
        String pyName = fileName.substring(0, fileName.lastIndexOf("."));
        String pyName_uuTime = pyName + "_" + new Date().getTime();
        File pydBakFile = new File(bakDir.getAbsolutePath() + File.separator + pyName_uuTime + ".pyd");
        try {
            // py文件存入临时文件夹
            File saveFile = new File(tempDir.getAbsolutePath() + File.separator + fileName);
@@ -520,12 +574,12 @@
            // 临时文件夹中生成pyd文件
            // 生成setup.py文件
            createSetUpFile(tempDir.getAbsolutePath(),fileName,pyName);
            createSetUpFile(tempDir.getAbsolutePath(),fileName,pyName_uuTime);
            // 执行cmd命令编译pyd文件
            createPydFile(tempDir.getAbsolutePath());
            // 临时文件夹中pyd文件移到dir下,修改为uuid.pyd
            File pydFile = new File(tempDir.getAbsolutePath() + File.separator + pyName + MdkConstant.PYD_SUFFIX);
            File pydFile = new File(tempDir.getAbsolutePath() + File.separator + pyName_uuTime + MdkConstant.PYD_SUFFIX);
            if (!pydFile.exists()) {
                throw new RuntimeException("编译pyd文件失败!");
            }
iailab-module-model/iailab-module-model-biz/src/main/resources/template/cpp.vm
@@ -17,7 +17,16 @@
    {
        PyGILThreadLock lock;
        PyObject* pModule = create_py_module("${pyModule}.${pyName}");
        jclass hashmapClass = env->FindClass("java/util/HashMap");
        jmethodID getMID = env->GetMethodID(hashmapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
        jstring keyJString = env->NewStringUTF("pyFile");
        jobject javaValueObj = env->CallObjectMethod(settings, getMID, keyJString);
        const char* strValue = env->GetStringUTFChars((jstring)javaValueObj, NULL);
        cout << strValue << endl;
        PyObject* pModule = create_py_module(strValue);
        /*PyObject* pModule = create_py_module("${pyModule}.${pyName}");*/
        if (pModule == NULL)
        {
            cout << "model error" << endl;