dengzedong
2025-02-21 f7e4a8c81cb019d9aef5ff55ddedf8083943ca8b
iailab-module-model/iailab-module-model-biz/src/main/java/com/iailab/module/model/mpk/controller/admin/MdkController.java
@@ -2,8 +2,14 @@
import cn.hutool.cache.CacheUtil;
import cn.hutool.cache.impl.FIFOCache;
import cn.hutool.core.io.FileUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.iail.IAILMDK;
import com.iail.bean.FieldSet;
import com.iail.bean.Property;
import com.iail.bean.SelectItem;
import com.iail.model.IAILModel;
import com.iail.utils.RSAUtils;
import com.iailab.framework.common.exception.enums.GlobalErrorCodeConstants;
import com.iailab.framework.common.pojo.CommonResult;
@@ -16,18 +22,21 @@
import com.iailab.module.model.mpk.dto.MethodSettingDTO;
import io.swagger.v3.oas.annotations.Operation;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.net.URLEncoder;
import java.nio.file.Files;
import java.util.*;
import java.util.stream.Collectors;
import static com.iailab.framework.common.pojo.CommonResult.error;
import static com.iailab.framework.common.pojo.CommonResult.success;
@@ -44,7 +53,7 @@
    @Value("${mpk.bak-file-path}")
    private String mpkBakFilePath;
    // 先进先出缓存
    // 先进先出缓存 临时保存导入的数据
    private static FIFOCache<String, String> cache = CacheUtil.newFIFOCache(100);
    /**
@@ -110,6 +119,10 @@
            }
            try {
                if (dto.getModelSettings().stream().noneMatch(e -> e.getSettingKey().equals(MdkConstant.PY_FILE_KEY))) {
                    return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"模型设置参数缺少必要信息【" + MdkConstant.PY_FILE_KEY +  "】,请重新上传模型!");
                }
                if (dto.getHasModel()) {
                    paramsValueArray[uuids.size()] = dto.getModel();
                    paramsValueArray[uuids.size() + 1] = handleModelSettings(dto.getModelSettings());
@@ -135,6 +148,113 @@
                DllUtils.unloadJar(classLoader);
            }
            System.gc();
        }
    }
    private IAILModel createModelBean(MdkDTO dto) {
        IAILModel modelBean = new IAILModel();
        //ParamPathList
        List<String> paramPathList = new ArrayList<>();
        List<String> paramNameList = new ArrayList<>();
        for (Map.Entry<String, Object> entry : dto.getModel().entrySet()) {
            paramNameList.add(entry.getKey());
            paramPathList.add(entry.getValue().toString());
        }
        modelBean.setParamNameList(paramNameList);
        modelBean.setParamPathList(paramPathList);
        //ClassName MethodName
        modelBean.setClassName(dto.getClassName());
        modelBean.setMethodName(dto.getMethodName());
        //ParamsArray
        int paramLength = dto.getHasModel() ? dto.getDataLength() + 2 : dto.getDataLength() + 1;
        Class<?>[] paramsArray = new Class[paramLength];
        for (int i = 0; i < dto.getDataLength(); i++) {
            paramsArray[i] = double[][].class;
        }
        if (dto.getHasModel()) {
            paramsArray[dto.getDataLength()] = HashMap.class;
            paramsArray[dto.getDataLength() + 1] = HashMap.class;
        }else {
            paramsArray[dto.getDataLength()] = HashMap.class;
        }
        modelBean.setParamsArray(paramsArray);
        //LoadFieldSetList
        List<FieldSet> loadFieldSetList = new ArrayList<>();
        FieldSet fieldSet = new FieldSet();
        fieldSet.setFieldName("");
        List<Property> propertyList = new ArrayList<>();
        for (MethodSettingDTO modelSetting : dto.getPredModelSettings()) {
            Property property = new Property();
            property.setKey(modelSetting.getSettingKey());
            property.setName(modelSetting.getName());
            property.setType(modelSetting.getType());
            property.setValueType(modelSetting.getValueType());
            property.setMin(modelSetting.getMin() == null ? "" : modelSetting.getMin().toString());
            property.setMax(modelSetting.getMax() == null ? "" : modelSetting.getMax().toString());
            property.setSelectItemList(CollectionUtils.isEmpty(modelSetting.getSettingSelects()) ? null : modelSetting.getSettingSelects().stream().map(e -> new SelectItem(e.getSelectKey(),e.getName())).collect(Collectors.toList()));
            property.setValue(modelSetting.getValue());
            property.setFlow(false);
            propertyList.add(property);
        }
        fieldSet.setPropertyList(propertyList);
        loadFieldSetList.add(fieldSet);
        modelBean.setLoadFieldSetList(loadFieldSetList);
        //SettingConfigMap
        Map<String, Object> settingConfigMap = new HashMap<String, Object>();
        List<com.iail.bean.Value> settingKeyList = new ArrayList<com.iail.bean.Value>();
        Map<String, Object> settingMap = new HashMap<String, Object>();
        for (MethodSettingDTO modelSetting : dto.getModelSettings()) {
            settingKeyList.add(new com.iail.bean.Value(modelSetting.getSettingKey(),modelSetting.getSettingKey()));
            settingConfigMap.put("settingKeyList", settingKeyList);
            settingConfigMap.put("settingMap", handleModelSettings(dto.getModelSettings()));
        }
        modelBean.setSettingConfigMap(settingConfigMap);
        //DataMap
        modelBean.setDataMap(dto.getModelResult());
        //ResultKey
        modelBean.setResultKey(dto.getResultKey());
        //ResultKey
        modelBean.setVersion("1.0.0");
        return modelBean;
    }
    @PostMapping("saveModel")
    public void saveModel(@RequestBody MdkDTO dto, HttpServletResponse response) {
        IAILModel modelBean = createModelBean(dto);
        try {
            //临时文件夹
            File tempFile = null;
            try {
                tempFile = Files.createTempFile(dto.getPyName(),".miail").toFile();
                log.info("生成临时文件," + tempFile.getAbsolutePath());
            } catch (IOException e) {
                throw new RuntimeException("创建临时文件异常",e);
            }
            try {
                IAILMDK.saveModel(tempFile, modelBean);
            } catch (Exception e) {
                throw new RuntimeException("IAILMDK.saveModel异常",e);
            }
            byte[] data = FileUtil.readBytes(tempFile);
            response.reset();
            response.setHeader("Content-Disposition", "attachment; filename=\"" + URLEncoder.encode(tempFile.getName(), "UTF-8") + "\"");
            response.addHeader("Content-Length", "" + data.length);
            response.setContentType("application/octet-stream; charset=UTF-8");
            IOUtils.write(data, response.getOutputStream());
        } catch (Exception e) {
            throw new RuntimeException("代码生成异常",e);
        }
    }
@@ -197,7 +317,7 @@
                return CommonResult.error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"模型运行失败!");
            }
        }
        return CommonResult.success();
        return CommonResult.success("");
    }
    @PostMapping("/import")