dengzedong
2024-10-14 3e18d4bfbf2c657b08b21512c2d884cc9d59df7b
提交 | 用户 | 时间
449017 1 package com.iailab.module.model.mpk.controller.admin;
D 2
3 import com.alibaba.fastjson.JSON;
4 import com.alibaba.fastjson.JSONArray;
3e18d4 5 import com.iail.utils.RSAUtils;
449017 6 import com.iailab.framework.common.exception.enums.GlobalErrorCodeConstants;
D 7 import com.iailab.framework.common.pojo.CommonResult;
558ffc 8 import com.iailab.framework.tenant.core.context.TenantContextHolder;
1fea3e 9 import com.iailab.module.model.mpk.common.MdkConstant;
912a1e 10 import com.iailab.module.model.mpk.common.utils.DllUtils;
449017 11 import com.iailab.module.model.mpk.common.utils.Readtxt;
D 12 import com.iailab.module.model.mpk.dto.MdkDTO;
3e18d4 13 import com.iailab.module.model.mpk.dto.MdkRunDTO;
449017 14 import io.swagger.v3.oas.annotations.Operation;
912a1e 15 import lombok.extern.slf4j.Slf4j;
1fea3e 16 import org.springframework.beans.factory.annotation.Value;
449017 17 import org.springframework.util.CollectionUtils;
D 18 import org.springframework.web.bind.annotation.*;
19 import org.springframework.web.multipart.MultipartFile;
20
1fea3e 21 import java.io.File;
3e18d4 22 import java.lang.reflect.Method;
1fea3e 23 import java.net.URLClassLoader;
D 24 import java.util.ArrayList;
25 import java.util.HashMap;
26 import java.util.List;
449017 27
D 28 import static com.iailab.framework.common.pojo.CommonResult.error;
29 import static com.iailab.framework.common.pojo.CommonResult.success;
30
31 /**
32  * @author PanZhibao
33  * @Description
34  * @createTime 2024年08月08日
35  */
36 @RestController
912a1e 37 @Slf4j
449017 38 @RequestMapping("/model/mpk/api")
D 39 public class MdkController {
1fea3e 40     @Value("${mpk.bak-file-path}")
D 41     private String mpkBakFilePath;
42
43
3e18d4 44     /**
D 45      * @description: 模型测试运行
46      * @author: dzd
47      * @date: 2024/10/14 15:26
48      **/
49     @PostMapping("test")
50     public CommonResult<String> test(@RequestBody MdkDTO dto) {
558ffc 51         Long tenantId = TenantContextHolder.getTenantId();
D 52         // 备份文件 租户隔离
53         String mpkTenantBakFilePath = mpkBakFilePath + File.separator + tenantId;
54
912a1e 55         Class<?> clazz;
D 56         URLClassLoader classLoader;
1fea3e 57         try {
558ffc 58             File jarFile = new File(mpkTenantBakFilePath + File.separator + MdkConstant.JAR + File.separator + dto.getPyName() + ".jar");
1fea3e 59             if (!jarFile.exists()) {
D 60                 throw new RuntimeException("jar包不存在,请先生成代码。jarPath:" + jarFile.getAbsolutePath());
61             }
558ffc 62             File dllFile = new File(mpkTenantBakFilePath + File.separator + MdkConstant.DLL + File.separator + dto.getPyName() + ".dll");
1fea3e 63             if (!dllFile.exists()) {
912a1e 64                 throw new RuntimeException("dll文件不存在,请先生成代码。dllPath:" + dllFile.getAbsolutePath());
1fea3e 65             }
D 66             // 加载jar包
912a1e 67             classLoader = DllUtils.loadJar(jarFile.getAbsolutePath());
D 68             // 实现类
69             clazz = classLoader.loadClass(dto.getClassName());
70             // 加载dll到实现类
71             DllUtils.loadDll(clazz,dllFile.getAbsolutePath());
1fea3e 72         } catch (Exception e) {
912a1e 73             e.printStackTrace();
1fea3e 74             throw new RuntimeException("加载运行环境失败。");
D 75         }
76
449017 77         System.out.println("runTime=" + System.currentTimeMillis());
D 78         try {
79             List<String> datas = dto.getDatas();
80
81             int paramLength = dto.getHasModel() ? datas.size() + 2 : datas.size() + 1;
82             Object[] paramsValueArray = new Object[paramLength];
83             Class<?>[] paramsArray = new Class[paramLength];
84
85             try {
86                 for (int i = 0; i < datas.size(); i++) {
87                     String json = datas.get(i);
88                     JSONArray jsonArray = JSON.parseArray(json);
89                     double[][] data = new double[jsonArray.size()][jsonArray.getJSONArray(0).size()];
90                     for (int j = 0; j < jsonArray.size(); j++) {
91                         for (int k = 0; k < jsonArray.getJSONArray(j).size(); k++) {
92                             data[j][k] = jsonArray.getJSONArray(j).getDoubleValue(k);
93                         }
94                     }
95                     paramsValueArray[i] = data;
96                     paramsArray[i] = double[][].class;
97                 }
98             } catch (Exception e) {
99                 e.printStackTrace();
100                 return error(GlobalErrorCodeConstants.BAD_REQUEST.getCode(),"参数错误,请检查!");
101             }
102
103             if (dto.getHasModel()) {
104                 paramsValueArray[datas.size()] = dto.getModel();
105                 paramsValueArray[datas.size() + 1] = dto.getModelSettings();
106                 paramsArray[datas.size()] = HashMap.class;
107                 paramsArray[datas.size() + 1] = HashMap.class;
108             }else {
109                 paramsValueArray[datas.size()] = dto.getModelSettings();
110                 paramsArray[datas.size()] = HashMap.class;
111             }
112
912a1e 113             HashMap result = (HashMap) clazz.getDeclaredMethod(dto.getMethodName(), paramsArray).invoke(clazz.newInstance(), paramsValueArray);
449017 114             return success(JSON.toJSONString(result));
D 115         } catch (Exception ex) {
116             ex.printStackTrace();
117             return error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"运行异常");
118         } finally {
912a1e 119             if (classLoader != null) {
c12dae 120                 DllUtils.unloadDll(classLoader);
D 121                 DllUtils.unloadJar(classLoader);
912a1e 122             }
449017 123             System.gc();
D 124         }
125     }
126
3e18d4 127     /**
D 128      * @description: 模型运行
129      * @author: dzd
130      * @date: 2024/10/14 15:26
131      **/
132     @PostMapping("run")
133     public CommonResult<String> run(@RequestBody MdkRunDTO dto) {
134         if (RSAUtils.checkLisenceBean().getCode() != 1) {
135             log.error("Lisence 不可用!");
136             return CommonResult.error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"Lisence 不可用!");
137         } else {
138             try {
139                 URLClassLoader classLoader = DllUtils.getClassLoader(dto.getProjectId());
140                 if (null == classLoader) {
141                     return CommonResult.error(GlobalErrorCodeConstants.ERROR_CONFIGURATION.getCode(),"请先发布项目!");
142                 }
143
144                 Class<?> clazz = classLoader.loadClass(dto.getClassName());
145                 Method method = clazz.getMethod(dto.getMethodName(), dto.getParamsClassArray());
146                 HashMap invoke = (HashMap) method.invoke(clazz.newInstance(), dto.getParamsValueArray());
147
148                 // todo 将结果存入数据库
149
150             } catch (Exception e) {
151                 log.error("模型运行失败",e);
152                 return CommonResult.error(GlobalErrorCodeConstants.INTERNAL_SERVER_ERROR.getCode(),"模型运行失败!");
153             }
154         }
155         return CommonResult.success();
156     }
157
449017 158     @PostMapping("/import")
D 159     @Operation(summary = "导入参数")
160     public CommonResult<List<String>> importExcel(@RequestParam("file") MultipartFile file) throws Exception {
161         List<double[][]> datas = Readtxt.readMethodExcel(file);
162         List<String> result = new ArrayList<>();
163         if (!CollectionUtils.isEmpty(datas)) {
164             for (double[][] data : datas) {
165                 if (data.length > 0) {
166                     result.add(JSON.toJSONString(data));
167                 }
168             }
169         }
170         return success(result);
1fea3e 171     }
449017 172 }